In [None]:
import os
import re
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from skimage.metrics import structural_similarity as ssim
from torchvision import transforms
import import_ipynb           # makes Python aware of .ipynb modules
import unetDiceLoss_with_L1_modified               # references deepunet.ipynb
from unetDiceLoss_with_L1_modified import UNet

# 1) Configuration
BASE_PATHS = [
    r"C:\Projects\Embryo\Dataset\embryo_dataset_F15",
    r"C:\Projects\Embryo\Dataset\embryo_dataset_F-15",
    r"C:\Projects\Embryo\Dataset\embryo_dataset_F30",
    r"C:\Projects\Embryo\Dataset\embryo_dataset_F-30",
    r"C:\Projects\Embryo\Dataset\embryo_dataset_F45",
    r"C:\Projects\Embryo\Dataset\embryo_dataset_F-45",
]
GT_PATH     = r"C:\Projects\Embryo\Dataset\embryo_dataset"
MODEL_PATH  = "embryo_unet.pth"
EMBRYO_ID   = "AB91-1"   # change to your series
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2) Preprocessing transforms
#   - for model input: resize & to-tensor
tf_model = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
])

# 3) Helper to extract frame number
def frame_num(fn):
    m = re.search(r'RUN(\d+)', fn)
    return int(m.group(1)) if m else None

# 4) Load model
model = UNet(in_channels=6, out_channels=1).to(DEVICE)
state = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state)
model.eval()

# 5) Gather all focal filenames for this embryo
fns = sorted(os.listdir(os.path.join(BASE_PATHS[0], EMBRYO_ID)))
# filter only those matching RUN\d+
fns = [fn for fn in fns if frame_num(fn) is not None]

scores = []
frames = []

with torch.no_grad():
    for fn in fns:
        num = frame_num(fn)
        # load 6 focal planes
        channels = []
        for bp in BASE_PATHS:
            img = Image.open(os.path.join(bp, EMBRYO_ID, fn)).convert('L')
            channels.append(tf_model(img))
        inp = torch.cat(channels, dim=0).unsqueeze(0).to(DEVICE)  # (1,6,H,W)

        # forward
        out = model(inp).squeeze(0).cpu()  # (1,H,W)
        fused = out.numpy().squeeze()

        # load GT F0
        gt = np.array(Image.open(os.path.join(GT_PATH, EMBRYO_ID, fn)).resize((256,256)).convert('L'))

        # compute SSIM
        score = ssim(gt, fused,
                     data_range=fused.max()-fused.min(),
                     win_size=11,
                     gaussian_weights=True)
        frames.append(num)
        scores.append(score)

# 6) Sort by frame number
pairs = sorted(zip(frames, scores))
frames, scores = zip(*pairs)

# 7) Plot
plt.figure(figsize=(10,4))
plt.plot(frames, scores, marker='o')
plt.xlabel("Frame Number (Time)")
plt.ylabel("SSIM")
plt.title(f"SSIM vs. Time for Embryo {EMBRYO_ID}")
plt.grid(True)
plt.tight_layout()
plt.savefig(f"ssim_vs_time_{EMBRYO_ID}.png", dpi=150)
plt.show()


NameError: name 'MODEL_PATH' is not defined