<a href="https://colab.research.google.com/github/appababba/USDA/blob/main/unet_gabor_adaptor_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip -q install segmentation-models-pytorch==0.3.3 albumentations==1.4.7 --no-deps
import os, random, shutil, math, json, csv, glob, hashlib
from datetime import datetime
from glob import glob as gglob
import numpy as np
import cv2
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
from google.colab import drive
from tqdm import tqdm

CFG = {
    "RUN": dict(
        DO_BASELINE=True,
        DO_LOVASZ=False,
        DO_ADAPTER_VARIANT=True,
        SAVE_SAMPLES=24
    ),
    "GABOR": dict(
        enabled=True,
        mode="concat",
        learnable=False,
        kernel_size=15,
        orientations=8,
        wavelengths=[6, 10, 16],
        sigmas=[3.0, 5.0, 7.0],
        gamma=0.5,
        phases=[0.0, math.pi/2],
        magnitude=True
    ),
    "TRAIN": dict(
        img_size=(256, 256),
        batch_size=8,
        epochs=15,
        adapter_epochs=20,
        lr=1e-4,
        weight_decay=1e-4,
        pos_weight=9.66,
        random_seed=42
    ),
    "PATHS": dict(
        base_drive="/content/drive/Shared drives/USDA-Summer2025",
        data="data",
        images="Exported_Images",
        masks="Exported_Masks",
        models="models",
        exports="exports"
    )
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
SEED = CFG["TRAIN"]["random_seed"]
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

drive.mount('/content/drive', force_remount=False)

BASE_DRIVE = CFG["PATHS"]["base_drive"]
DATA_DIR   = os.path.join(BASE_DRIVE, CFG["PATHS"]["data"])
IMG_DRIVE  = os.path.join(DATA_DIR, CFG["PATHS"]["images"])
MSK_DRIVE  = os.path.join(DATA_DIR, CFG["PATHS"]["masks"])
MODELS_DIR = os.path.join(BASE_DRIVE, CFG["PATHS"]["models"])
EXPORTS_DIR= os.path.join(BASE_DRIVE, CFG["PATHS"]["exports"])
os.makedirs(MODELS_DIR, exist_ok=True); os.makedirs(EXPORTS_DIR, exist_ok=True)

LOCAL_ROOT = "/content/local_data"
IMG_LOCAL  = os.path.join(LOCAL_ROOT, CFG["PATHS"]["images"])
MSK_LOCAL  = os.path.join(LOCAL_ROOT, CFG["PATHS"]["masks"])

def ensure_local(src, dst):
    if os.path.isdir(dst) and any(True for _ in os.scandir(dst)):
        print(f"📦 Using existing local: {dst}"); return
    print(f"📥 Copying {src} -> {dst}")
    os.makedirs(dst, exist_ok=True)
    for root,_,files in os.walk(src):
        rel = os.path.relpath(root, src)
        out = os.path.join(dst, rel) if rel!="." else dst
        os.makedirs(out, exist_ok=True)
        for f in files:
            sp, dp = os.path.join(root,f), os.path.join(out,f)
            if not os.path.exists(dp):
                shutil.copy2(sp, dp)
    print("✅ Copy done.")

ensure_local(IMG_DRIVE, IMG_LOCAL)
ensure_local(MSK_DRIVE, MSK_LOCAL)

class SegDataset(Dataset):
    def __init__(self, image_paths, mask_dir, size=(256,256), transform=None):
        self.image_paths = image_paths
        self.mask_dir = mask_dir
        self.size = size
        self.transform = transform

    def __len__(self): return len(self.image_paths)

    def _mask_path_for(self, img_path):
        base = os.path.splitext(os.path.basename(img_path))[0]
        p1 = os.path.join(self.mask_dir, f"{base}_mask.png")
        if os.path.exists(p1): return p1
        for ext in ('.png','.jpg','.jpeg','.tif','.tiff'):
            p2 = os.path.join(self.mask_dir, base+ext)
            if os.path.exists(p2): return p2
        raise FileNotFoundError(f"No mask for {img_path}")

    def __getitem__(self, idx):
        ip = self.image_paths[idx]
        mp = self._mask_path_for(ip)
        img = cv2.cvtColor(cv2.imread(ip), cv2.COLOR_BGR2RGB)
        msk = cv2.imread(mp, cv2.IMREAD_GRAYSCALE)
        msk = (msk > 0).astype(np.uint8)

        if self.transform:
            out = self.transform(image=img, mask=msk)
            img, msk = out["image"], out["mask"]
        else:
            img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR).astype(np.float32)/255.0
            msk = cv2.resize(msk, self.size, interpolation=cv2.INTER_NEAREST)
            img = torch.from_numpy(img).permute(2,0,1).float()
            msk = torch.from_numpy(msk).unsqueeze(0).float()

        if isinstance(msk, torch.Tensor) and msk.ndim==2:
            msk = msk.unsqueeze(0)
        return img, msk.float()

class GaborStem(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        k = cfg["kernel_size"]; assert k % 2 == 1, "kernel_size must be odd"
        self.k = k; self.gamma = cfg["gamma"]; self.magnitude = cfg["magnitude"]; self.learnable = cfg["learnable"]

        ax = torch.arange(-(k//2), k//2 + 1).float()
        X, Y = torch.meshgrid(ax, ax, indexing='xy')
        self.register_buffer('X', X); self.register_buffer('Y', Y)

        self.orientations = cfg["orientations"]
        phases = torch.tensor(cfg["phases"], dtype=torch.float32)
        self.register_buffer('phases_buf', phases)

        lambdas = torch.tensor(cfg["wavelengths"], dtype=torch.float32)
        sigmas  = torch.tensor(cfg["sigmas"], dtype=torch.float32)
        if sigmas.ndim==0: sigmas = sigmas.repeat(lambdas.numel())
        assert lambdas.numel()==sigmas.numel(), "wavelengths and sigmas must align"

        thetas = torch.linspace(0, math.pi, steps=self.orientations, dtype=torch.float32)
        base = []
        for lam, sig in zip(lambdas, sigmas):
            for th in thetas:
                for ph in phases:
                    base.append((th.item(), sig.item(), lam.item(), ph.item()))
        base = torch.tensor(base, dtype=torch.float32)

        if self.learnable: self.params = nn.Parameter(base)
        else: self.register_buffer('params_buf', base)

        self.num_per_phase = lambdas.numel() * self.orientations
        self.out_channels = self.num_per_phase if (self.magnitude and len(phases)==2) else base.shape[0]
        self.norm = nn.InstanceNorm2d(self.out_channels, affine=False)

    def _kernels(self, device, dtype):
        P = self.params if self.learnable else self.params_buf
        P = P.to(device=device, dtype=dtype)
        X = self.X.to(device=device, dtype=dtype); Y = self.Y.to(device=device, dtype=dtype)
        gamma = torch.as_tensor(self.gamma, dtype=dtype, device=device)

        theta = P[:,0].view(-1,1,1); sigma = P[:,1].view(-1,1,1)
        lambd = P[:,2].view(-1,1,1); phase = P[:,3].view(-1,1,1)

        Xp =  X*torch.cos(theta) + Y*torch.sin(theta)
        Yp = -X*torch.sin(theta) + Y*torch.cos(theta)

        gauss = torch.exp(-(Xp**2 + (gamma*Yp)**2) / (2*sigma**2))
        carrier = torch.cos(2*math.pi*Xp / lambd + phase)
        g = gauss * carrier

        g = g - g.mean(dim=(1,2), keepdim=True)
        g = g / (g.square().sum(dim=(1,2), keepdim=True).sqrt() + 1e-8)
        return g

    def forward(self, x):
        device, dtype = x.device, x.dtype
        y = 0.299*x[:,0:1] + 0.587*x[:,1:2] + 0.114*x[:,2:3]
        K = self._kernels(device, dtype)
        if self.magnitude and self.phases_buf.numel()==2:
            N = self.num_per_phase
            k_cos = K[0::2].unsqueeze(1); k_sin = K[1::2].unsqueeze(1)
            rc = torch.conv2d(y, k_cos, padding=self.k//2)
            rs = torch.conv2d(y, k_sin, padding=self.k//2)
            feats = torch.sqrt(rc**2 + rs**2 + 1e-8)
        else:
            feats = torch.conv2d(y, K.unsqueeze(1), padding=self.k//2)
        return self.norm(feats)

class UNetWithGabor(nn.Module):
    def __init__(self, gcfg, in_img_ch=3):
        super().__init__()
        self.use_gabor = gcfg["enabled"]
        if self.use_gabor:
            self.gabor = GaborStem(gcfg)
            in_ch = in_img_ch + self.gabor.out_channels if gcfg["mode"]=="concat" else self.gabor.out_channels
        else:
            in_ch = in_img_ch
        encoder_weights = 'imagenet' if in_ch == 3 else None
        self.net = smp.Unet(encoder_name='resnet50', encoder_weights=encoder_weights,
                            in_channels=in_ch, classes=1, activation=None)
        self.mode = gcfg["mode"]

    def forward(self, x):
        if self.use_gabor:
            g = self.gabor(x)
            x = torch.cat([x, g], dim=1) if self.mode=="concat" else g
        return self.net(x)

class UNetWithGabor_Adapted(nn.Module):
    def __init__(self, gcfg, in_img_ch=3):
        super().__init__()
        self.use_gabor = gcfg["enabled"]; self.mode = gcfg["mode"]
        in_ch = in_img_ch
        if self.use_gabor:
            self.gabor = GaborStem(gcfg)
            in_ch = in_img_ch + self.gabor.out_channels if gcfg["mode"]=="concat" else self.gabor.out_channels

        if in_ch == 3:
            self.adapter = nn.Identity()
        else:
            mid_ch = max(16, in_ch // 4)
            self.adapter = nn.Sequential(
                nn.Conv2d(in_ch, mid_ch, kernel_size=1, bias=False),
                nn.BatchNorm2d(mid_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(mid_ch, 3, kernel_size=1, bias=False)
            )
        self.net = smp.Unet(encoder_name='resnet50', encoder_weights='imagenet',
                            in_channels=3, classes=1, activation=None)

    def forward(self, x):
        if self.use_gabor:
            g = self.gabor(x)
            x = torch.cat([x, g], dim=1) if self.mode=="concat" else g
        return self.net(self.adapter(x))

IMG_SIZE = CFG["TRAIN"]["img_size"]; BATCH = CFG["TRAIN"]["batch_size"]

train_tfms = A.Compose([
    A.RandomResizedCrop(IMG_SIZE[0], IMG_SIZE[1], scale=(0.6,1.0), ratio=(0.9,1.1), interpolation=cv2.INTER_LINEAR),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.RandomRotate90(p=0.2),
    A.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.1, hue=0.02, p=0.3),
    A.Normalize(), ToTensorV2(),
])
val_tfms = A.Compose([A.Resize(IMG_SIZE[0], IMG_SIZE[1], interpolation=cv2.INTER_LINEAR), A.Normalize(), ToTensorV2()])

all_imgs = []
for ext in ('*.jpg','*.jpeg','*.png','*.tif','*.tiff'):
    all_imgs.extend(gglob(os.path.join(IMG_LOCAL, ext)))
random.shuffle(all_imgs)

train_val, test_paths = train_test_split(all_imgs, test_size=0.2, random_state=SEED)
train_paths, val_paths = train_test_split(train_val, test_size=0.15, random_state=SEED)

train_ds = SegDataset(train_paths, MSK_LOCAL, size=IMG_SIZE, transform=train_tfms)
val_ds   = SegDataset(val_paths,   MSK_LOCAL, size=IMG_SIZE, transform=val_tfms)

train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,  num_workers=4, pin_memory=True, persistent_workers=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)

print(f"Data ready. Train: {len(train_ds)}, Val: {len(val_ds)}")

SPLIT_DIR = os.path.join(BASE_DRIVE, "splits")
os.makedirs(SPLIT_DIR, exist_ok=True)
with open(os.path.join(SPLIT_DIR, "train_files.txt"), "w") as f: [f.write(p+"\n") for p in train_paths]
with open(os.path.join(SPLIT_DIR, "val_files.txt"), "w")   as f: [f.write(p+"\n") for p in val_paths]
with open(os.path.join(SPLIT_DIR, "test_files.txt"), "w")  as f: [f.write(p+"\n") for p in test_paths]

def dice_loss_binary(logits, y, eps=1e-6):
    p = torch.sigmoid(logits)
    num = 2*(p*y).sum(dim=(1,2,3)) + eps
    den = (p*p).sum(dim=(1,2,3)) + (y*y).sum(dim=(1,2,3)) + eps
    return 1 - (num/den).mean()

POS_WEIGHT = torch.tensor([CFG["TRAIN"]["pos_weight"]], device=DEVICE)
bce_w = nn.BCEWithLogitsLoss(pos_weight=POS_WEIGHT)

def combined_loss_baseline(logits, y):
    return dice_loss_binary(logits, y) + bce_w(logits, y)

USE_BOUNDARY = True
try:
    from lovasz_losses import lovasz_hinge
    def boundary_loss(logits, y):
        lap = torch.tensor([[0,-1,0],[-1,4,-1],[0,-1,0]], dtype=torch.float32, device=logits.device).view(1,1,3,3)
        p = torch.sigmoid(logits)
        pe = F.conv2d(p, lap, padding=1); ye = F.conv2d(y, lap, padding=1)
        return F.l1_loss(torch.abs(pe), torch.abs(ye))
    def combined_loss_lovasz(logits, y):
        lh = lovasz_hinge(logits.squeeze(1), y.squeeze(1).float())
        dl = dice_loss_binary(logits, y)
        if USE_BOUNDARY:
            bl = boundary_loss(logits, y); return 0.45*dl + 0.45*lh + 0.10*bl
        else:
            return 0.5*dl + 0.5*lh
except Exception:
    def combined_loss_lovasz(*args, **kwargs):
        raise RuntimeError("Lovász not installed. Run pip install line at the top and re-run.")

@torch.no_grad()
def pooled_counts_at_thr(loader, model, thr=0.5):
    model.eval()
    tp=fp=fn=0
    for x,y in loader:
        x,y = x.to(DEVICE), y.to(DEVICE)
        logits = model(x)
        p = (torch.sigmoid(logits) > thr).to(y.dtype)
        tp += (p*y).sum().item()
        fp += (p*(1-y)).sum().item()
        fn += ((1-p)*y).sum().item()
    return tp, fp, fn

def iou_dice(tp, fp, fn, eps=1e-6):
    iou  = (tp+eps)/(tp+fp+fn+eps)
    dice = (2*tp+eps)/(2*tp+fp+fn+eps)
    return float(iou), float(dice)

def sweep_thresholds(loader, model, thr_grid=np.linspace(0.2,0.8,13)):
    rows=[]; best={"thr":None,"micro_IoU":-1,"Dice":-1}
    for thr in thr_grid:
        tp,fp,fn = pooled_counts_at_thr(loader, model, float(thr))
        iou,dice = iou_dice(tp,fp,fn)
        rows.append({"thr":float(thr),"micro_IoU":iou,"Dice":dice,"tp":tp,"fp":fp,"fn":fn})
        if iou>best["micro_IoU"]: best={"thr":float(thr),"micro_IoU":iou,"Dice":dice}
    return rows, best

def make_optimizer(model, base_lr, wd):
    param_groups = [
        {'params': model.net.encoder.parameters(), 'lr': base_lr / 10.0},
        {'params': model.net.decoder.parameters(), 'lr': base_lr},
        {'params': model.net.segmentation_head.parameters(), 'lr': base_lr},
    ]
    if hasattr(model, 'adapter') and not isinstance(model.adapter, nn.Identity):
         param_groups.append({'params': model.adapter.parameters(), 'lr': base_lr})
    return torch.optim.AdamW(param_groups, lr=base_lr, weight_decay=wd)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

def train_epoch(model, loader, optimizer, scaler, loss_fn):
    model.train(); tot=0.0; meter=0.0
    for x,y in tqdm(loader, desc="Train", leave=False):
        x,y = x.to(DEVICE,non_blocking=True), y.to(DEVICE,non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast(device_type='cuda' if DEVICE.type=='cuda' else 'cpu'):
            logits = model(x); loss = loss_fn(logits, y)
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        tot += loss.item()
        with torch.no_grad():
            tp,fp,fn = pooled_counts_at_thr([(x,y)], model, thr=0.5)
            iou,_ = iou_dice(tp,fp,fn); meter += iou
    return tot/max(1,len(loader)), meter/max(1,len(loader))

@torch.no_grad()
def validate_epoch(model, loader, loss_fn):
    model.eval(); tot=0.0; meter=0.0
    for x,y in loader:
        x,y = x.to(DEVICE,non_blocking=True), y.to(DEVICE,non_blocking=True)
        with torch.amp.autocast(device_type='cuda' if DEVICE.type=='cuda' else 'cpu'):
            logits = model(x); loss = loss_fn(logits, y)
        tot += loss.item()
        tp,fp,fn = pooled_counts_at_thr([(x,y)], model, thr=0.5)
        iou,_ = iou_dice(tp,fp,fn); meter += iou
    return tot/max(1,len(loader)), meter/max(1,len(loader))

def run_training(model, epochs, loss_fn, tag, patience=5):
    optimizer = make_optimizer(model, CFG["TRAIN"]["lr"], CFG["TRAIN"]["weight_decay"])
    scheduler = make_scheduler(optimizer)
    scaler = torch.amp.GradScaler('cuda' if DEVICE.type=='cuda' else 'cpu')
    best_iou = -1.0
    best_path = os.path.join(MODELS_DIR, f"{tag}.pth")
    epochs_no_improve = 0
    print(f"\n🚀 Training: {tag}")
    for ep in range(1, epochs+1):
        trL, trI = train_epoch(model, train_loader, optimizer, scaler, loss_fn)
        vaL, vaI = validate_epoch(model, val_loader, loss_fn)
        scheduler.step(vaI)
        print(f"Epoch {ep}/{epochs} | Train L {trL:.4f} IoU@0.5 {trI:.3f} || Val L {vaL:.4f} IoU@0.5 {vaI:.3f}")
        if vaI > best_iou:
            best_iou = vaI
            torch.save({"state_dict": model.state_dict(),
                        "img_size": IMG_SIZE, "gabor_cfg": CFG["GABOR"]}, best_path)
            print(f"💾 Saved best -> {best_path} (IoU@0.5 {best_iou:.4f})")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"🛑 Early stopping triggered after {patience} epochs with no improvement.")
            break
    return best_path, best_iou, model

baseline_tag = f"UNetRes50_Gabor-{CFG['GABOR']['mode']}_{'learn' if CFG['GABOR']['learnable'] else 'fixed'}" if CFG["GABOR"]["enabled"] else "UNetRes50_plain"
if CFG["RUN"]["DO_BASELINE"]:
    base_model = UNetWithGabor(CFG["GABOR"]).to(DEVICE)
    best_ckpt, best_iou, base_model = run_training(base_model, CFG["TRAIN"]["epochs"], combined_loss_baseline, baseline_tag)
    rows, best = sweep_thresholds(val_loader, base_model)
    print("\n=== Baseline sweep ===")
    for r in rows: print(f"thr={r['thr']:.2f}  micro_IoU={r['micro_IoU']:.4f}  Dice={r['Dice']:.4f}")
    print(f"BEST @ thr={best['thr']:.2f} | micro_IoU={best['micro_IoU']:.4f} Dice={best['Dice']:.4f}")

if CFG["RUN"]["DO_LOVASZ"]:
    if not 'base_model' in locals():
        base_model = UNetWithGabor(CFG["GABOR"]).to(DEVICE)
        sd = torch.load(os.path.join(MODELS_DIR, baseline_tag+".pth"), map_location=DEVICE)["state_dict"]
        base_model.load_state_dict(sd, strict=True)
    lovasz_tag = baseline_tag + "_lovasz"
    best_ckpt_l, best_iou_l, base_model = run_training(base_model, 10, combined_loss_lovasz, lovasz_tag)
    rows, best = sweep_thresholds(val_loader, base_model)
    print("\n=== Lovász sweep ===")
    for r in rows: print(f"thr={r['thr']:.2f}  micro_IoU={r['micro_IoU']:.4f}  Dice={r['Dice']:.4f}")
    print(f"BEST @ thr={best['thr']:.2f} | micro_IoU={best['micro_IoU']:.4f} Dice={best['Dice']:.4f}")

adapter_tag = "UNetRes50_GaborAdapter_imagenet" if CFG["GABOR"]["enabled"] else "UNetRes50_Adapter_imagenet"
if CFG["RUN"]["DO_ADAPTER_VARIANT"]:
    adapter_model = UNetWithGabor_Adapted(CFG["GABOR"]).to(DEVICE)
    best_ckpt_ad, best_iou_ad, adapter_model = run_training(adapter_model, CFG["TRAIN"]["adapter_epochs"], combined_loss_baseline, adapter_tag)
    rows, best_ad = sweep_thresholds(val_loader, adapter_model)
    print("\n=== Adapter sweep ===")
    for r in rows: print(f"thr={r['thr']:.2f}  micro_IoU={r['micro_IoU']:.4f}  Dice={r['Dice']:.4f}")
    print(f"BEST @ thr={best_ad['thr']:.2f} | micro_IoU={best_ad['micro_IoU']:.4f} Dice={best_ad['Dice']:.4f}")

EXP_NAME = adapter_tag if CFG["RUN"]["DO_ADAPTER_VARIANT"] else (lovasz_tag if CFG["RUN"]["DO_LOVASZ"] else baseline_tag)
OUT = os.path.join(EXPORTS_DIR, EXP_NAME, datetime.now().strftime("%Y%m%d-%H%M%S"))
os.makedirs(OUT, exist_ok=True)

active_model = adapter_model if CFG["RUN"]["DO_ADAPTER_VARIANT"] else (base_model)

torch.save({
    "state_dict": active_model.state_dict(),
    "gabor_cfg": CFG["GABOR"],
    "img_size": IMG_SIZE,
    "exp_name": EXP_NAME
}, os.path.join(OUT, "checkpoint_full.pth"))
torch.save(active_model.state_dict(), os.path.join(OUT, "weights_only.pth"))

thr_grid = np.linspace(0.2,0.8,13)
rows, best = sweep_thresholds(val_loader, active_model, thr_grid)
with open(os.path.join(OUT, "val_threshold_sweep.csv"), "w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=list(rows[0].keys())); w.writeheader(); w.writerows(rows)
with open(os.path.join(OUT, "summary.json"), "w") as f:
    json.dump({"exp": EXP_NAME, "best": best, "img_size": IMG_SIZE, "gabor_cfg": CFG["GABOR"], "device": str(DEVICE)}, f, indent=2)

!pip freeze > "{OUT}/pip-freeze.txt"

try:
    active_model.eval()
    H,W = IMG_SIZE
    example = torch.randn(1,3,H,W, device=DEVICE)
    torch.jit.trace(active_model, example, strict=False).save(os.path.join(OUT, "model_traced.pt"))
    torch.onnx.export(
        active_model, example, os.path.join(OUT, "model.onnx"),
        input_names=["image"], output_names=["logits"], opset_version=17,
        dynamic_axes={"image":{0:"N",2:"H",3:"W"}, "logits":{0:"N",2:"H",3:"W"}}
    )
except Exception as e:
    print("[Export skipped]", e)

try:
    import torchvision
    SAMPLE_OUT = os.path.join(OUT, "preds_val_samples"); os.makedirs(SAMPLE_OUT, exist_ok=True)
    best_thr = float(best["thr"])
    active_model.eval(); saved=0
    with torch.no_grad():
        for i,(x,y) in enumerate(val_loader):
            logits = active_model(x.to(DEVICE))
            pm = (torch.sigmoid(logits) > best_thr).float().cpu()
            for b in range(min(pm.size(0), 4)):
                torchvision.utils.save_image(pm[b], os.path.join(SAMPLE_OUT, f"val_{i:04d}_{b:02d}.png"))
                saved += 1
                if saved >= CFG["RUN"]["SAVE_SAMPLES"]: break
            if saved >= CFG["RUN"]["SAVE_SAMPLES"]: break
except Exception as e:
    print("[Sample export skipped]", e)

with open(os.path.join(OUT, "CFG.json"), "w") as f: json.dump(CFG, f, indent=2, default=str)
for split_name, split_list in [("train_files.txt", train_paths), ("val_files.txt", val_paths), ("test_files.txt", test_paths)]:
    with open(os.path.join(OUT, split_name), "w") as f:
        for p in split_list: f.write(p+"\n")

def sha256(p):
    h=hashlib.sha256()
    with open(p,'rb') as f:
        for chunk in iter(lambda: f.read(1<<20), b''): h.update(chunk)
    return h.hexdigest()
artifacts = ["checkpoint_full.pth","weights_only.pth","val_threshold_sweep.csv","summary.json"]
hashes = {a: sha256(os.path.join(OUT,a)) for a in artifacts if os.path.exists(os.path.join(OUT,a))}
with open(os.path.join(OUT,"sha256.json"),"w") as f: json.dump(hashes,f,indent=2)

print("\n✅ Export saved to:", OUT)
print("🔐 Hashes:", hashes)