In [17]:
from pathlib import Path
import numpy as np
import tensorflow as tf
import json

# ── 0) Paths ────────────────────────────────────────────────────────────────
BASE_DATA_DIR       = Path(r"E:\ArtifactRemovalProject\data\modelinference")
SAUMYA_MODEL_PATH   = Path(r"E:\ArtifactRemovalProject\Saumya NNArtifact\saved_model\NNArtifact_tf2")
ENSEMBLE_CFG_PATH   = Path(r"E:\ArtifactRemovalProject\results\ensembles\config.json")
EXP_KEY             = "all_four"   # change this if you want another ensemble

# ── 1) Load ensemble config once ─────────────────────────────────────────────
with open(ENSEMBLE_CFG_PATH, "r") as f:
    ensemble_cfg = json.load(f)
info = ensemble_cfg[EXP_KEY]

# ── 2) Helper: z-score normalizer ────────────────────────────────────────────
def zscore_per_spectrum(x, eps=1e-6):
    mu  = x.mean(axis=1, keepdims=True)
    std = x.std(axis=1, keepdims=True) + eps
    return (x - mu) / std

# ── 3) Main loop over subjects/dates ────────────────────────────────────────
for subject_dir in BASE_DATA_DIR.iterdir():
    if not subject_dir.is_dir(): 
        continue

    for date_dir in subject_dir.iterdir():
        if not date_dir.is_dir(): 
            continue

        print(f"\n--- Processing {subject_dir.name} / {date_dir.name} ---")

        # ── 3a) Load your raw arrays ─────────────────────────────────────────
        si    = np.load(date_dir / "si.npy")       # shape (Z, X, Y, S)
        water = np.load(date_dir / "siref.npy")
        fit1  = np.load(date_dir / "midasfit.npy")
        fit2  = np.load(date_dir / "nnfit.npy")

        Z, X, Y, S = si.shape
        BATCH_SIZE = X * Y

        # ── 4) Inference with Saumya's model ────────────────────────────────
        saumya_model = tf.saved_model.load(SAUMYA_MODEL_PATH)
        num_classes  = 2
        ypred_saumya = np.zeros((Z, X, Y, num_classes), dtype=np.float32)

        for z in range(Z):
            xtest = np.real(si[z]).reshape(BATCH_SIZE, S)
            preds = saumya_model.predict(xtest)            # raw logits or probabilities
            probs = tf.nn.softmax(preds).numpy()           # ensure softmax
            ypred_saumya[z] = probs.reshape(X, Y, num_classes)

        out_path = date_dir / "ypred_saumya.npy"
        np.save(out_path, ypred_saumya)
        print(f" Saved Saumya predictions → {out_path}")

        # Clean up Saumya's model
        del saumya_model
        tf.keras.backend.clear_session()

        # ── 5) Inference with your ensemble model ───────────────────────────
        model = tf.keras.models.load_model(info["model_path"])
        print(f" Loaded ensemble '{EXP_KEY}'")

        # Flatten + preprocess each channel
        raw_flat   = si.reshape(-1, S)
        fit1_flat  = fit1.reshape(-1, S)
        fit2_flat  = fit2.reshape(-1, S)
        water_flat = water.reshape(-1, S)

        raw_norm   = zscore_per_spectrum(raw_flat)
        fit1_norm  = zscore_per_spectrum(fit1_flat)
        fit2_norm  = zscore_per_spectrum(fit2_flat)

        # log-normalize water
        wat_log = np.log10(np.abs(water_flat) + 1e-6)
        wmin    = wat_log.min(axis=1, keepdims=True)
        wmax    = wat_log.max(axis=1, keepdims=True) + 1e-6
        water_norm = (wat_log - wmin) / (wmax - wmin)

        # stack channels in configured order
        channel_arrays = {
            "raw":   raw_norm,
            "water": water_norm,
            "fit1":  fit1_norm,
            "fit2":  fit2_norm,
        }
        X_input = np.stack([channel_arrays[ch] for ch in info["channels"]], axis=-1).astype("float32")

        # run inference
        probs = model.predict(X_input, batch_size=4096).ravel()
        thresh = info["threshold"]
        preds = (probs >= thresh).astype(np.int32) ^ 1

        # reshape back to volumes
        probs_vol = probs.reshape(Z, X, Y)
        preds_vol = preds.reshape(Z, X, Y)

        # save them
        np.save(date_dir / f"probs_vol_{EXP_KEY}.npy", probs_vol)
        np.save(date_dir / f"preds_vol_{EXP_KEY}.npy", preds_vol)
        print(f" Saved ensemble probs → {date_dir/'probs_vol_{EXP_KEY}.npy'}")
        print(f" Saved ensemble preds → {date_dir/'preds_vol_{EXP_KEY}.npy'}")

        # clean up ensemble model
        del model
        tf.keras.backend.clear_session()




--- Processing DOSEESC_EM09 / 11.20.2018 ---
 Saved Saumya predictions → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_EM09\11.20.2018\ypred_saumya.npy
 Loaded ensemble 'all_four'


  X_input = np.stack([channel_arrays[ch] for ch in info["channels"]], axis=-1).astype("float32")


 Saved ensemble probs → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_EM09\11.20.2018\probs_vol_{EXP_KEY}.npy
 Saved ensemble preds → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_EM09\11.20.2018\preds_vol_{EXP_KEY}.npy

--- Processing DOSEESC_EM09 / 12.17.2018 ---
 Saved Saumya predictions → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_EM09\12.17.2018\ypred_saumya.npy
 Loaded ensemble 'all_four'


  X_input = np.stack([channel_arrays[ch] for ch in info["channels"]], axis=-1).astype("float32")


 Saved ensemble probs → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_EM09\12.17.2018\probs_vol_{EXP_KEY}.npy
 Saved ensemble preds → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_EM09\12.17.2018\preds_vol_{EXP_KEY}.npy

--- Processing DOSEESC_JH07 / 08.27.2018 ---
 Saved Saumya predictions → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_JH07\08.27.2018\ypred_saumya.npy
 Loaded ensemble 'all_four'


  X_input = np.stack([channel_arrays[ch] for ch in info["channels"]], axis=-1).astype("float32")


 Saved ensemble probs → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_JH07\08.27.2018\probs_vol_{EXP_KEY}.npy
 Saved ensemble preds → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_JH07\08.27.2018\preds_vol_{EXP_KEY}.npy

--- Processing DOSEESC_JH07 / 09.20.2018 ---
 Saved Saumya predictions → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_JH07\09.20.2018\ypred_saumya.npy
 Loaded ensemble 'all_four'


  X_input = np.stack([channel_arrays[ch] for ch in info["channels"]], axis=-1).astype("float32")


 Saved ensemble probs → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_JH07\09.20.2018\probs_vol_{EXP_KEY}.npy
 Saved ensemble preds → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_JH07\09.20.2018\preds_vol_{EXP_KEY}.npy

--- Processing DOSEESC_UM08 / 06.13.2018 ---
 Saved Saumya predictions → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_UM08\06.13.2018\ypred_saumya.npy
 Loaded ensemble 'all_four'


  X_input = np.stack([channel_arrays[ch] for ch in info["channels"]], axis=-1).astype("float32")


 Saved ensemble probs → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_UM08\06.13.2018\probs_vol_{EXP_KEY}.npy
 Saved ensemble preds → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_UM08\06.13.2018\preds_vol_{EXP_KEY}.npy

--- Processing DOSEESC_UM08 / 07.11.2018 ---
 Saved Saumya predictions → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_UM08\07.11.2018\ypred_saumya.npy
 Loaded ensemble 'all_four'


  X_input = np.stack([channel_arrays[ch] for ch in info["channels"]], axis=-1).astype("float32")


 Saved ensemble probs → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_UM08\07.11.2018\probs_vol_{EXP_KEY}.npy
 Saved ensemble preds → E:\ArtifactRemovalProject\data\modelinference\DOSEESC_UM08\07.11.2018\preds_vol_{EXP_KEY}.npy

--- Processing rGBM_001_02_27_2023 / 02.27.2023 ---
 Saved Saumya predictions → E:\ArtifactRemovalProject\data\modelinference\rGBM_001_02_27_2023\02.27.2023\ypred_saumya.npy
 Loaded ensemble 'all_four'


  X_input = np.stack([channel_arrays[ch] for ch in info["channels"]], axis=-1).astype("float32")


 Saved ensemble probs → E:\ArtifactRemovalProject\data\modelinference\rGBM_001_02_27_2023\02.27.2023\probs_vol_{EXP_KEY}.npy
 Saved ensemble preds → E:\ArtifactRemovalProject\data\modelinference\rGBM_001_02_27_2023\02.27.2023\preds_vol_{EXP_KEY}.npy
