This notebook was executed in the Kaggle environment  
using `/kaggle/input` datasets and saving outputs to `/kaggle/working`.

### copy checkpoints from input

In [None]:
# === Locate checkpoints inside /kaggle/input and copy to /kaggle/working/checkpoints ===
import glob, os, shutil
from pathlib import Path

INPUT_ROOT = Path("/kaggle/input")
CKPT_ROOT  = Path("/kaggle/working/checkpoints")
CKPT_ROOT.mkdir(parents=True, exist_ok=True)

# 1) First, search for latest_net_G.pth; if not found, take the most recent *_net_G.pth
candidates = glob.glob(str(INPUT_ROOT / "**/latest_net_G.pth"), recursive=True)
if not candidates:
    candidates = glob.glob(str(INPUT_ROOT / "**/*_net_G.pth"), recursive=True)

assert candidates, "No model files found (latest_net_G.pth or *_net_G.pth) inside the input dataset."

# 2) Select the most recently modified file
candidates = sorted([Path(p) for p in candidates], key=lambda p: p.stat().st_mtime, reverse=True)
best_G = candidates[0]
exp_dir_in_input = best_G.parent           # experiment directory inside Input
experiment_name  = exp_dir_in_input.name   # folder name (e.g. wafer_pix2pix_AtoB_256_out1)
dst_exp_dir      = CKPT_ROOT / experiment_name

print("Found model:", best_G)
print("Experiment directory:", exp_dir_in_input)
print("Destination:", dst_exp_dir)

# 3) Copy the entire experiment folder (files and subfolders)
if dst_exp_dir.exists():
    shutil.rmtree(dst_exp_dir)
shutil.copytree(exp_dir_in_input, dst_exp_dir)

# 4) Sanity check
assert (dst_exp_dir/"latest_net_G.pth").exists() or any(dst_exp_dir.glob("*_net_G.pth")), \
       "No G.pth files were copied to the destination; make sure the dataset contains them."

print("✅ Checkpoints successfully copied to:", dst_exp_dir)
!ls -lh {dst_exp_dir} | sed -n '1,200p'


In [None]:
!pip install dominate

### run inference val sbs

In [None]:
# === Inference on VAL (SBS) — matching the training config ===
import sys, subprocess, shlex
from pathlib import Path

REPO       = Path("/kaggle/working/pix2pix")
DATAROOT   = Path("/kaggle/input/processed-images")        # train/val/test in SBS format
CKPT_ROOT  = Path("/kaggle/working/checkpoints")

# Select the most recently modified experiment folder
exp_dirs = sorted([p for p in CKPT_ROOT.iterdir() if p.is_dir()],
                  key=lambda p: p.stat().st_mtime, reverse=True)
assert exp_dirs, "No experiment folder found in /kaggle/working/checkpoints"
EXPERIMENT = exp_dirs[0].name
print("Running with experiment:", EXPERIMENT)

# Same parameters as in training
INPUT_NC  = 3
OUTPUT_NC = 1            # set to 3 if your SEM output is RGB
DIRECTION = "AtoB"
LOAD_SIZE = 286
CROP_SIZE = 256

# Restore the repo if missing
if not REPO.exists():
    !git clone --depth 1 https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git /kaggle/working/pix2pix

# Ensure G exists
assert (CKPT_ROOT/EXPERIMENT/"latest_net_G.pth").exists(), "Missing latest_net_G.pth"

RESULTS = Path("/kaggle/working/inference_val_sbs")
RESULTS.mkdir(parents=True, exist_ok=True)

cmd = [
    sys.executable, str(REPO/"test.py"),
    "--model", "pix2pix",             # <<< Important: not 'test'!
    "--netG", "unet_256",             # <<< must match training
    "--norm", "batch",                # <<< same as training
    "--dataroot", str(DATAROOT),
    "--phase", "val",
    "--dataset_mode", "aligned",
    "--direction", DIRECTION,
    "--name", EXPERIMENT,
    "--checkpoints_dir", str(CKPT_ROOT),
    "--preprocess", "resize_and_crop",
    "--load_size", str(LOAD_SIZE),
    "--crop_size", str(CROP_SIZE),
    "--input_nc", str(INPUT_NC),
    "--output_nc", str(OUTPUT_NC),
    "--serial_batches",
    "--num_test", "100000",
    "--results_dir", str(RESULTS),
    "--epoch", "latest",              # load the 'latest' checkpoint
    "--eval",                         # disables dropout/bn training mode
]

print(">>> TEST CMD:\n", " ".join(shlex.quote(c) for c in cmd))
ret = subprocess.run(cmd, check=False)
print("test.py exited with code:", ret.returncode)

print("\nResults in:", RESULTS)
!find /kaggle/working/inference_val_sbs -maxdepth 3 -type f | sed -n '1,120p'


### single image inference

In [None]:
# === Single-image inference using TestModel (A only) ===
from pathlib import Path
from PIL import Image
import sys, subprocess, shlex, glob

REPO       = Path("/kaggle/working/pix2pix")
CKPT_ROOT  = Path("/kaggle/working/checkpoints")
EXPERIMENT = "wafer_pix2pix_AtoB_256_out1"

INPUT_IMG  = Path("/kaggle/input/image10/image10.png")
TEMP_DIR   = Path("/kaggle/working/single_infer")
RESULTS    = Path("/kaggle/working/inference_single")

# Convert to PNG and save into a temporary dataroot folder
TEMP_DIR.mkdir(parents=True, exist_ok=True)
RESULTS.mkdir(parents=True, exist_ok=True)
Image.open(INPUT_IMG).convert("RGB").save(TEMP_DIR / "test.png")
print("Input saved:", TEMP_DIR/"test.png")

# Clone repo if it doesn't exist in the current notebook
if not REPO.exists():
    !git clone --depth 1 https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git /kaggle/working/pix2pix

# Run with --model test (not pix2pix), and same architecture arguments as training
cmd = [
    sys.executable, str(REPO/"test.py"),
    "--model", "test",           # <<< Important: 'test' model doesn't require B
    "--netG", "unet_256",        # must match training
    "--norm", "batch",           # same as in training
    "--dataroot", str(TEMP_DIR),
    "--dataset_mode", "single",  # A only
    "--direction", "AtoB",
    "--name", EXPERIMENT,
    "--checkpoints_dir", str(CKPT_ROOT),
    "--preprocess", "resize_and_crop",
    "--load_size", "286",
    "--crop_size", "256",
    "--input_nc", "3",
    "--output_nc", "1",          # set to 3 if your SEM output is RGB
    "--results_dir", str(RESULTS),
    "--epoch", "latest",
    "--num_test", "1",
    "--eval",
]
print(">>>", " ".join(shlex.quote(c) for c in cmd))
ret = subprocess.run(cmd, check=False)
print("Exit code:", ret.returncode)

# Locate and display output images
out_dir = RESULTS/EXPERIMENT/"test_latest"/"images"
print("Results located in:", out_dir)
!ls -lh $out_dir | sed -n '1,200p'


##### display single inference

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path

out_dir = Path("/kaggle/working/inference_single/wafer_pix2pix_AtoB_256_out1/test_latest/images")

imgA = Image.open(out_dir / "test_real.png")
imgB = Image.open(out_dir / "test_fake.png")

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow(imgA)
plt.title("Input (Segmentation)")
plt.axis("off")

plt.subplot(1,2,2)
plt.imshow(imgB, cmap="gray")
plt.title("Generated SEM")
plt.axis("off")

plt.show()


### run inference test sbs

In [None]:
# === Inference on TEST (SBS) — matching training config ===
import sys, subprocess, shlex
from pathlib import Path

REPO       = Path("/kaggle/working/pix2pix")
DATAROOT   = Path("/kaggle/input/processed-images")  # train/val/test in SBS format
CKPT_ROOT  = Path("/kaggle/working/checkpoints")

# Select the most recently modified experiment folder
exp_dirs = sorted([p for p in CKPT_ROOT.iterdir() if p.is_dir()],
                  key=lambda p: p.stat().st_mtime, reverse=True)
assert exp_dirs, "No experiment folder found in /kaggle/working/checkpoints"
EXPERIMENT = exp_dirs[0].name
print("Running with experiment:", EXPERIMENT)

# Parameters (same as in training)
INPUT_NC  = 3
OUTPUT_NC = 1       # set to 3 if your SEM is RGB
DIRECTION = "AtoB"
LOAD_SIZE = 286
CROP_SIZE = 256

# Clone repo if needed
if not REPO.exists():
    !git clone --depth 1 https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git /kaggle/working/pix2pix

# Ensure generator exists
assert (CKPT_ROOT/EXPERIMENT/"latest_net_G.pth").exists(), "Missing latest_net_G.pth"

RESULTS = Path("/kaggle/working/inference_test_sbs")
RESULTS.mkdir(parents=True, exist_ok=True)

cmd = [
    sys.executable, str(REPO/"test.py"),
    "--model", "pix2pix",
    "--netG", "unet_256",
    "--norm", "batch",
    "--dataroot", str(DATAROOT),
    "--phase", "test",              # <<< now using test phase
    "--dataset_mode", "aligned",
    "--direction", DIRECTION,
    "--name", EXPERIMENT,
    "--checkpoints_dir", str(CKPT_ROOT),
    "--preprocess", "resize_and_crop",
    "--load_size", str(LOAD_SIZE),
    "--crop_size", str(CROP_SIZE),
    "--input_nc", str(INPUT_NC),
    "--output_nc", str(OUTPUT_NC),
    "--serial_batches",
    "--num_test", "100000",
    "--results_dir", str(RESULTS),
    "--epoch", "latest",
    "--eval",
]

print(">>> TEST CMD:\n", " ".join(shlex.quote(c) for c in cmd))
ret = subprocess.run(cmd, check=False)
print("test.py exited with code:", ret.returncode)

print("\nResults saved in:", RESULTS)
!find /kaggle/working/inference_test_sbs -maxdepth 3 -type f | sed -n '1,80p'


### evaluate pix2pix metrics

In [None]:
# ===========================
# Evaluation for Pix2Pix test_latest (PSNR, SSIM, LPIPS, FID)
# ===========================
# Imports
!pip -q install lpips torch-fidelity >/dev/null

import os, re, shutil, glob, math, json, pathlib
from pathlib import Path
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import lpips
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from torch_fidelity import calculate_metrics
import pandas as pd

# ---- Path configuration (adjust as needed)
TEST_SBS_DIR = Path("/kaggle/input/processed-images/test")  # not required for metric calculations
RESULTS_DIR = Path("/kaggle/working/inference_test_sbs/wafer_pix2pix_AtoB_256_out1/test_latest/images")
OUT_DIR      = Path("/kaggle/working/eval_pix2pix_test"); OUT_DIR.mkdir(parents=True, exist_ok=True)

# ---- Helper functions
def load_image(path):
    img = Image.open(path).convert('RGB')   # convert to RGB even if the source is grayscale
    return np.array(img)

def to_torch_lpips(arr_rgb_uint8):
    # LPIPS expects a tensor in range [-1,1], shape NCHW, RGB order
    ten = torch.from_numpy(arr_rgb_uint8.astype(np.float32)/255.0)      # HWC, [0,1]
    ten = ten.permute(2,0,1).unsqueeze(0)                               # 1x3xHxW
    ten = ten*2-1                                                       # [-1,1]
    return ten

# ---- Find fake_B / real_B pairs based on filenames
# In standard pix2pix output, file names follow the pattern: xxx_fake_B.png, xxx_real_B.png
fake_paths = sorted(RESULTS_DIR.glob("*_fake_B.*"))
pairs = []
for fp in fake_paths:
    base = re.sub(r"_fake_B\.[^.]+$", "", fp.name)
    # try to find matching ground truth
    cand = list(RESULTS_DIR.glob(base + "_real_B.*"))
    if cand:
        pairs.append((fp, cand[0]))

assert len(pairs) > 0, f"No Fake/Real pairs found in: {RESULTS_DIR}"

# ---- Initialize LPIPS
lpips_model = lpips.LPIPS(net='alex').eval()

# ---- Compute metrics for each image pair
rows = []
for fp, rp in pairs:
    fake = load_image(fp)
    real = load_image(rp)
    # Ensure same dimensions
    if fake.shape != real.shape:
        # If there's a minor mismatch, resize to GT size
        real_h, real_w = real.shape[:2]
        fake = np.array(Image.fromarray(fake).resize((real_w, real_h), Image.BICUBIC))

    # PSNR / SSIM (computed on RGB; for grayscale, convert both to gray)
    cur_psnr = psnr(real, fake, data_range=255)
    cur_ssim = ssim(real, fake, channel_axis=2, data_range=255)

    # LPIPS
    with torch.no_grad():
        t_fake = to_torch_lpips(fake)
        t_real = to_torch_lpips(real)
        cur_lpips = float(lpips_model(t_fake, t_real).cpu().numpy())

    rows.append({
        "name": fp.stem.replace("_fake_B",""),
        "psnr": cur_psnr,
        "ssim": cur_ssim,
        "lpips": cur_lpips
    })

df = pd.DataFrame(rows).sort_values("name")
df.to_csv(OUT_DIR/"metrics_per_image.csv", index=False)

# ---- Summary (mean and std)
summary = {
    "N": len(df),
    "PSNR_mean": df.psnr.mean(), "PSNR_std": df.psnr.std(),
    "SSIM_mean": df.ssim.mean(), "SSIM_std": df.ssim.std(),
    "LPIPS_mean": df.lpips.mean(), "LPIPS_std": df.lpips.std(),
}
with open(OUT_DIR/"metrics_summary.json","w") as f:
    json.dump(summary, f, indent=2, ensure_ascii=False)

print("Summary:", summary)

# ---- Plots
plt.figure(); df.psnr.hist(bins=40); plt.title("PSNR distribution"); plt.xlabel("PSNR (dB)"); plt.ylabel("count"); plt.show()
plt.figure(); df.ssim.hist(bins=40); plt.title("SSIM distribution"); plt.xlabel("SSIM"); plt.ylabel("count"); plt.show()
plt.figure(); df.lpips.hist(bins=40); plt.title("LPIPS distribution"); plt.xlabel("LPIPS (lower is better)"); plt.ylabel("count"); plt.show()

# ---- FID: requires two folders (real and fake)
fid_real_dir = OUT_DIR/"fid_real"; fid_fake_dir = OUT_DIR/"fid_fake"
for d in [fid_real_dir, fid_fake_dir]:
    if d.exists(): shutil.rmtree(d)
    d.mkdir(parents=True)

for fp, rp in pairs:
    shutil.copy(rp, fid_real_dir/(Path(fp).stem.replace("_fake_B","") + rp.suffix))
    shutil.copy(fp, fid_fake_dir/(Path(fp).stem.replace("_fake_B","") + fp.suffix))

metrics = calculate_metrics(input1=fid_fake_dir.as_posix(),
                            input2=fid_real_dir.as_posix(),
                            cuda=torch.cuda.is_available(),
                            isc=False, kid=True, fid=True, verbose=False)
with open(OUT_DIR/"fid_kid.json","w") as f:
    json.dump(metrics, f, indent=2)

print("FID/KID:", metrics)
print(f"\nOutput files:\n- {OUT_DIR/'metrics_per_image.csv'}\n- {OUT_DIR/'metrics_summary.json'}\n- {OUT_DIR/'fid_kid.json'}")


#### evaluate pix2pix metrics 300

In [None]:
# ===========================
# Evaluation for Pix2Pix test_latest (PSNR, SSIM, LPIPS, FID, KID on 300)
# ===========================
!pip -q install lpips torch-fidelity >/dev/null

import os, re, shutil, glob, math, json, pathlib
from pathlib import Path
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import lpips
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from torch_fidelity import calculate_metrics
import pandas as pd

# ---- Path configuration
TEST_SBS_DIR = Path("/kaggle/input/processed-images/test")  # not required for metric evaluation
RESULTS_DIR  = Path("/kaggle/working/inference_test_sbs/wafer_pix2pix_AtoB_256_out1/test_latest/images")
OUT_DIR      = Path("/kaggle/working/eval_pix2pix_test"); OUT_DIR.mkdir(parents=True, exist_ok=True)

# ---- Helper functions
def load_image(path):
    img = Image.open(path).convert('RGB')
    return np.array(img)

def to_torch_lpips(arr_rgb_uint8):
    ten = torch.from_numpy(arr_rgb_uint8.astype(np.float32)/255.0)  # HWC, [0,1]
    ten = ten.permute(2,0,1).unsqueeze(0)                           # 1x3xHxW
    ten = ten*2-1                                                   # [-1,1]
    return ten

# ---- Find fake_B / real_B pairs
fake_paths = sorted(RESULTS_DIR.glob("*_fake_B.*"))
pairs = []
for fp in fake_paths:
    base = re.sub(r"_fake_B\.[^.]+$", "", fp.name)
    cand = list(RESULTS_DIR.glob(base + "_real_B.*"))
    if cand:
        pairs.append((fp, cand[0]))

assert len(pairs) > 0, f"No Fake/Real pairs found in: {RESULTS_DIR}"
print(f"Found {len(pairs)} pairs for evaluation.")

# ---- LPIPS
lpips_model = lpips.LPIPS(net='alex').eval()

# ---- Compute metrics for each image
rows = []
for fp, rp in pairs:
    fake = load_image(fp)
    real = load_image(rp)

    if fake.shape != real.shape:
        real_h, real_w = real.shape[:2]
        fake = np.array(Image.fromarray(fake).resize((real_w, real_h), Image.BICUBIC))

    cur_psnr = psnr(real, fake, data_range=255)
    cur_ssim = ssim(real, fake, channel_axis=2, data_range=255)

    with torch.no_grad():
        t_fake = to_torch_lpips(fake)
        t_real = to_torch_lpips(real)
        cur_lpips = lpips_model(t_fake, t_real).item()  # fixed deprecation warning

    rows.append({
        "name": fp.stem.replace("_fake_B",""),
        "psnr": cur_psnr,
        "ssim": cur_ssim,
        "lpips": cur_lpips
    })

df = pd.DataFrame(rows).sort_values("name")
df.to_csv(OUT_DIR/"metrics_per_image.csv", index=False)

summary = {
    "N": len(df),
    "PSNR_mean": df.psnr.mean(), "PSNR_std": df.psnr.std(),
    "SSIM_mean": df.ssim.mean(), "SSIM_std": df.ssim.std(),
    "LPIPS_mean": df.lpips.mean(), "LPIPS_std": df.lpips.std(),
}
with open(OUT_DIR/"metrics_summary.json","w") as f:
    json.dump(summary, f, indent=2, ensure_ascii=False)

print("Summary:", summary)

# ---- Plots
plt.figure(); df.psnr.hist(bins=40); plt.title("PSNR distribution"); plt.xlabel("PSNR (dB)"); plt.ylabel("count"); plt.show()
plt.figure(); df.ssim.hist(bins=40); plt.title("SSIM distribution"); plt.xlabel("SSIM"); plt.ylabel("count"); plt.show()
plt.figure(); df.lpips.hist(bins=40); plt.title("LPIPS distribution"); plt.xlabel("LPIPS (lower is better)"); plt.ylabel("count"); plt.show()

# ---- FID/KID: build real/fake folders
fid_real_dir = OUT_DIR/"fid_real"; fid_fake_dir = OUT_DIR/"fid_fake"
for d in [fid_real_dir, fid_fake_dir]:
    if d.exists(): shutil.rmtree(d)
    d.mkdir(parents=True)

for fp, rp in pairs:
    shutil.copy(rp, fid_real_dir/(Path(fp).stem.replace("_fake_B","") + rp.suffix))
    shutil.copy(fp, fid_fake_dir/(Path(fp).stem.replace("_fake_B","") + fp.suffix))

# --- Compute FID + KID where KID subset <= number of examples (here 300)
kid_subset = min(300, len(pairs))  # if fewer than 300, automatically adjusts
metrics = calculate_metrics(
    input1=fid_fake_dir.as_posix(),
    input2=fid_real_dir.as_posix(),
    cuda=torch.cuda.is_available(),
    isc=False,
    fid=True,
    kid=True,
    kid_subset_size=kid_subset,
    kid_subset_retries=10,   # improves statistical stability
    verbose=False
)

with open(OUT_DIR/"fid_kid.json","w") as f:
    json.dump(metrics, f, indent=2)

print("FID/KID:", metrics)
print(f"\nOutput files:\n- {OUT_DIR/'metrics_per_image.csv'}\n- {OUT_DIR/'metrics_summary.json'}\n- {OUT_DIR/'fid_kid.json'}")


### show worst three panels

In [None]:
# ===========================
# Show worst-3 per metric (PSNR↓, SSIM↓, LPIPS↑)
# Saves panels to OUT_DIR and displays them inline
# ===========================
import numpy as np
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import re, itertools

def load_img(path):
    return np.array(Image.open(path).convert('RGB'))

def try_find(path_pattern_list):
    for p in path_pattern_list:
        matches = list(p.parent.glob(p.name))
        if matches:
            return matches[0]
    return None

# Build quick index: name -> {fake, realB, realA?}
index = {}
for fp, rp in pairs:
    name = Path(fp).stem.replace("_fake_B","")
    rec = index.setdefault(name, {})
    rec["fake"] = fp
    rec["realB"] = rp
    # try to find real_A next to them
    base = re.sub(r"_fake_B\.[^.]+$", "", fp.name)
    candA = list(RESULTS_DIR.glob(base + "_real_A.*"))
    if candA:
        rec["realA"] = candA[0]

def safe_resize_like(img_src, img_ref):
    if img_src.shape[:2] != img_ref.shape[:2]:
        h, w = img_ref.shape[:2]
        img_src = np.array(Image.fromarray(img_src).resize((w, h), Image.BICUBIC))
    return img_src

def show_worst(metric_name, worst_df, save_name):
    rows = len(worst_df)
    cols = 4  # InputA | FakeB | RealB | |Diff|
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 3.5*rows))
    if rows == 1:
        axes = np.expand_dims(axes, 0)

    titles = ["Input A", "Fake B", "Real B", "|Fake-Real|"]
    for r, (_, row) in enumerate(worst_df.iterrows()):
        name = row["name"]
        rec = index[name]
        img_fake = load_img(rec["fake"])
        img_real = load_img(rec["realB"])
        img_fake = safe_resize_like(img_fake, img_real)

        # Input A (optional)
        if "realA" in rec:
            img_A = load_img(rec["realA"])
            img_A = safe_resize_like(img_A, img_real)
        else:
            img_A = None

        # |Diff| as grayscale heat
        diff = np.mean(np.abs(img_fake.astype(np.float32) - img_real.astype(np.float32)), axis=2)
        # normalize for display
        if diff.max() > 0:
            diff_disp = (diff / diff.max())
        else:
            diff_disp = diff

        imgs = [img_A, img_fake, img_real, diff_disp]
        for c in range(cols):
            ax = axes[r, c]
            ax.axis("off")
            if c == 0 and imgs[c] is None:
                ax.set_title(f"{titles[c]} (missing)", fontsize=11)
                continue
            if c == 3:
                im = ax.imshow(imgs[c], cmap="inferno")
            else:
                im = ax.imshow(imgs[c])
            if r == 0:
                ax.set_title(titles[c], fontsize=12)

        # row label with metric value
        val_str = f"{metric_name}="
        if metric_name.lower() == "lpips":
            val_str += f"{row['lpips']:.4f}"
        elif metric_name.lower() == "psnr":
            val_str += f"{row['psnr']:.2f} dB"
        elif metric_name.lower() == "ssim":
            val_str += f"{row['ssim']:.4f}"
        fig.text(0.01, 1 - (r+0.5)/rows, f"{r+1}. {name}  |  {val_str}", va="center", fontsize=11)

    fig.suptitle(f"Worst 3 by {metric_name}", fontsize=14, y=0.995)
    plt.tight_layout(rect=[0,0,1,0.97])
    out_path = OUT_DIR / save_name
    plt.savefig(out_path, dpi=130)
    plt.show()
    print(f"Saved: {out_path}")

# pick worst-3 per metric
worst_ssim  = df.nsmallest(3, "ssim")
worst_psnr  = df.nsmallest(3, "psnr")
worst_lpips = df.nlargest(3, "lpips")

print("Worst by SSIM:")
print(worst_ssim[["name","ssim"]])
print("\nWorst by PSNR:")
print(worst_psnr[["name","psnr"]])
print("\nWorst by LPIPS:")
print(worst_lpips[["name","lpips"]])

# show/save panels
show_worst("SSIM", worst_ssim,  "worst3_ssim.png")
show_worst("PSNR", worst_psnr,  "worst3_psnr.png")
show_worst("LPIPS", worst_lpips, "worst3_lpips.png")


### batch inference and grids

In [None]:
# === Batch inference on folder (A-only) + side-by-side grids ===
# Input:  /kaggle/input/preds-yael  (a folder of single images, not SBS)
# Output: /kaggle/working/inference_preds_yael/<EXP>/test_latest/images
# Grids:  /kaggle/working/preds_yael_grids
# ZIPs for download: /kaggle/working/preds_yael_results.zip , preds_yael_grids.zip
!pip -q install dominate

import os, glob, shutil, subprocess, shlex
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt

# ---------- Settings ----------
INPUT_DS   = Path("/kaggle/input/preds-yael")  # ← your folder with input images
REPO       = Path("/kaggle/working/pix2pix")
# Detect where checkpoints are:
CKPT_WORK  = Path("/kaggle/working/checkpoints")
CKPT_INPUT = Path("/kaggle/input/wafer-pix2pix-checkpoints")

if CKPT_WORK.exists():
    CKPT_ROOT = CKPT_WORK
else:
    CKPT_ROOT = CKPT_INPUT
assert CKPT_ROOT.exists(), f"Checkpoint directory not found: {CKPT_ROOT}"

# Find the experiment folder (the one containing latest_net_G.pth)
def find_experiment_dir(root: Path):
    cands = []
    # search up to depth 2
    for p in root.glob("*"):
        if p.is_dir() and (p/"latest_net_G.pth").exists():
            cands.append(p)
    if not cands:
        for p in root.glob("*/*"):
            if p.is_dir() and (p/"latest_net_G.pth").exists():
                cands.append(p)
    assert cands, f"latest_net_G.pth not found under {root}"
    # pick the most recently modified
    cands = sorted(cands, key=lambda p: p.stat().st_mtime, reverse=True)
    return cands[0]

EXP_DIR    = find_experiment_dir(CKPT_ROOT)
EXPERIMENT = EXP_DIR.name
print("Selected experiment:", EXPERIMENT)
print("Checkpoints under:", EXP_DIR)

# Parameters (match training):
INPUT_NC   = 3
OUTPUT_NC  = 1     # set to 3 if your SEM output is RGB
DIRECTION  = "AtoB"
LOAD_SIZE  = 286
CROP_SIZE  = 256

# ---------- Prepare repo ----------
if not REPO.exists():
    !git clone --depth 1 https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix.git /kaggle/working/pix2pix

# ---------- Collect input images and convert to PNG in a working folder ----------
TEMP_DIR   = Path("/kaggle/working/preds_yael_input")
TEMP_DIR.mkdir(parents=True, exist_ok=True)

exts = (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp")
src_imgs = []
for ext in exts:
    src_imgs += sorted(INPUT_DS.rglob(f"*{ext}"))

assert src_imgs, f"No images found under {INPUT_DS}"
print(f"Found {len(src_imgs)} images. Converting and saving as PNG...")

# clean first
for p in TEMP_DIR.glob("*"):
    p.unlink()

name_map = {}  # mapping: base-name → source (for grids/display)
for i, p in enumerate(src_imgs, 1):
    try:
        im = Image.open(p).convert("RGB")
        # Save as "imgNNNN.png" to avoid issues with special filenames
        out_name = f"img_{i:05d}.png"
        im.save(TEMP_DIR / out_name)
        name_map[out_name.replace(".png","")] = p
    except Exception as e:
        print("Skipping problematic file:", p, e)

print("Converted and saved:", len(name_map))

# ---------- Run test.py with single dataset_mode + model=test (A only) ----------
RESULTS = Path("/kaggle/working/inference_preds_yael")
RESULTS.mkdir(parents=True, exist_ok=True)

cmd = [
    "/usr/bin/python3", str(REPO/"test.py"),
    "--model", "test",               # Important: use 'test' (not pix2pix) when B is absent
    "--netG", "unet_256",
    "--norm", "batch",
    "--dataroot", str(TEMP_DIR),
    "--dataset_mode", "single",      # A only
    "--direction", DIRECTION,
    "--name", EXPERIMENT,
    "--checkpoints_dir", str(CKPT_ROOT),
    "--preprocess", "resize_and_crop",
    "--load_size", str(LOAD_SIZE),
    "--crop_size", str(CROP_SIZE),
    "--input_nc", str(INPUT_NC),
    "--output_nc", str(OUTPUT_NC),
    "--results_dir", str(RESULTS),
    "--epoch", "latest",
    "--num_test", "100000",
    "--eval",
]
print(">>>", " ".join(shlex.quote(c) for c in cmd))
ret = subprocess.run(cmd, check=False)
print("Exit code:", ret.returncode)
assert ret.returncode == 0, "test.py failed — check logs above"

# ---------- Locate the produced images folder ----------
images_dirs = list(RESULTS.glob(f"{EXPERIMENT}/test_latest/images"))
assert images_dirs, f"'images' folder not found under {RESULTS}"
IMDIR = images_dirs[0]
print("Output folder:", IMDIR)

# ---------- Build grids (Input | Output) and save ----------
GRIDS_DIR = Path("/kaggle/working/preds_yael_grids")
GRIDS_DIR.mkdir(parents=True, exist_ok=True)

fake_list = sorted(IMDIR.glob("*_fake.png"))
shown = 0

for fake_path in fake_list:
    stem = fake_path.name.replace("_fake.png","")
    real_path = IMDIR / f"{stem}_real.png"
    if not real_path.exists():
        continue
    # load
    A = Image.open(real_path).convert("RGB")
    B = Image.open(fake_path)
    # normalize for uniform display
    A_show = A.resize((256,256))
    # if grayscale, show with gray-like look — but grids are saved in RGB
    if B.mode != "RGB":
        B_show = B.convert("L").resize((256,256))
        B_show = Image.merge("RGB", (B_show, B_show, B_show))
    else:
        B_show = B.resize((256,256))
    # grid
    W = Image.new("RGB", (256*2, 256))
    W.paste(A_show, (0,0))
    W.paste(B_show, (256,0))
    grid_name = f"{stem}_grid.png"
    W.save(GRIDS_DIR / grid_name)

print(f"Grids saved under: {GRIDS_DIR}")

# ---------- Display in notebook (first 12 examples to avoid clutter) ----------
to_show = sorted(GRIDS_DIR.glob("*_grid.png"))[:12]
fig, axes = plt.subplots(len(to_show), 1, figsize=(7, 3*len(to_show)))
if len(to_show) == 1:
    axes = [axes]
for ax, p in zip(axes, to_show):
    ax.imshow(Image.open(p))
    ax.set_title(p.name)
    ax.axis("off")
plt.tight_layout()
plt.show()

# ---------- Create ZIPs for download ----------
# All inference results
!cd /kaggle/working && zip -r -q preds_yael_results.zip inference_preds_yael
# All grids
!cd /kaggle/working && zip -r -q preds_yael_grids.zip preds_yael_grids
print("\nDownload from:")
print("/kaggle/working/preds_yael_results.zip")
print("/kaggle/working/preds_yael_grids.zip")
