In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/pcos-picture/PCOSGen-test/images/image10876.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image11269.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image11210.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image10743.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image10247.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image10969.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image10633.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image10696.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image10052.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image11106.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image10497.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image10349.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image10425.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image10432.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image11016.jpg
/kaggle/input/pcos-picture/PCOSGen-test/images/image10051.jpg
/kaggle/

In [2]:
# ======================================================================
# PCOSGen Kaggle Pipeline (Golden++ Version)
# Author: Rizvi + ChatGPT (Aug 2025)
# ======================================================================

# -------------------------
# Install / pin deps
# -------------------------
!pip -q uninstall -y albumentations albucore
!pip -q install --no-deps albucore==0.0.20 albumentations==1.4.16

import albucore, albumentations
print("albucore:", albucore.__version__, "albumentations:", albumentations.__version__)

# -------------------------
# Imports
# -------------------------
import os, gc, math, time, random, glob
from dataclasses import dataclass
from pathlib import Path
import importlib, sys
if 'albumentations' in sys.modules: importlib.reload(sys.modules['albumentations'])

import cv2
import numpy as np
import pandas as pd
from PIL import Image

from sklearn.metrics import (
    roc_auc_score, f1_score, accuracy_score, confusion_matrix,
    roc_curve, auc, precision_recall_curve, ConfusionMatrixDisplay
)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import WeightedRandomSampler

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from rich import print as rprint

# -------------------------
# Config
# -------------------------
@dataclass
class CFG:
    seed: int = 42
    num_workers: int = 4
    folds: int = 5
    img_size: int = 448
    batch_size: int = 24
    epochs: int = 12
    lr: float = 2e-4
    weight_decay: float = 1e-5
    warmup_epochs: int = 1
    patience: int = 3
    backbone: str = "convnext_tiny.fb_in22k"
    use_amp: bool = True
    use_focal: bool = True
    focal_gamma: float = 2.0
    # Paths
    BASE: str = "/kaggle/input/pcos-picture"
    TRAIN_DIR: str = "/kaggle/input/pcos-picture/PCOSGen-train (1)/PCOSGen-train/images"
    TEST_DIR: str = "/kaggle/input/pcos-picture/PCOSGen-test/images"
    LABELS_FILE: str = "/kaggle/input/pcos-picture/PCOSGen-train (1)/PCOSGen-train/class_label.xlsx"
    OUT_DIR: str = "/kaggle/working/pcosgen_out"

os.makedirs(CFG.OUT_DIR, exist_ok=True)

# -------------------------
# Utils & Determinism
# -------------------------
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

set_seed(CFG.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rprint(f"[bold green]Device:[/bold green] {device}")

n_gpus = torch.cuda.device_count()
USE_DP = n_gpus > 1
rprint(f"[bold green]GPUs:[/bold green] {n_gpus}  |  DataParallel: {USE_DP}")
if n_gpus > 0: CFG.num_workers = max(2, 2 * n_gpus)

# Stronger determinism (optional but recommended)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def _seed_worker(worker_id):
    worker_seed = CFG.seed + worker_id
    np.random.seed(worker_seed); random.seed(worker_seed)
_gen = torch.Generator(); _gen.manual_seed(CFG.seed)

# -------------------------
# Data Loading (from Excel)
# -------------------------
def load_labels_from_table(path):
    ext = os.path.splitext(path)[1].lower()
    df_lab = pd.read_excel(path, sheet_name=0) if ext in [".xlsx", ".xls"] else pd.read_csv(path)
    df_lab.columns = [str(c).strip().lower() for c in df_lab.columns]
    img_col = next((c for c in ["imagepath","image","filename","file","img","name"] if c in df_lab.columns), None)
    lab_col = next((c for c in ["healthy","label","class","target","y"] if c in df_lab.columns), None)
    assert img_col is not None and lab_col is not None, "Image or Label column not found in Excel/CSV."
    name = (df_lab[img_col].astype(str).str.strip().apply(os.path.basename).str.lower())
    lab  = df_lab[lab_col].astype(str).str.strip().str.lower().map({"1":1,"0":0,"healthy":1,"unhealthy":0,"h":1,"u":0}).astype(int)
    table = pd.DataFrame({"image": name, "label": lab}).dropna().drop_duplicates("image", keep="last")
    return dict(zip(table["image"], table["label"]))

def list_images(dir_path):
    exts = (".jpg",".jpeg",".png",".bmp",".tif",".tiff")
    return sorted([str(p) for p in Path(dir_path).glob("*") if p.suffix.lower() in exts])

train_files, test_files = list_images(CFG.TRAIN_DIR), list_images(CFG.TEST_DIR)
bn_to_path = {os.path.basename(p).lower(): p for p in train_files}
label_map = load_labels_from_table(CFG.LABELS_FILE)
df = pd.DataFrame([(bn_to_path[bn], lab) for bn, lab in label_map.items() if bn in bn_to_path], columns=["path","y"])
rprint(f"[bold]Labeled train images used:[/bold] {len(df)} / {len(train_files)}")
rprint(df.y.value_counts())

# -------------------------
# Duplicate / Leakage Scan (exact MD5 + near-exact aHash)
# -------------------------
import hashlib
def md5_bytes(fp, block=1<<20):
    h = hashlib.md5()
    with open(fp, "rb") as f:
        for chunk in iter(lambda: f.read(block), b""):
            h.update(chunk)
    return h.hexdigest()

def ahash_8x8(fp):
    try:
        img = Image.open(fp).convert("L").resize((8,8), Image.BILINEAR)
        arr = np.asarray(img, dtype=np.float32)
        return int((arr > arr.mean()).flatten().dot(1<<np.arange(64)))
    except Exception:
        return None

def scan_duplicates(train_files, test_files):
    out = []
    # exact train<->test and intra-train
    md5_map = {}
    for p in train_files:
        md = md5_bytes(p); md5_map.setdefault(md, []).append(("train", p))
    for p in test_files:
        md = md5_bytes(p)
        if md in md5_map:
            for _, t in md5_map[md]:
                out.append(("exact_train_test", os.path.basename(t), os.path.basename(p)))
    for md, lst in md5_map.items():
        if len(lst) > 1:
            names = [os.path.basename(p) for _, p in lst]
            out.append(("exact_train_train", names[0], names[1]))
    # near-exact aHash train<->test
    ah_map = {}
    for p in train_files:
        h = ahash_8x8(p)
        if h is not None: ah_map.setdefault(h, []).append(("train", p))
    for p in test_files:
        h = ahash_8x8(p)
        if h is not None and h in ah_map:
            for _, t in ah_map[h]:
                out.append(("ahash_train_test", os.path.basename(t), os.path.basename(p)))

    if out:
        leak_path = f"{CFG.OUT_DIR}/leakage_report.csv"
        pd.DataFrame(out, columns=["type","train_image","test_image"]).to_csv(leak_path, index=False)
        rprint(f"[yellow]Leakage/duplicates report written to {leak_path}[/yellow]")
    else:
        rprint("[green]No exact/near-exact duplicates detected.[/green]")

scan_duplicates(train_files, test_files)

# -------------------------
# Augmentations, Dataset, Model, Losses, Scheduler
# -------------------------
IMAGENET_MEAN, IMAGENET_STD = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)

train_tfms = A.Compose([
    A.LongestMaxSize(max_size=CFG.img_size),
    A.PadIfNeeded(CFG.img_size, CFG.img_size, border_mode=cv2.BORDER_REFLECT),
    A.CLAHE(clip_limit=2.0, tile_grid_size=(8,8), p=0.6),
    A.RandomResizedCrop(CFG.img_size, CFG.img_size, scale=(0.85, 1.0), ratio=(0.9,1.1), p=0.8),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.Rotate(limit=10, p=0.5, border_mode=cv2.BORDER_REFLECT),
    A.MultiplicativeNoise(p=0.3),
    A.RandomBrightnessContrast(p=0.5),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(),
])

valid_tfms = A.Compose([
    A.LongestMaxSize(max_size=CFG.img_size),
    A.PadIfNeeded(CFG.img_size, CFG.img_size, border_mode=cv2.BORDER_REFLECT),
    A.CLAHE(clip_limit=2.0, tile_grid_size=(8,8), p=1.0),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(),
])

# Hardened Dataset (handles corrupt / grayscale)
class USDataset(Dataset):
    def __init__(self, df, transform):
        self.df = df.reset_index(drop=True)
        self.transform = transform
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = cv2.imread(row.path, cv2.IMREAD_UNCHANGED)
        if img is None:
            img = np.zeros((CFG.img_size, CFG.img_size), dtype=np.uint8)
        if img.ndim == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        else:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        x = self.transform(image=img)["image"]
        y = int(row.y)
        return x, torch.tensor([y], dtype=torch.float32)

class PCOSNet(nn.Module):
    def __init__(self, backbone=CFG.backbone, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(backbone, pretrained=pretrained, num_classes=1, in_chans=3, drop_rate=0.2)
    def forward(self, x): return self.backbone(x).squeeze(1)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__(); self.alpha, self.gamma = alpha, gamma
    def forward(self, logits, targets):
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
        p_t = torch.exp(-bce)
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        loss = alpha_t * (1 - p_t).pow(self.gamma) * bce
        return loss.mean()

def build_cosine_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# -------------------------
# Metrics & Calibration (bias-free temperature)
# -------------------------
def fit_temperature(logits, targets):
    logits = torch.as_tensor(logits, dtype=torch.float32, device=device).unsqueeze(1)
    targets = torch.as_tensor(targets, dtype=torch.float32, device=device).unsqueeze(1)
    model = nn.Linear(1, 1, bias=False).to(device)
    with torch.no_grad(): model.weight.fill_(1.0)
    opt = torch.optim.LBFGS(model.parameters(), lr=0.01, max_iter=50)
    def _closure():
        opt.zero_grad()
        loss = F.binary_cross_entropy_with_logits(model(logits), targets)
        loss.backward()
        return loss
    opt.step(_closure)
    w = model.weight.item()
    return max(1e-3, 1.0 / w)  # return T

def compute_metrics(y_true, y_prob, thr=0.5):
    y_pred = (y_prob >= thr).astype(int)
    auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true))>1 else 0.5
    f1, acc = f1_score(y_true, y_pred), accuracy_score(y_true, y_pred)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()
    spec, rec = tn / (tn+fp+1e-9), tp / (tp+fn+1e-9)
    return dict(AUC=auc, F1=f1, ACC=acc, SPEC_unhealthy=spec, RECALL_healthy=rec, J=rec+spec-1.0)

def best_threshold(y_true, y_prob, metric="J", n=1001):
    thrs = np.linspace(0.0, 1.0, n); best_t, best_v = 0.5, -1.0
    for t in thrs:
        v = compute_metrics(y_true, y_prob, thr=t)[metric]
        if v > best_v: best_v, best_t = v, t
    return float(best_t), float(best_v)

def find_threshold_for_spec(y_true, y_prob, target_spec=0.80):
    for t in np.linspace(0, 1, 1001):
        if compute_metrics(y_true, y_prob, thr=t)['SPEC_unhealthy'] >= target_spec:
            return float(t)
    return 0.5

# -------------------------
# DataLoaders & Helpers
# -------------------------
def get_loaders(train_df, valid_df):
    class_counts = train_df.y.value_counts()
    weights = (1.0 / train_df.y.map(class_counts)).values
    sampler = WeightedRandomSampler(weights=weights, num_samples=len(train_df), replacement=True)
    dl_train = DataLoader(
        USDataset(train_df, train_tfms),
        batch_size=CFG.batch_size, sampler=sampler,
        num_workers=CFG.num_workers, pin_memory=True,
        worker_init_fn=_seed_worker, generator=_gen
    )
    dl_valid = DataLoader(
        USDataset(valid_df, valid_tfms),
        batch_size=CFG.batch_size*2, shuffle=False,
        num_workers=CFG.num_workers, pin_memory=True,
        worker_init_fn=_seed_worker, generator=_gen
    )
    return dl_train, dl_valid

def load_ckpt_into_base(ckpt, backbone, device):
    model = PCOSNet(backbone, pretrained=False).to(device)
    sd = ckpt["state_dict"]
    if any(k.startswith("module.") for k in sd.keys()):
        sd = {k.replace("module.", "", 1): v for k, v in sd.items()}
    model.load_state_dict(sd, strict=True)
    return model

def tta_logits(model, xb, do_transpose=True):
    outs = [
        model(xb),
        model(torch.flip(xb, dims=[-1])),
        model(torch.flip(xb, dims=[-2])),
        model(torch.flip(torch.flip(xb, [-1]), [-2])),
    ]
    if do_transpose:
        xbt = xb.transpose(-1, -2).contiguous()
        outs += [
            model(xbt),
            model(torch.flip(xbt, dims=[-1])),
            model(torch.flip(xbt, dims=[-2])),
            model(torch.flip(torch.flip(xbt, [-1]), [-2])),
        ]
    return sum(outs) / len(outs)

# -------------------------
# K-fold Training
# -------------------------
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=CFG.folds, shuffle=True, random_state=CFG.seed)
df = df.sample(frac=1.0, random_state=CFG.seed).reset_index(drop=True)
oof_logits, oof_y = np.zeros(len(df), dtype=np.float32), df.y.values.astype(np.float32)
fold_Ts = []

for fold, (tr_idx, va_idx) in enumerate(skf.split(df, df.y), 1):
    rprint(f"[bold cyan]\n========= FOLD {fold}/{CFG.folds} =========[/bold cyan]")
    tr_df, va_df = df.iloc[tr_idx], df.iloc[va_idx]
    dl_train, dl_valid = get_loaders(tr_df, va_df)

    model = PCOSNet(CFG.backbone, pretrained=True).to(device)
    if USE_DP: model = nn.DataParallel(model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    scheduler = build_cosine_with_warmup(optimizer, len(dl_train)*CFG.warmup_epochs, len(dl_train)*CFG.epochs)

    if CFG.use_focal:
        alpha_pos = (tr_df.y == 0).mean()  # weight positives by share of negatives
        criterion = FocalLoss(alpha=float(alpha_pos), gamma=CFG.focal_gamma)
    else:
        pos_weight = torch.tensor([(tr_df.y==0).sum()/(tr_df.y==1).sum()], device=device)
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    scaler = torch.amp.GradScaler('cuda', enabled=CFG.use_amp)
    best_auc, best_path, patience = -1.0, f"{CFG.OUT_DIR}/model_fold{fold}.pt", CFG.patience

    for epoch in range(1, CFG.epochs + 1):
        model.train(); tr_loss = 0.0
        for xb, yb in dl_train:
            xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True).squeeze(1)
            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast('cuda', enabled=CFG.use_amp):
                logits = model(xb); loss = criterion(logits, yb)
            scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update(); scheduler.step()
            tr_loss += loss.item() * xb.size(0)
        tr_loss /= len(tr_df)

        model.eval()
        with torch.no_grad():
            val_logits_epoch = np.concatenate([model(xb.to(device)).detach().cpu().numpy() for xb, _ in dl_valid])
        probs = 1/(1+np.exp(-val_logits_epoch))
        m = compute_metrics(va_df.y.values, probs)
        rprint(f"E{epoch:02d}: train_loss={tr_loss:.4f} AUC={m['AUC']:.4f} F1={m['F1']:.4f} ACC={m['ACC']:.4f}")

        if m["AUC"] > best_auc:
            best_auc, patience = m["AUC"], CFG.patience
            state_dict = model.module.state_dict() if USE_DP else model.state_dict()
            torch.save({"state_dict": state_dict, "auc": float(best_auc), "backbone": CFG.backbone}, best_path)
        else:
            patience -= 1
            if patience < 1:
                rprint("[yellow]Early stopping[/yellow]"); break

    ckpt = torch.load(best_path, map_location="cpu", weights_only=False)
    model = load_ckpt_into_base(ckpt, ckpt["backbone"], device)
    model.eval()
    with torch.no_grad():
        oof_logits[va_idx] = np.concatenate([tta_logits(model, xb.to(device)).detach().cpu().numpy() for xb, _ in dl_valid])

    fold_T = fit_temperature(oof_logits[va_idx], oof_y[va_idx])
    fold_Ts.append(fold_T)
    rprint(f"[bold yellow]Fold {fold} fitted T: {fold_T:.3f}[/bold yellow]")

# -------------------------
# OOF Analysis & Thresholding + Visualizations
# -------------------------
global_t_temp = fit_temperature(oof_logits, oof_y)
oof_probs_cal = 1/(1+np.exp(-(oof_logits / max(1e-3, global_t_temp))))
rprint(f"[bold magenta]\nOOF Calibrated:[/bold magenta] {compute_metrics(oof_y, oof_probs_cal)}")

thr_j, _ = best_threshold(oof_y, oof_probs_cal, metric="J")
thr_s80 = find_threshold_for_spec(oof_y, oof_probs_cal)
rprint(f"[bold blue]OOF-cal @ J* thr {thr_j:.3f} ->[/bold blue] {compute_metrics(oof_y, oof_probs_cal, thr=thr_j)}")
rprint(f"[bold blue]OOF-cal @ SPEC>=0.80 thr {thr_s80:.3f} ->[/bold blue] {compute_metrics(oof_y, oof_probs_cal, thr=thr_s80)}")

# --- Visualization pack ---
import matplotlib.pyplot as plt

# ROC
fpr, tpr, _ = roc_curve(oof_y, oof_probs_cal)
roc_auc = auc(fpr, tpr)
plt.figure(); plt.plot(fpr, tpr, label=f"AUC={roc_auc:.3f}"); plt.plot([0,1],[0,1],'--')
plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate"); plt.legend()
plt.title("OOF ROC (Calibrated)"); plt.tight_layout()
plt.savefig(f"{CFG.OUT_DIR}/oof_roc.png", dpi=150); plt.close()

# PR
prec, rec, _ = precision_recall_curve(oof_y, oof_probs_cal)
plt.figure(); plt.plot(rec, prec)
plt.xlabel("Recall"); plt.ylabel("Precision")
plt.title("OOF Precision–Recall (Calibrated)"); plt.tight_layout()
plt.savefig(f"{CFG.OUT_DIR}/oof_pr.png", dpi=150); plt.close()

# Reliability
def reliability_plot(y_true, y_prob, path, n_bins=15):
    bins = np.linspace(0,1,n_bins+1); xs, ys = [], []
    for i in range(n_bins):
        sel = (y_prob>=bins[i]) & (y_prob<bins[i+1])
        if sel.sum()==0: continue
        xs.append((bins[i]+bins[i+1])/2); ys.append((y_true[sel]==1).mean())
    plt.figure(); plt.plot([0,1],[0,1],'--'); plt.scatter(xs, ys)
    plt.xlabel("Confidence"); plt.ylabel("Empirical Accuracy")
    plt.title("Reliability Diagram (OOF Calibrated)"); plt.tight_layout()
    plt.savefig(path, dpi=150); plt.close()

reliability_plot(oof_y, oof_probs_cal, f"{CFG.OUT_DIR}/oof_reliability.png")

# Confusion matrices
def save_cm(y_true, y_prob, thr, name):
    y_pred = (y_prob >= thr).astype(int)
    fig = ConfusionMatrixDisplay.from_predictions(
        y_true, y_pred, labels=[0,1], display_labels=["Unhealthy","Healthy"], cmap="Blues"
    ).figure_
    plt.title(name); plt.tight_layout()
    fig.savefig(f"{CFG.OUT_DIR}/{name}.png", dpi=150); plt.close(fig)

save_cm(oof_y, oof_probs_cal, thr_j,   "cm_oof_Jstar")
save_cm(oof_y, oof_probs_cal, thr_s80, "cm_oof_spec80")
rprint("[green]Saved ROC, PR, reliability, and confusion matrix plots to OUT_DIR.[/green]")

# -------------------------
# Inference on Test (per-fold temperature, TTA, averaged)
# -------------------------
test_loader = DataLoader(
    USDataset(pd.DataFrame({"path": test_files, "y": 0}), valid_tfms),
    batch_size=CFG.batch_size*2, shuffle=False,
    num_workers=CFG.num_workers, pin_memory=True,
    worker_init_fn=_seed_worker, generator=_gen
)

fold_logits_list = []
for fold in range(1, CFG.folds + 1):
    best_path = f"{CFG.OUT_DIR}/model_fold{fold}.pt"
    if not os.path.exists(best_path): continue
    ckpt = torch.load(best_path, map_location=device, weights_only=False)
    model = load_ckpt_into_base(ckpt, ckpt["backbone"], device)
    model.eval()
    logits_all = []
    with torch.no_grad():
        for xb, _ in test_loader:
            logits = tta_logits(model, xb.to(device, non_blocking=True))
            logits = logits / max(1e-3, fold_Ts[fold-1])  # per-fold T
            logits_all.append(logits.detach().cpu().numpy())
    fold_logits_list.append(np.concatenate(logits_all))

if not fold_logits_list: raise RuntimeError("No fold checkpoints found.")
test_logits = np.mean(np.stack(fold_logits_list, axis=0), axis=0)
test_probs = 1/(1+np.exp(-test_logits))

sub = pd.DataFrame({"image": [os.path.basename(p) for p in test_files], "prob_healthy": test_probs})
sub["pred_J"] = (sub.prob_healthy >= thr_j).astype(int)
sub["pred_S80"] = (sub.prob_healthy >= thr_s80).astype(int)
sub_path = f"{CFG.OUT_DIR}/submission.csv"
sub.sort_values("image").to_csv(sub_path, index=False)
rprint(f"[bold green]Saved final submission to {sub_path}[/bold green]")

gc.collect(); torch.cuda.empty_cache()


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m214.6/214.6 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0mta [36m0:00:01[0m
[?25h

  check_for_updates()


albucore: 0.0.20 albumentations: 1.4.16


model.safetensors:   0%|          | 0.00/178M [00:00<?, ?B/s]