# UHViT: Complete Training Pipeline for IF 5+ Journal
## Skin Lesion Classification with Uncertainty Quantification


In [1]:
#Install Dependencies (Fixed for RunPod)

# Upgrade all packages to numpy 2.x compatible versions
!pip install -q --upgrade numpy==2.0.2
!pip install -q --upgrade matplotlib>=3.8.0 scipy>=1.12.0 scikit-learn>=1.4.0
!pip install -q --upgrade timm>=1.0.0
!pip install -q --upgrade wandb>=0.17.0  # Fixed: needs numpy 2.x compatible version

# Now install the rest
!pip install -q albumentations>=1.3.1
!pip install -q kaggle==1.5.16
!pip install -q grad-cam
!pip install -q pandas>=2.0.3
!pip install -q seaborn>=0.13.0
!pip install -q tqdm>=4.66.1
!pip install -q ptflops>=0.7
!pip install -q gdown
!pip install -q opencv-python-headless


In [2]:
# CELL 1B: EXTRA DEPENDENCY FOR EXTERNAL DATASET
!pip install -q medmnist
print("✓ medmnist installed")

✓ medmnist installed


In [3]:
# CELL 2: IMPORTS
import os, sys, json, random, time, warnings
from pathlib import Path
from collections import Counter
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import timm
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, confusion_matrix, roc_auc_score, precision_recall_fscore_support, roc_curve, auc
from sklearn.calibration import calibration_curve
from scipy import stats
import wandb
warnings.filterwarnings('ignore')
print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}")



PyTorch: 2.8.0+cu128, CUDA: True
GPU: NVIDIA GeForce RTX 5090


In [4]:
# CELL 3: CONFIGURATION
CONFIG = {
    'KAGGLE_USERNAME': 'sushovanchaudhury',  # UPDATE THIS
    'KAGGLE_KEY': '6ad1238ff7f05a1d17687883ac1cb80a',            # UPDATE THIS
    'WANDB_API_KEY': '07ac9d1305afd5973d9022f150eb134c3b77ca7b',          # UPDATE THIS
    'seed': 42, 'n_folds': 5, 'epochs': 30, 'batch_size': 32, 'img_size': 224,
    'num_workers': 4, 'lr': 1e-4, 'weight_decay': 1e-4, 'mc_dropout_samples': 10,
    'num_classes': 8, 'dropout_rate': 0.3,
    'swin_model': 'swin_tiny_patch4_window7_224', 'efficientnet_model': 'efficientnet_b3',
    'class_names': ['AK', 'BCC', 'BKL', 'DF', 'MEL', 'NV', 'SCC', 'VASC'],
    'data_dir': './data', 'output_dir': './outputs', 'checkpoint_dir': './checkpoints',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}
for d in [CONFIG['data_dir'], CONFIG['output_dir'], CONFIG['checkpoint_dir']]:
    Path(d).mkdir(parents=True, exist_ok=True)
def set_seed(seed):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True
set_seed(CONFIG['seed'])
print(f"✓ Config set. Device: {CONFIG['device']}")



✓ Config set. Device: cuda


In [5]:
# CELL 4: SETUP APIs
os.makedirs(os.path.expanduser('~/.kaggle'), exist_ok=True)
with open(os.path.expanduser('~/.kaggle/kaggle.json'), 'w') as f:
    json.dump({"username": CONFIG['KAGGLE_USERNAME'], "key": CONFIG['KAGGLE_KEY']}, f)
os.chmod(os.path.expanduser('~/.kaggle/kaggle.json'), 0o600)
os.environ['WANDB_API_KEY'] = CONFIG['WANDB_API_KEY']
wandb.login(key=CONFIG['WANDB_API_KEY'])
print("✓ APIs configured")



[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msushovan-chaudhury[0m ([33msushovan-chaudhury-bits-pilani[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


✓ APIs configured


In [6]:
!kaggle datasets list -s pad-ufes -p 5

No datasets found


In [7]:
# CELL 5: DOWNLOAD DATASETS  (REPLACE THIS ENTIRE CELL)

from pathlib import Path

isic_path = Path(CONFIG['data_dir']) / 'isic-2019'
if not isic_path.exists():
    !kaggle datasets download -d andrewmvd/isic-2019 -p {CONFIG['data_dir']}
    !unzip -q {CONFIG['data_dir']}/isic-2019.zip -d {CONFIG['data_dir']}/isic-2019
    !rm {CONFIG['data_dir']}/isic-2019.zip

ham_path = Path(CONFIG['data_dir']) / 'ham10000'
if not ham_path.exists():
    !kaggle datasets download -d kmader/skin-cancer-mnist-ham10000 -p {CONFIG['data_dir']}
    !unzip -q {CONFIG['data_dir']}/skin-cancer-mnist-ham10000.zip -d {CONFIG['data_dir']}/ham10000
    !rm {CONFIG['data_dir']}/skin-cancer-mnist-ham10000.zip



print("✓ Datasets ready (ISIC2019, HAM10000)")



Downloading isic-2019.zip to ./data
100%|██████████████████████████████████████▉| 9.10G/9.10G [00:59<00:00, 126MB/s]
100%|███████████████████████████████████████| 9.10G/9.10G [00:59<00:00, 165MB/s]
Downloading skin-cancer-mnist-ham10000.zip to ./data
100%|██████████████████████████████████████▉| 5.18G/5.20G [00:26<00:00, 233MB/s]
100%|███████████████████████████████████████| 5.20G/5.20G [00:26<00:00, 209MB/s]
✓ Datasets ready (ISIC2019, HAM10000)


In [8]:
# FIX: find correct ISIC2019 CSV + image directory (robust)
from pathlib import Path
import pandas as pd

isic_base = Path(CONFIG["data_dir"]) / "isic-2019"

# 1) Find a CSV that has an 'image' column and class columns
candidate_csvs = list(isic_base.rglob("*.csv"))
best_csv = None
for p in candidate_csvs:
    try:
        df = pd.read_csv(p, nrows=5)
        if "image" in df.columns:
            best_csv = p
            break
    except:
        pass

print("Chosen CSV:", best_csv)
train_df = pd.read_csv(best_csv)

# 2) Find an images directory that contains lots of .jpg files
candidate_dirs = [d for d in isic_base.rglob("*") if d.is_dir()]
def jpg_count(d):
    return len(list(d.glob("*.jpg"))) + len(list(d.glob("*.JPG"))) + len(list(d.glob("*.jpeg"))) + len(list(d.glob("*.png")))

best_img_dir = max(candidate_dirs, key=jpg_count)
print("Chosen image dir:", best_img_dir, "| image files:", jpg_count(best_img_dir))

train_img_dir = best_img_dir
print("train_df columns:", train_df.columns[:15])
print("train_df size:", len(train_df))

Chosen CSV: data/isic-2019/ISIC_2019_Training_GroundTruth.csv
Chosen image dir: data/isic-2019/ISIC_2019_Training_Input/ISIC_2019_Training_Input | image files: 25331
train_df columns: Index(['image', 'MEL', 'NV', 'BCC', 'AK', 'BKL', 'DF', 'VASC', 'SCC', 'UNK'], dtype='object')
train_df size: 25331


In [9]:
def get_train_transforms(img_size):
    return A.Compose([
        A.RandomResizedCrop(size=(img_size, img_size), scale=(0.8, 1.0)),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45, p=0.5),
        A.OneOf([
            A.GaussNoise(var_limit=(10, 50)),
            A.GaussianBlur(blur_limit=(3, 7))
        ], p=0.3),
        A.OneOf([
            A.CLAHE(clip_limit=4.0),
            A.RandomBrightnessContrast()
        ], p=0.5),
        A.HueSaturationValue(p=0.5),
        A.CoarseDropout(
            num_holes_range=(1, 8),
            hole_height_range=(img_size//16, img_size//8),
            hole_width_range=(img_size//16, img_size//8),
            p=0.3
        ),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

def get_val_transforms(img_size):
    return A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])


In [10]:
import albumentations as A
import albumentations
print("Albumentations version:", albumentations.__version__)

# instantiate to force validation now
t = get_train_transforms(CONFIG["img_size"])
print("Train transforms OK:", t)

Albumentations version: 2.0.8
Train transforms OK: Compose([
  RandomResizedCrop(p=1.0, area_for_downscale=None, interpolation=1, mask_interpolation=0, ratio=(0.75, 1.3333333333333333), scale=(0.8, 1.0), size=(224, 224)),
  HorizontalFlip(p=0.5),
  VerticalFlip(p=0.5),
  RandomRotate90(p=0.5),
  ShiftScaleRotate(p=0.5, shift_limit_x=(-0.1, 0.1), shift_limit_y=(-0.1, 0.1), scale_limit=(-0.09999999999999998, 0.10000000000000009), rotate_limit=(-45.0, 45.0), interpolation=1, border_mode=0, fill=0.0, fill_mask=0.0, rotate_method='largest_box', mask_interpolation=0),
  OneOf([
    GaussNoise(p=0.5, mean_range=(0.0, 0.0), noise_scale_factor=1.0, per_channel=True, std_range=(0.2, 0.44)),
    GaussianBlur(p=0.5, blur_limit=(3, 7), sigma_limit=(0.5, 3.0)),
  ], p=0.3),
  OneOf([
    CLAHE(p=0.5, clip_limit=(1.0, 4.0), tile_grid_size=(8, 8)),
    RandomBrightnessContrast(p=0.5, brightness_by_max=True, brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), ensure_safe_range=False),
  ], p=0.5)

In [11]:
# CELL 7: DATASET CLASS
class ISIC2019Dataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = Path(img_dir)
        self.transform = transform
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row['image']
        img_path = None
        for ext in ['.jpg', '.jpeg', '.png', '.JPG']:
            p = self.img_dir / f"{img_name}{ext}"
            if p.exists(): img_path = p; break
        if not img_path:
            for f in self.img_dir.glob(f"{img_name}*"): img_path = f; break
        image = np.array(Image.open(img_path).convert('RGB'))
        label = 0
        for i, cls in enumerate(CONFIG['class_names']):
            if cls in row and row[cls] == 1.0: label = i; break
        if self.transform: image = self.transform(image=image)['image']
        return image, label

class HAM10000Dataset(Dataset):
    def __init__(self, csv_path, img_dirs, transform=None):
        self.df = pd.read_csv(csv_path)
        self.img_dirs = [Path(d) for d in img_dirs] if isinstance(img_dirs, list) else [Path(img_dirs)]
        self.transform = transform
        self.label_map = {'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3, 'mel': 4, 'nv': 5, 'vasc': 7}
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = None
        for d in self.img_dirs:
            for ext in ['.jpg', '.jpeg', '.png']:
                p = d / f"{row['image_id']}{ext}"
                if p.exists(): img_path = p; break
            if img_path: break
        image = np.array(Image.open(img_path).convert('RGB'))
        label = self.label_map.get(row['dx'], 0)
        if self.transform: image = self.transform(image=image)['image']
        return image, label
print("✓ Dataset classes defined")



✓ Dataset classes defined


In [12]:
# CELL 8: LOAD DATA  (REPLACE THIS ENTIRE CELL)

from pathlib import Path
import pandas as pd

isic_base = Path(CONFIG["data_dir"]) / "isic-2019"

# 1) Pick the correct LABEL CSV (GroundTruth), not the Metadata CSV
csv_candidates = list(isic_base.rglob("*.csv"))
gt_candidates = [p for p in csv_candidates if ("ground" in p.name.lower()) or ("truth" in p.name.lower())]

if len(gt_candidates) == 0:
    raise FileNotFoundError(
        "GroundTruth CSV not found. Expected something like 'ISIC_2019_Training_GroundTruth.csv' "
        "inside data/isic-2019/. Please check your unzip output."
    )

# Prefer the one that contains 'Training' if multiple exist
gt_candidates_sorted = sorted(gt_candidates, key=lambda p: ("training" not in p.name.lower(), len(p.name)))
train_csv = gt_candidates_sorted[0]
train_df = pd.read_csv(train_csv)

print("Using label CSV:", train_csv)
print("train_df size:", len(train_df))

# 2) Find the correct IMAGE directory (folder with most jpg/png)
candidate_dirs = [d for d in isic_base.rglob("*") if d.is_dir()]

def image_count(d: Path) -> int:
    return (
        len(list(d.glob("*.jpg")))
        + len(list(d.glob("*.JPG")))
        + len(list(d.glob("*.jpeg")))
        + len(list(d.glob("*.png")))
    )

train_img_dir = max(candidate_dirs, key=image_count)
print("Using image dir:", train_img_dir, "| image files:", image_count(train_img_dir))

# 3) Verify the CSV actually contains the class columns
missing_cols = [c for c in CONFIG["class_names"] if c not in train_df.columns]
if len(missing_cols) > 0:
    raise ValueError(
        f"This CSV does not contain required class columns: {missing_cols}\n"
        f"You likely loaded the Metadata CSV by mistake.\n"
        f"Loaded columns (first 20): {list(train_df.columns[:20])}\n"
        f"CSV path used: {train_csv}"
    )

print("✓ Class columns found:", [c for c in CONFIG["class_names"] if c in train_df.columns])
print("Class distribution (sum of one-hot columns):")
for cls in CONFIG["class_names"]:
    print(f"  {cls}: {int(train_df[cls].sum())}")

Using label CSV: data/isic-2019/ISIC_2019_Training_GroundTruth.csv
train_df size: 25331
Using image dir: data/isic-2019/ISIC_2019_Training_Input/ISIC_2019_Training_Input | image files: 25331
✓ Class columns found: ['AK', 'BCC', 'BKL', 'DF', 'MEL', 'NV', 'SCC', 'VASC']
Class distribution (sum of one-hot columns):
  AK: 867
  BCC: 3323
  BKL: 2624
  DF: 239
  MEL: 4522
  NV: 12875
  SCC: 628
  VASC: 253


In [13]:
ds = ISIC2019Dataset(train_df.iloc[:200], train_img_dir, get_train_transforms(CONFIG["img_size"]))
x, y = ds[0]
print("Sample tensor:", x.shape, "label:", y)
print("Unique labels in first 200:", sorted({ds[i][1] for i in range(200)}))

Sample tensor: torch.Size([3, 224, 224]) label: 5
Unique labels in first 200: [4, 5]


In [14]:
# Build labels exactly like your dataset does
labels = []
for i in range(len(train_df)):
    row = train_df.iloc[i]
    found = False
    for j, cls in enumerate(CONFIG["class_names"]):
        if cls in train_df.columns and row[cls] == 1.0:
            labels.append(j)
            found = True
            break
    if not found:
        labels.append(-1)

labels = np.array(labels)
print("Label counts:", np.bincount(labels[labels>=0], minlength=CONFIG["num_classes"]))
print("Missing labels (-1):", np.sum(labels == -1))
print("Unique labels:", np.unique(labels))

Label counts: [  867  3323  2624   239  4522 12875   628   253]
Missing labels (-1): 0
Unique labels: [0 1 2 3 4 5 6 7]


In [15]:
# CELL 9: UHViT MODEL
class GatedFusion(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gate = nn.Sequential(nn.Linear(dim * 2, dim), nn.Sigmoid())
        self.proj = nn.Linear(dim * 2, dim)
    def forward(self, x1, x2):
        concat = torch.cat([x1, x2], dim=-1)
        return self.proj(concat) + self.gate(concat) * x1 + (1 - self.gate(concat)) * x2

class MCDropout(nn.Module):
    def __init__(self, p=0.3): super().__init__(); self.p = p
    def forward(self, x): return F.dropout(x, p=self.p, training=True)

class UHViT(nn.Module):
    def __init__(self, num_classes=8, dropout_rate=0.3):
        super().__init__()
        self.swin = timm.create_model(CONFIG['swin_model'], pretrained=True, num_classes=0)
        self.efficientnet = timm.create_model(CONFIG['efficientnet_model'], pretrained=True, num_classes=0)
        self.swin_proj = nn.Linear(self.swin.num_features, 512)
        self.eff_proj = nn.Linear(self.efficientnet.num_features, 512)
        self.fusion = GatedFusion(512)
        self.mc_dropout = MCDropout(dropout_rate)
        self.classifier = nn.Sequential(nn.LayerNorm(512), nn.Linear(512, 256), nn.GELU(), MCDropout(dropout_rate), nn.Linear(256, num_classes))
    def forward(self, x):
        swin_feat = self.swin_proj(self.swin(x))
        eff_feat = self.eff_proj(self.efficientnet(x))
        return self.classifier(self.mc_dropout(self.fusion(swin_feat, eff_feat)))
model = UHViT(); print(f"✓ UHViT: {sum(p.numel() for p in model.parameters()):,} params"); del model



model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/49.3M [00:00<?, ?B/s]

✓ UHViT: 40,580,266 params


In [16]:
# CELL 10: ABLATION & SOTA MODELS
class SwinOnly(nn.Module):
    def __init__(self, num_classes=8, dropout_rate=0.3):
        super().__init__()
        self.swin = timm.create_model(CONFIG['swin_model'], pretrained=True, num_classes=0)
        self.classifier = nn.Sequential(nn.LayerNorm(self.swin.num_features), nn.Dropout(dropout_rate), nn.Linear(self.swin.num_features, num_classes))
    def forward(self, x): return self.classifier(self.swin(x))

class EfficientNetOnly(nn.Module):
    def __init__(self, num_classes=8, dropout_rate=0.3):
        super().__init__()
        self.eff = timm.create_model(CONFIG['efficientnet_model'], pretrained=True, num_classes=0)
        self.classifier = nn.Sequential(nn.LayerNorm(self.eff.num_features), nn.Dropout(dropout_rate), nn.Linear(self.eff.num_features, num_classes))
    def forward(self, x): return self.classifier(self.eff(x))

class UHViTConcatFusion(nn.Module):
    def __init__(self, num_classes=8, dropout_rate=0.3):
        super().__init__()
        self.swin = timm.create_model(CONFIG['swin_model'], pretrained=True, num_classes=0)
        self.efficientnet = timm.create_model(CONFIG['efficientnet_model'], pretrained=True, num_classes=0)
        self.swin_proj = nn.Linear(self.swin.num_features, 512)
        self.eff_proj = nn.Linear(self.efficientnet.num_features, 512)
        self.classifier = nn.Sequential(nn.LayerNorm(1024), nn.Dropout(dropout_rate), nn.Linear(1024, num_classes))
    def forward(self, x): return self.classifier(torch.cat([self.swin_proj(self.swin(x)), self.eff_proj(self.efficientnet(x))], dim=-1))

SOTA_MODELS = ['resnet50', 'efficientnet_b3', 'vit_base_patch16_224', 'swin_tiny_patch4_window7_224', 'convnext_tiny', 'densenet121']
def create_sota_model(name, num_classes=8): return timm.create_model(name, pretrained=True, num_classes=num_classes)
print("✓ Ablation & SOTA models defined")



✓ Ablation & SOTA models defined


In [17]:
# CELL 11: LOSS FUNCTIONS
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0): super().__init__(); self.alpha = alpha; self.gamma = gamma
    def forward(self, inputs, targets):
        ce = F.cross_entropy(inputs, targets, weight=self.alpha, reduction='none')
        return (((1 - torch.exp(-ce)) ** self.gamma) * ce).mean()

class ClassBalancedLoss(nn.Module):
    def __init__(self, samples_per_class, num_classes, beta=0.9999, gamma=2.0):
        super().__init__()
        effective_num = 1.0 - np.power(beta, samples_per_class)
        weights = (1.0 - beta) / np.array(effective_num)
        self.weights = torch.tensor(weights / np.sum(weights) * num_classes, dtype=torch.float32)
        self.gamma = gamma
    def forward(self, inputs, targets):
        self.weights = self.weights.to(inputs.device)
        ce = F.cross_entropy(inputs, targets, weight=self.weights, reduction='none')
        return (((1 - torch.exp(-ce)) ** self.gamma) * ce).mean()

def get_class_weights(labels):
    counts = np.maximum(np.bincount(labels, minlength=CONFIG['num_classes']), 1)
    weights = 1.0 / counts
    return torch.tensor(weights / weights.sum() * len(weights), dtype=torch.float32)

def create_balanced_sampler(labels):
    counts = np.maximum(np.bincount(labels, minlength=CONFIG['num_classes']), 1)
    sample_weights = (1.0 / counts)[labels]
    return WeightedRandomSampler(sample_weights, len(labels), replacement=True)
print("✓ Loss functions defined")



✓ Loss functions defined


In [18]:
# CELL 12: TRAINING FUNCTIONS
def train_epoch(model, loader, criterion, optimizer, scheduler, device, epoch):
    model.train()
    running_loss, all_preds, all_labels = 0.0, [], []
    pbar = tqdm(loader, desc=f'Train E{epoch}')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        if scheduler: scheduler.step()
        running_loss += loss.item()
        all_preds.extend(outputs.argmax(dim=1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    return running_loss / len(loader), accuracy_score(all_labels, all_preds), balanced_accuracy_score(all_labels, all_preds)

def validate(model, loader, criterion, device):
    model.eval()
    running_loss, all_preds, all_labels, all_probs = 0.0, [], [], []
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Val'):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            running_loss += criterion(outputs, labels).item()
            all_probs.extend(F.softmax(outputs, dim=1).cpu().numpy())
            all_preds.extend(outputs.argmax(dim=1).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    all_preds, all_labels, all_probs = np.array(all_preds), np.array(all_labels), np.array(all_probs)
    return running_loss / len(loader), accuracy_score(all_labels, all_preds), balanced_accuracy_score(all_labels, all_preds), f1_score(all_labels, all_preds, average='macro'), all_preds, all_labels, all_probs

def validate_with_tta(model, loader, device):
    model.eval()
    all_probs, all_labels = [], []
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='TTA'):
            images = images.to(device)
            probs = torch.stack([F.softmax(model(images), dim=1), F.softmax(model(torch.flip(images, [3])), dim=1), F.softmax(model(torch.flip(images, [2])), dim=1)]).mean(0)
            all_probs.extend(probs.cpu().numpy()); all_labels.extend(labels.numpy())
    all_probs, all_labels = np.array(all_probs), np.array(all_labels)
    all_preds = all_probs.argmax(axis=1)
    return accuracy_score(all_labels, all_preds), balanced_accuracy_score(all_labels, all_preds), f1_score(all_labels, all_preds, average='macro'), all_preds, all_labels, all_probs
print("✓ Training functions defined")



✓ Training functions defined


In [19]:
# CELL 13: UNCERTAINTY ESTIMATION
def estimate_uncertainty(model, loader, device, n_samples=10):
    model.train()
    all_means, all_uncertainties, all_labels = [], [], []
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Uncertainty'):
            images = images.to(device)
            outputs = torch.stack([F.softmax(model(images), dim=1) for _ in range(n_samples)])
            all_means.extend(outputs.mean(0).cpu().numpy())
            all_uncertainties.extend(outputs.std(0).mean(1).cpu().numpy())
            all_labels.extend(labels.numpy())
    return np.array(all_means), np.array(all_uncertainties), np.array(all_labels)

def plot_uncertainty_analysis(uncertainties, predictions, labels, save_path):
    correct = (predictions == labels)
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    axes[0].hist(uncertainties[correct], bins=30, alpha=0.7, label='Correct', density=True)
    axes[0].hist(uncertainties[~correct], bins=30, alpha=0.7, label='Incorrect', density=True)
    axes[0].legend(); axes[0].set_title('Uncertainty Distribution')
    thresholds = np.percentile(uncertainties, np.arange(0, 100, 5))
    accs, coverages = [], []
    for t in thresholds:
        mask = uncertainties <= t
        if mask.sum() > 0: accs.append(accuracy_score(labels[mask], predictions[mask])); coverages.append(mask.mean())
    axes[1].plot(coverages, accs, 'o-'); axes[1].set_xlabel('Coverage'); axes[1].set_ylabel('Accuracy'); axes[1].set_title('Accuracy vs Coverage')
    class_unc = [uncertainties[labels == i].mean() if (labels == i).sum() > 0 else 0 for i in range(CONFIG['num_classes'])]
    axes[2].bar(CONFIG['class_names'], class_unc); axes[2].set_title('Per-Class Uncertainty'); axes[2].tick_params(axis='x', rotation=45)
    plt.tight_layout(); plt.savefig(save_path, dpi=150); wandb.log({"uncertainty": wandb.Image(save_path)}); plt.close()
print("✓ Uncertainty functions defined")



✓ Uncertainty functions defined


In [20]:
# CELL 14: VISUALIZATION FUNCTIONS
def plot_confusion_matrix(y_true, y_pred, class_names, save_path):
    cm = confusion_matrix(y_true, y_pred)
    cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, ax=axes[0])
    axes[0].set_title('Confusion Matrix'); axes[0].set_xlabel('Predicted'); axes[0].set_ylabel('True')
    sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues', xticklabels=class_names, yticklabels=class_names, ax=axes[1])
    axes[1].set_title('Normalized'); axes[1].set_xlabel('Predicted'); axes[1].set_ylabel('True')
    plt.tight_layout(); plt.savefig(save_path, dpi=150); wandb.log({"confusion_matrix": wandb.Image(save_path)}); plt.close()

def plot_roc_curves(y_true, y_probs, class_names, save_path):
    fig, ax = plt.subplots(figsize=(10, 8))
    for i, cls in enumerate(class_names):
        binary = (y_true == i).astype(int)
        if 0 < binary.sum() < len(binary):
            fpr, tpr, _ = roc_curve(binary, y_probs[:, i])
            ax.plot(fpr, tpr, label=f'{cls} (AUC={auc(fpr, tpr):.3f})')
    ax.plot([0, 1], [0, 1], 'k--'); ax.legend(loc='lower right')
    plt.savefig(save_path, dpi=150); wandb.log({"roc_curves": wandb.Image(save_path)}); plt.close()

def plot_training_history(history, save_path):
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    epochs = range(1, len(history['train_loss']) + 1)
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train'); axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val'); axes[0, 0].set_title('Loss'); axes[0, 0].legend()
    axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Train'); axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Val'); axes[0, 1].set_title('Accuracy'); axes[0, 1].legend()
    axes[1, 0].plot(epochs, history['train_bacc'], 'b-', label='Train'); axes[1, 0].plot(epochs, history['val_bacc'], 'r-', label='Val'); axes[1, 0].set_title('Balanced Accuracy'); axes[1, 0].legend()
    axes[1, 1].plot(epochs, history['val_f1'], 'g-'); axes[1, 1].set_title('Macro F1')
    plt.tight_layout(); plt.savefig(save_path, dpi=150); wandb.log({"training_history": wandb.Image(save_path)}); plt.close()
print("✓ Visualization functions defined")



✓ Visualization functions defined


In [21]:
# CELL 15: CLINICAL & CALIBRATION ANALYSIS
def compute_ece(probs, labels, n_bins=15):
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    ece = 0.0
    for i in range(n_bins):
        mask = (probs > bin_boundaries[i]) & (probs <= bin_boundaries[i + 1])
        if mask.sum() > 0: ece += mask.sum() * abs(probs[mask].mean() - labels[mask].mean())
    return ece / len(probs)

def per_class_analysis(y_true, y_pred, y_probs, class_names):
    precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
    auc_scores = []
    for i in range(len(class_names)):
        binary = (y_true == i).astype(int)
        try: auc_scores.append(roc_auc_score(binary, y_probs[:, i]) if 0 < binary.sum() < len(binary) else 0)
        except: auc_scores.append(0)
    df = pd.DataFrame({'Class': class_names, 'Precision': precision, 'Recall': recall, 'F1': f1, 'AUC': auc_scores, 'Support': support})
    wandb.log({"per_class_metrics": wandb.Table(dataframe=df)})
    return df

def plot_per_class_performance(df, save_path):
    fig, ax = plt.subplots(figsize=(14, 6))
    x = np.arange(len(df)); width = 0.2
    ax.bar(x - 1.5*width, df['Precision'], width, label='Precision'); ax.bar(x - 0.5*width, df['Recall'], width, label='Recall')
    ax.bar(x + 0.5*width, df['F1'], width, label='F1'); ax.bar(x + 1.5*width, df['AUC'], width, label='AUC')
    ax.set_xticks(x); ax.set_xticklabels(df['Class'], rotation=45); ax.legend(); ax.set_ylim(0, 1)
    plt.tight_layout(); plt.savefig(save_path, dpi=150); wandb.log({"per_class_chart": wandb.Image(save_path)}); plt.close()
print("✓ Clinical analysis functions defined")



✓ Clinical analysis functions defined


In [22]:
# CELL 16: GRAD-CAM++
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

def get_gradcam_visualizations(model, loader, device, save_path='gradcam.png'):
    model.eval()
    try: target_layer = [model.efficientnet.conv_head]
    except: target_layer = [list(model.efficientnet.children())[-2]]
    cam = GradCAMPlusPlus(model=model, target_layers=target_layer)
    fig, axes = plt.subplots(2, 4, figsize=(16, 8)); axes = axes.flatten()
    images, labels = next(iter(loader)); images = images[:8].to(device); labels = labels[:8]
    for i in range(min(8, len(images))):
        input_tensor = images[i:i + 1]
        with torch.no_grad(): pred_class = model(input_tensor).argmax(dim=1).item()
        grayscale_cam = cam(input_tensor=input_tensor, targets=[ClassifierOutputTarget(pred_class)])[0, :]
        img = images[i].cpu().numpy().transpose(1, 2, 0)
        img = np.clip(img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]), 0, 1)
        vis = show_cam_on_image(img, grayscale_cam, use_rgb=True)
        axes[i].imshow(vis); axes[i].set_title(f'T:{CONFIG["class_names"][labels[i]]} P:{CONFIG["class_names"][pred_class]}'); axes[i].axis('off')
    plt.tight_layout(); plt.savefig(save_path, dpi=150); wandb.log({"gradcam": wandb.Image(save_path)}); plt.close(); del cam
print("✓ Grad-CAM++ defined")



✓ Grad-CAM++ defined


In [23]:
# CELL 17: STATISTICAL TESTS
def mcnemar_test(y_true, y_pred1, y_pred2):
    c1, c2 = (y_pred1 == y_true), (y_pred2 == y_true)
    b, c = np.sum(c1 & ~c2), np.sum(~c1 & c2)
    if b + c > 0: stat = (abs(b - c) - 1) ** 2 / (b + c); p = 1 - stats.chi2.cdf(stat, df=1)
    else: stat, p = 0, 1.0
    return stat, p

def paired_t_test(scores1, scores2): return stats.ttest_rel(scores1, scores2)
def wilcoxon_test(scores1, scores2):
    try: return stats.wilcoxon(scores1, scores2)
    except: return 0, 1.0

def compute_ci(scores, confidence=0.95):
    n, mean, se = len(scores), np.mean(scores), stats.sem(scores)
    h = se * stats.t.ppf((1 + confidence) / 2, n - 1)
    return mean, mean - h, mean + h
print("✓ Statistical tests defined")



✓ Statistical tests defined


In [24]:
# CELL 18: COMPUTATIONAL ANALYSIS
from ptflops import get_model_complexity_info

def compute_model_complexity(model, input_size=(3, 224, 224)):
    try: macs, params = get_model_complexity_info(model, input_size, as_strings=False, print_per_layer_stat=False, verbose=False); return macs * 2, params
    except: return 0, sum(p.numel() for p in model.parameters())

def measure_inference_time(model, device, n_runs=100):
    model.eval(); dummy = torch.randn(1, 3, 224, 224).to(device)
    with torch.no_grad():
        for _ in range(10): _ = model(dummy)
    if device == 'cuda': torch.cuda.synchronize()
    times = []
    with torch.no_grad():
        for _ in range(n_runs):
            start = time.time(); _ = model(dummy)
            if device == 'cuda': torch.cuda.synchronize()
            times.append(time.time() - start)
    return np.mean(times) * 1000, np.std(times) * 1000

def computational_analysis(models_dict, device):
    results = []
    for name, model_fn in models_dict.items():
        print(f"Analyzing {name}..."); model = model_fn().to(device)
        flops, params = compute_model_complexity(model); mean_time, std_time = measure_inference_time(model, device)
        results.append({'Model': name, 'Params (M)': params / 1e6, 'FLOPs (G)': flops / 1e9, 'Time (ms)': mean_time})
        del model; torch.cuda.empty_cache()
    df = pd.DataFrame(results); wandb.log({"computational": wandb.Table(dataframe=df)}); return df
print("✓ Computational analysis defined")



✓ Computational analysis defined


In [25]:
# CELL 19: MAIN TRAINING LOOP
def train_fold(fold, train_df, val_df, train_img_dir, device):
    print(f"\n{'='*50}\nFOLD {fold+1}/{CONFIG['n_folds']}\n{'='*50}")
    train_dataset = ISIC2019Dataset(train_df, train_img_dir, get_train_transforms(CONFIG['img_size']))
    val_dataset = ISIC2019Dataset(val_df, train_img_dir, get_val_transforms(CONFIG['img_size']))
    train_labels = []
    for i in range(len(train_df)):
        row = train_df.iloc[i]
        for j, cls in enumerate(CONFIG['class_names']):
            if cls in row and row[cls] == 1.0: train_labels.append(j); break
        else: train_labels.append(0)
    train_labels = np.array(train_labels)
    sampler = create_balanced_sampler(train_labels)
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], sampler=sampler, num_workers=CONFIG['num_workers'], pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'], pin_memory=True)
    model = UHViT(num_classes=CONFIG['num_classes'], dropout_rate=CONFIG['dropout_rate']).to(device)
    samples_per_class = np.maximum(np.bincount(train_labels, minlength=CONFIG['num_classes']), 1)
    criterion = ClassBalancedLoss(samples_per_class, CONFIG['num_classes'])
    optimizer = AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
    history = {'train_loss': [], 'train_acc': [], 'train_bacc': [], 'val_loss': [], 'val_acc': [], 'val_bacc': [], 'val_f1': []}
    best_bacc, best_state = 0, None
    for epoch in range(CONFIG['epochs']):
        train_loss, train_acc, train_bacc = train_epoch(model, train_loader, criterion, optimizer, scheduler, device, epoch + 1)
        val_loss, val_acc, val_bacc, val_f1, val_preds, val_labels, val_probs = validate(model, val_loader, criterion, device)
        wandb.log({f'fold{fold}/train_loss': train_loss, f'fold{fold}/train_bacc': train_bacc, f'fold{fold}/val_loss': val_loss, f'fold{fold}/val_bacc': val_bacc, f'fold{fold}/val_f1': val_f1})
        history['train_loss'].append(train_loss); history['train_acc'].append(train_acc); history['train_bacc'].append(train_bacc)
        history['val_loss'].append(val_loss); history['val_acc'].append(val_acc); history['val_bacc'].append(val_bacc); history['val_f1'].append(val_f1)
        print(f"E{epoch+1}: TrBAcc={train_bacc:.4f}, ValBAcc={val_bacc:.4f}, ValF1={val_f1:.4f}")
        if val_bacc > best_bacc: best_bacc = val_bacc; best_state = model.state_dict().copy()
    model.load_state_dict(best_state); torch.save(best_state, f"{CONFIG['checkpoint_dir']}/fold{fold}_best.pt")
    tta_acc, tta_bacc, tta_f1, tta_preds, tta_labels, tta_probs = validate_with_tta(model, val_loader, device)
    print(f"TTA Results: Acc={tta_acc:.4f}, BAcc={tta_bacc:.4f}, F1={tta_f1:.4f}")
    plot_confusion_matrix(tta_labels, tta_preds, CONFIG['class_names'], f"{CONFIG['output_dir']}/fold{fold}_cm.png")
    plot_roc_curves(tta_labels, tta_probs, CONFIG['class_names'], f"{CONFIG['output_dir']}/fold{fold}_roc.png")
    plot_training_history(history, f"{CONFIG['output_dir']}/fold{fold}_history.png")
    mean_probs, uncertainties, unc_labels = estimate_uncertainty(model, val_loader, device)
    plot_uncertainty_analysis(uncertainties, mean_probs.argmax(axis=1), unc_labels, f"{CONFIG['output_dir']}/fold{fold}_uncertainty.png")
    get_gradcam_visualizations(model, val_loader, device, f"{CONFIG['output_dir']}/fold{fold}_gradcam.png")
    pc_df = per_class_analysis(tta_labels, tta_preds, tta_probs, CONFIG['class_names'])
    plot_per_class_performance(pc_df, f"{CONFIG['output_dir']}/fold{fold}_perclass.png")
    return {'acc': tta_acc, 'bacc': tta_bacc, 'f1': tta_f1, 'preds': tta_preds, 'labels': tta_labels, 'probs': tta_probs, 'model': model}
print("✓ Main training function defined")



✓ Main training function defined


In [26]:
# Sanity check: confirm first 50 images exist
missing = 0
for i in range(min(50, len(train_df))):
    img_name = str(train_df.iloc[i]["image"])
    found = False
    for ext in [".jpg", ".jpeg", ".png", ".JPG"]:
        if (Path(train_img_dir) / f"{img_name}{ext}").exists():
            found = True
            break
    if not found:
        # try glob
        if not list(Path(train_img_dir).glob(f"{img_name}*")):
            missing += 1
            if missing <= 5:
                print("Missing example:", img_name)

print("Missing in first 50:", missing)
assert missing == 0, "Image paths are wrong — fix train_img_dir / CSV selection."

Missing in first 50: 0


In [27]:
# CELL 20: RUN 5-FOLD CROSS-VALIDATION
wandb.init(project='UHViT-SkinLesion-IF5', name='5fold-cv', config=CONFIG)
all_labels = []
for i in range(len(train_df)):
    row = train_df.iloc[i]
    for j, cls in enumerate(CONFIG['class_names']):
        if cls in row and row[cls] == 1.0: all_labels.append(j); break
    else: all_labels.append(0)
all_labels = np.array(all_labels)
skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['seed'])
fold_results = []
for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, all_labels)):
    result = train_fold(fold, train_df.iloc[train_idx], train_df.iloc[val_idx], train_img_dir, CONFIG['device'])
    fold_results.append(result)
    print(f"\nFold {fold+1} Complete: BAcc={result['bacc']:.4f}, F1={result['f1']:.4f}")
accs, baccs, f1s = [r['acc'] for r in fold_results], [r['bacc'] for r in fold_results], [r['f1'] for r in fold_results]
print(f"\n{'='*50}\n5-FOLD CV RESULTS\n{'='*50}")
print(f"Accuracy: {np.mean(accs):.4f} ± {np.std(accs):.4f}")
print(f"Balanced Accuracy: {np.mean(baccs):.4f} ± {np.std(baccs):.4f}")
print(f"Macro F1: {np.mean(f1s):.4f} ± {np.std(f1s):.4f}")
acc_mean, acc_lo, acc_hi = compute_ci(accs); bacc_mean, bacc_lo, bacc_hi = compute_ci(baccs); f1_mean, f1_lo, f1_hi = compute_ci(f1s)
print(f"\n95% CI - Accuracy: [{acc_lo:.4f}, {acc_hi:.4f}]")
print(f"95% CI - Balanced Acc: [{bacc_lo:.4f}, {bacc_hi:.4f}]")
print(f"95% CI - Macro F1: [{f1_lo:.4f}, {f1_hi:.4f}]")
wandb.log({'final/acc_mean': np.mean(accs), 'final/acc_std': np.std(accs), 'final/bacc_mean': np.mean(baccs), 'final/bacc_std': np.std(baccs), 'final/f1_mean': np.mean(f1s), 'final/f1_std': np.std(f1s)})




FOLD 1/5


Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E1: TrBAcc=0.3389, ValBAcc=0.3773, ValF1=0.1059


Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E2: TrBAcc=0.4332, ValBAcc=0.4993, ValF1=0.1712


Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E3: TrBAcc=0.4644, ValBAcc=0.5360, ValF1=0.2290


Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E4: TrBAcc=0.5364, ValBAcc=0.5593, ValF1=0.2700


Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E5: TrBAcc=0.5133, ValBAcc=0.5636, ValF1=0.2695


Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E6: TrBAcc=0.5456, ValBAcc=0.5989, ValF1=0.3447


Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E7: TrBAcc=0.5959, ValBAcc=0.6240, ValF1=0.3721


Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E8: TrBAcc=0.6196, ValBAcc=0.6227, ValF1=0.3664


Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E9: TrBAcc=0.5782, ValBAcc=0.5933, ValF1=0.3099


Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E10: TrBAcc=0.5838, ValBAcc=0.6179, ValF1=0.3202


Train E11:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E11: TrBAcc=0.5989, ValBAcc=0.6425, ValF1=0.3897


Train E12:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E12: TrBAcc=0.6274, ValBAcc=0.6358, ValF1=0.3766


Train E13:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E13: TrBAcc=0.6501, ValBAcc=0.6479, ValF1=0.4311


Train E14:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E14: TrBAcc=0.6765, ValBAcc=0.6679, ValF1=0.4526


Train E15:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E15: TrBAcc=0.6858, ValBAcc=0.6786, ValF1=0.4672


Train E16:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E16: TrBAcc=0.6939, ValBAcc=0.6781, ValF1=0.4700


Train E17:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E17: TrBAcc=0.6385, ValBAcc=0.6240, ValF1=0.3254


Train E18:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E18: TrBAcc=0.6211, ValBAcc=0.6576, ValF1=0.4019


Train E19:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E19: TrBAcc=0.6381, ValBAcc=0.6549, ValF1=0.4266


Train E20:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E20: TrBAcc=0.6479, ValBAcc=0.6318, ValF1=0.3862


Train E21:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E21: TrBAcc=0.6636, ValBAcc=0.6490, ValF1=0.3911


Train E22:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E22: TrBAcc=0.6732, ValBAcc=0.6642, ValF1=0.4442


Train E23:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E23: TrBAcc=0.6833, ValBAcc=0.6808, ValF1=0.4572


Train E24:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E24: TrBAcc=0.7043, ValBAcc=0.6962, ValF1=0.4987


Train E25:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E25: TrBAcc=0.7139, ValBAcc=0.7038, ValF1=0.5108


Train E26:   0%|          | 0/634 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E26: TrBAcc=0.7206, ValBAcc=0.7080, ValF1=0.5501


Train E27:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E27: TrBAcc=0.7316, ValBAcc=0.7007, ValF1=0.5318


Train E28:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E28: TrBAcc=0.7460, ValBAcc=0.7114, ValF1=0.5208


Train E29:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E29: TrBAcc=0.7531, ValBAcc=0.7082, ValF1=0.5439


Train E30:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E30: TrBAcc=0.7605, ValBAcc=0.6981, ValF1=0.5602


TTA:   0%|          | 0/159 [00:00<?, ?it/s]

TTA Results: Acc=0.4796, BAcc=0.7173, F1=0.5722


Uncertainty:   0%|          | 0/159 [00:00<?, ?it/s]


Fold 1 Complete: BAcc=0.7173, F1=0.5722

FOLD 2/5


Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E1: TrBAcc=0.3257, ValBAcc=0.4351, ValF1=0.1016


Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E2: TrBAcc=0.4241, ValBAcc=0.4624, ValF1=0.1564


Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E3: TrBAcc=0.4622, ValBAcc=0.5316, ValF1=0.2238


Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E4: TrBAcc=0.5327, ValBAcc=0.5651, ValF1=0.2891


Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E5: TrBAcc=0.5132, ValBAcc=0.5467, ValF1=0.2480


Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
IOStream.flush timed out

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E6: TrBAcc=0.5542, ValBAcc=0.5707, ValF1=0.3145


Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E7: TrBAcc=0.5984, ValBAcc=0.6279, ValF1=0.3772


Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E8: TrBAcc=0.6256, ValBAcc=0.6292, ValF1=0.3679


Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E9: TrBAcc=0.5780, ValBAcc=0.5929, ValF1=0.3145


Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E10: TrBAcc=0.5829, ValBAcc=0.6022, ValF1=0.3465


Train E11:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E11: TrBAcc=0.6095, ValBAcc=0.6273, ValF1=0.3790


Train E12:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E12: TrBAcc=0.6362, ValBAcc=0.6546, ValF1=0.4037


Train E13:   0%|          | 0/634 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E13: TrBAcc=0.6567, ValBAcc=0.6627, ValF1=0.4361


Train E14:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E14: TrBAcc=0.6755, ValBAcc=0.6765, ValF1=0.4516


Train E15:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E15: TrBAcc=0.6874, ValBAcc=0.6839, ValF1=0.4867


Train E16:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E16: TrBAcc=0.6897, ValBAcc=0.6789, ValF1=0.4805


Train E17:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E17: TrBAcc=0.6489, ValBAcc=0.6536, ValF1=0.4031


Train E18:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E18: TrBAcc=0.6391, ValBAcc=0.6549, ValF1=0.4196


Train E19:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E19: TrBAcc=0.6465, ValBAcc=0.6575, ValF1=0.4075


Train E20:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E20: TrBAcc=0.6674, ValBAcc=0.6816, ValF1=0.4426


Train E21:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E21: TrBAcc=0.6669, ValBAcc=0.6678, ValF1=0.3998


Train E22:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E22: TrBAcc=0.6770, ValBAcc=0.6799, ValF1=0.4580


Train E23:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E23: TrBAcc=0.6956, ValBAcc=0.6865, ValF1=0.4859


Train E24:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E24: TrBAcc=0.7072, ValBAcc=0.6785, ValF1=0.4894


Train E25:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E25: TrBAcc=0.7159, ValBAcc=0.7076, ValF1=0.5167


Train E26:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E26: TrBAcc=0.7277, ValBAcc=0.6937, ValF1=0.4992


Train E27:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E27: TrBAcc=0.7347, ValBAcc=0.7086, ValF1=0.5199


Train E28:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E28: TrBAcc=0.7471, ValBAcc=0.7026, ValF1=0.5339


Train E29:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E29: TrBAcc=0.7456, ValBAcc=0.7136, ValF1=0.5532


Train E30:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E30: TrBAcc=0.7609, ValBAcc=0.7105, ValF1=0.5471


TTA:   0%|          | 0/159 [00:00<?, ?it/s]

TTA Results: Acc=0.5174, BAcc=0.7268, F1=0.5662


Uncertainty:   0%|          | 0/159 [00:00<?, ?it/s]


Fold 2 Complete: BAcc=0.7268, F1=0.5662

FOLD 3/5


Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E1: TrBAcc=0.3440, ValBAcc=0.4122, ValF1=0.1100


Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E2: TrBAcc=0.4317, ValBAcc=0.4896, ValF1=0.1705


Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E3: TrBAcc=0.4659, ValBAcc=0.5375, ValF1=0.2179


Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E4: TrBAcc=0.5337, ValBAcc=0.5672, ValF1=0.2603


Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E5: TrBAcc=0.5081, ValBAcc=0.5710, ValF1=0.2672


Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E6: TrBAcc=0.5491, ValBAcc=0.6027, ValF1=0.3189


Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E7: TrBAcc=0.5981, ValBAcc=0.6381, ValF1=0.3729


Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E8: TrBAcc=0.6245, ValBAcc=0.6342, ValF1=0.3777


Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E9: TrBAcc=0.5722, ValBAcc=0.5923, ValF1=0.2843


Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E10: TrBAcc=0.5822, ValBAcc=0.5994, ValF1=0.3221


Train E11:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E11: TrBAcc=0.6062, ValBAcc=0.6420, ValF1=0.3634


Train E12:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E12: TrBAcc=0.6332, ValBAcc=0.6451, ValF1=0.4066


Train E13:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E13: TrBAcc=0.6549, ValBAcc=0.6502, ValF1=0.4313


Train E14:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E14: TrBAcc=0.6783, ValBAcc=0.6850, ValF1=0.4999


Train E15:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E15: TrBAcc=0.6969, ValBAcc=0.6935, ValF1=0.5056


Train E16:   0%|          | 0/634 [00:00<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E16: TrBAcc=0.6966, ValBAcc=0.6923, ValF1=0.5004


Train E17:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E17: TrBAcc=0.6431, ValBAcc=0.6320, ValF1=0.3371


Train E18:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E18: TrBAcc=0.6293, ValBAcc=0.6389, ValF1=0.3936


Train E19:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E19: TrBAcc=0.6500, ValBAcc=0.6294, ValF1=0.3647


Train E20:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E20: TrBAcc=0.6639, ValBAcc=0.6707, ValF1=0.4199


Train E21:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E21: TrBAcc=0.6749, ValBAcc=0.6900, ValF1=0.4336


Train E22:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E22: TrBAcc=0.6727, ValBAcc=0.6734, ValF1=0.4827


Train E23:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E23: TrBAcc=0.6908, ValBAcc=0.6891, ValF1=0.4651


Train E24:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E24: TrBAcc=0.7026, ValBAcc=0.7035, ValF1=0.5091


Train E25:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E25: TrBAcc=0.7222, ValBAcc=0.7159, ValF1=0.5285


Train E26:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E26: TrBAcc=0.7298, ValBAcc=0.7235, ValF1=0.5360


Train E27:   0%|          | 0/634 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E27: TrBAcc=0.7333, ValBAcc=0.7213, ValF1=0.5250


Train E28:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E28: TrBAcc=0.7447, ValBAcc=0.7353, ValF1=0.5594


Train E29:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E29: TrBAcc=0.7498, ValBAcc=0.7363, ValF1=0.5720


Train E30:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E30: TrBAcc=0.7518, ValBAcc=0.7380, ValF1=0.5851


TTA:   0%|          | 0/159 [00:00<?, ?it/s]

TTA Results: Acc=0.4907, BAcc=0.7432, F1=0.5891


Uncertainty:   0%|          | 0/159 [00:00<?, ?it/s]


Fold 3 Complete: BAcc=0.7432, F1=0.5891

FOLD 4/5


Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E1: TrBAcc=0.3356, ValBAcc=0.4378, ValF1=0.1200


Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E2: TrBAcc=0.4326, ValBAcc=0.4937, ValF1=0.1734


Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E3: TrBAcc=0.4677, ValBAcc=0.5478, ValF1=0.2452


Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E4: TrBAcc=0.5367, ValBAcc=0.5959, ValF1=0.2994


Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E5: TrBAcc=0.5138, ValBAcc=0.5400, ValF1=0.2474


Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E6: TrBAcc=0.5580, ValBAcc=0.6132, ValF1=0.3360


Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E7: TrBAcc=0.6100, ValBAcc=0.6441, ValF1=0.3827


Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E8: TrBAcc=0.6279, ValBAcc=0.6486, ValF1=0.3947


Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E9: TrBAcc=0.5781, ValBAcc=0.6237, ValF1=0.3618


Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E10: TrBAcc=0.5863, ValBAcc=0.6699, ValF1=0.3945


Train E11:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E11: TrBAcc=0.6230, ValBAcc=0.6522, ValF1=0.3783


Train E12:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E12: TrBAcc=0.6355, ValBAcc=0.6902, ValF1=0.4526


Train E13:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E13: TrBAcc=0.6728, ValBAcc=0.6887, ValF1=0.4577


Train E14:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E14: TrBAcc=0.6808, ValBAcc=0.6806, ValF1=0.4578


Train E15:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E15: TrBAcc=0.6948, ValBAcc=0.6992, ValF1=0.4847


Train E16:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E16: TrBAcc=0.7050, ValBAcc=0.6988, ValF1=0.4797


Train E17:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E17: TrBAcc=0.6451, ValBAcc=0.6458, ValF1=0.3409


Train E18:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E18: TrBAcc=0.6360, ValBAcc=0.6558, ValF1=0.3926


Train E19:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E19: TrBAcc=0.6526, ValBAcc=0.6829, ValF1=0.3995


Train E20:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E20: TrBAcc=0.6642, ValBAcc=0.6750, ValF1=0.4071


Train E21:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E21: TrBAcc=0.6770, ValBAcc=0.6789, ValF1=0.4277


Train E22:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E22: TrBAcc=0.6746, ValBAcc=0.6921, ValF1=0.4269


Train E23:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E23: TrBAcc=0.6877, ValBAcc=0.6970, ValF1=0.4851


Train E24:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E24: TrBAcc=0.7033, ValBAcc=0.7076, ValF1=0.5009


Train E25:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E25: TrBAcc=0.7187, ValBAcc=0.7252, ValF1=0.4847


Train E26:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E26: TrBAcc=0.7327, ValBAcc=0.7186, ValF1=0.5424


Train E27:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E27: TrBAcc=0.7418, ValBAcc=0.7324, ValF1=0.5661


Train E28:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E28: TrBAcc=0.7506, ValBAcc=0.7322, ValF1=0.5838


Train E29:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E29: TrBAcc=0.7579, ValBAcc=0.7295, ValF1=0.5744


Train E30:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E30: TrBAcc=0.7593, ValBAcc=0.7392, ValF1=0.5878


TTA:   0%|          | 0/159 [00:00<?, ?it/s]

TTA Results: Acc=0.5221, BAcc=0.7469, F1=0.6025


Uncertainty:   0%|          | 0/159 [00:00<?, ?it/s]


Fold 4 Complete: BAcc=0.7469, F1=0.6025

FOLD 5/5




Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E1: TrBAcc=0.3283, ValBAcc=0.4002, ValF1=0.0991


Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E2: TrBAcc=0.4230, ValBAcc=0.4923, ValF1=0.1695


Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E3: TrBAcc=0.4640, ValBAcc=0.4952, ValF1=0.2137


Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E4: TrBAcc=0.5297, ValBAcc=0.5749, ValF1=0.2786


Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E5: TrBAcc=0.5122, ValBAcc=0.5595, ValF1=0.2673


Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E6: TrBAcc=0.5607, ValBAcc=0.6127, ValF1=0.3326


Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E7: TrBAcc=0.5963, ValBAcc=0.6371, ValF1=0.3696


Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E8: TrBAcc=0.6177, ValBAcc=0.6335, ValF1=0.3738


Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E9: TrBAcc=0.5803, ValBAcc=0.6010, ValF1=0.3456


Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E10: TrBAcc=0.5841, ValBAcc=0.5993, ValF1=0.3276


Train E11:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E11: TrBAcc=0.6161, ValBAcc=0.6246, ValF1=0.3694


Train E12:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E12: TrBAcc=0.6339, ValBAcc=0.6324, ValF1=0.3616


Train E13:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E13: TrBAcc=0.6603, ValBAcc=0.6658, ValF1=0.4269


Train E14:   0%|          | 0/634 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E14: TrBAcc=0.6742, ValBAcc=0.6766, ValF1=0.4420


Train E15:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E15: TrBAcc=0.6921, ValBAcc=0.6911, ValF1=0.4772


Train E16:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E16: TrBAcc=0.6987, ValBAcc=0.6891, ValF1=0.4820


Train E17:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E17: TrBAcc=0.6502, ValBAcc=0.6642, ValF1=0.4061


Train E18:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E18: TrBAcc=0.6350, ValBAcc=0.6468, ValF1=0.3961


Train E19:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E19: TrBAcc=0.6567, ValBAcc=0.6799, ValF1=0.4245


Train E20:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E20: TrBAcc=0.6723, ValBAcc=0.6475, ValF1=0.4106


Train E21:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E21: TrBAcc=0.6733, ValBAcc=0.6529, ValF1=0.4128


Train E22:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E22: TrBAcc=0.6838, ValBAcc=0.6930, ValF1=0.4403


Train E23:   0%|          | 0/634 [00:00<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E23: TrBAcc=0.6930, ValBAcc=0.7046, ValF1=0.4926


Train E24:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E24: TrBAcc=0.7052, ValBAcc=0.6893, ValF1=0.5060


Train E25:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E25: TrBAcc=0.7171, ValBAcc=0.6906, ValF1=0.4972


Val:   0%|          | 0/159 [00:00<?, ?it/s]

E26: TrBAcc=0.7319, ValBAcc=0.7049, ValF1=0.5149


Train E27:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E27: TrBAcc=0.7359, ValBAcc=0.6990, ValF1=0.5146


Train E28:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E28: TrBAcc=0.7483, ValBAcc=0.7168, ValF1=0.5391


Train E29:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

E29: TrBAcc=0.7548, ValBAcc=0.7159, ValF1=0.5752


Train E30:   0%|          | 0/634 [00:00<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process


Val:   0%|          | 0/159 [00:00<?, ?it/s]

E30: TrBAcc=0.7580, ValBAcc=0.7244, ValF1=0.5661


TTA:   0%|          | 0/159 [00:00<?, ?it/s]

TTA Results: Acc=0.5156, BAcc=0.7334, F1=0.5783


Uncertainty:   0%|          | 0/159 [00:00<?, ?it/s]


Fold 5 Complete: BAcc=0.7334, F1=0.5783

5-FOLD CV RESULTS
Accuracy: 0.5051 ± 0.0168
Balanced Accuracy: 0.7335 ± 0.0108
Macro F1: 0.5817 ± 0.0129

95% CI - Accuracy: [0.4818, 0.5284]
95% CI - Balanced Acc: [0.7186, 0.7485]
95% CI - Macro F1: [0.5638, 0.5995]


In [40]:
# CELL 20B: CLINICAL SAFETY (MEL operating point + uncertainty referral + calibration)

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, precision_recall_fscore_support

MEL_CLASS = CONFIG["class_names"].index("MEL")

def mel_operating_point(y_true, y_probs, target_sens=0.95):
    """
    Choose a MEL probability threshold to reach >= target sensitivity (recall for MEL),
    and report specificity, precision, F1 at that threshold.
    """
    mel_true = (y_true == MEL_CLASS).astype(int)
    mel_score = y_probs[:, MEL_CLASS]

    fpr, tpr, thr = roc_curve(mel_true, mel_score)
    # tpr = sensitivity, fpr = 1 - specificity
    idx = np.where(tpr >= target_sens)[0]
    if len(idx) == 0:
        best_idx = np.argmax(tpr)
    else:
        best_idx = idx[np.argmin(fpr[idx])]  # highest specificity among those meeting target sensitivity

    threshold = thr[best_idx]
    mel_pred = (mel_score >= threshold).astype(int)

    tn, fp, fn, tp = confusion_matrix(mel_true, mel_pred).ravel()
    sens = tp / (tp + fn + 1e-9)
    spec = tn / (tn + fp + 1e-9)
    prec = tp / (tp + fp + 1e-9)
    f1 = 2 * prec * sens / (prec + sens + 1e-9)
    auc = roc_auc_score(mel_true, mel_score) if 0 < mel_true.sum() < len(mel_true) else 0.0

    return {
        "target_sensitivity": target_sens,
        "chosen_threshold": float(threshold),
        "MEL_sensitivity": float(sens),
        "MEL_specificity": float(spec),
        "MEL_precision": float(prec),
        "MEL_F1": float(f1),
        "MEL_AUROC": float(auc),
        "tp": int(tp), "fp": int(fp), "tn": int(tn), "fn": int(fn),
    }

def referral_policy(y_true, y_pred, y_probs, uncertainties, reject_percentiles=(80, 90, 95)):
    """
    Uncertainty-based referral: reject top X% most uncertain cases,
    compute coverage + performance on remaining cases.
    """
    results = []
    uncertainties = np.asarray(uncertainties)
    for p in reject_percentiles:
        thr = np.percentile(uncertainties, p)
        keep = uncertainties <= thr
        coverage = keep.mean()

        yt = y_true[keep]
        yp = y_pred[keep]
        ypr = y_probs[keep]

        acc = (yp == yt).mean() if len(yt) else 0
        # MEL sensitivity on kept cases
        mel_true = (yt == MEL_CLASS).astype(int)
        mel_pred = (yp == MEL_CLASS).astype(int)
        tn, fp, fn, tp = confusion_matrix(mel_true, mel_pred, labels=[0,1]).ravel()
        mel_sens = tp / (tp + fn + 1e-9)

        results.append({
            "reject_percentile": p,
            "uncertainty_threshold": float(thr),
            "coverage": float(coverage),
            "accuracy_on_kept": float(acc),
            "MEL_sensitivity_on_kept": float(mel_sens),
            "n_kept": int(len(yt)),
            "n_referred": int((~keep).sum())
        })
    return pd.DataFrame(results)

# Aggregate across folds using stored fold_results
# NOTE: fold_results[i]['labels'], ['preds'], ['probs'] exist in your notebook.
all_y_true = np.concatenate([r["labels"] for r in fold_results])
all_y_pred = np.concatenate([r["preds"] for r in fold_results])
all_y_prob = np.concatenate([r["probs"] for r in fold_results])

# If you computed uncertainties per fold (from estimate_uncertainty), collect them here.
# If not available, you can re-run uncertainty on the val loaders per fold or skip referral section.
# For now, we’ll recompute uncertainties on a quick subset by re-using the first fold model + val data if you saved them.
# (If you want full referral analysis per-fold, tell me and I’ll give the exact fold-level code.)
print("Clinical safety: MEL operating point on pooled CV predictions")
mel_report = mel_operating_point(all_y_true, all_y_prob, target_sens=0.95)
mel_df = pd.DataFrame([mel_report])
print(mel_df.to_string(index=False))
wandb.log({"clinical/MEL_operating_point": wandb.Table(dataframe=mel_df)})

print("\nTip: For uncertainty-based referral, you need uncertainties aligned with these pooled predictions.")
print("If you want, I can provide the exact fold-wise code to store uncertainties during validation and aggregate them.")

Clinical safety: MEL operating point on pooled CV predictions
 target_sensitivity  chosen_threshold  MEL_sensitivity  MEL_specificity  MEL_precision   MEL_F1  MEL_AUROC   tp    fp   tn  fn
               0.95          0.140667         0.950022         0.348359       0.240591 0.383949   0.819794 4296 13560 7249 226


Error: You must call wandb.init() before wandb.log()

In [29]:
wandb.log({"comparison_protocol": "Budgeted baselines/ablations: 10 epochs; UHViT full CV: 30 epochs"})

In [32]:
# CELL 21: ABLATION STUDY (FIXED LABELS + BEST-EPOCH SELECTION)

print("\n" + "="*50 + "\nABLATION STUDY\n" + "="*50)

# Define ablation models
ablation_models = {
    "UHViT (Full)": lambda: UHViT(CONFIG["num_classes"], CONFIG["dropout_rate"]),
    "Swin-T Only": lambda: SwinOnly(CONFIG["num_classes"], CONFIG["dropout_rate"]),
    "EfficientNet Only": lambda: EfficientNetOnly(CONFIG["num_classes"], CONFIG["dropout_rate"]),
    "Concat Fusion": lambda: UHViTConcatFusion(CONFIG["num_classes"], CONFIG["dropout_rate"]),
}

# Use a fixed fold split for budgeted ablation (fold 1)
train_idx, val_idx = list(skf.split(train_df, all_labels))[0]
abl_train_df, abl_val_df = train_df.iloc[train_idx].reset_index(drop=True), train_df.iloc[val_idx].reset_index(drop=True)

# Datasets & loaders (same transforms as main training)
abl_train_dataset = ISIC2019Dataset(abl_train_df, train_img_dir, get_train_transforms(CONFIG["img_size"]))
abl_val_dataset   = ISIC2019Dataset(abl_val_df,   train_img_dir, get_val_transforms(CONFIG["img_size"]))

# FIXED: build exactly one label per sample
abl_train_labels = []
for i in range(len(abl_train_df)):
    row = abl_train_df.iloc[i]
    label = 0
    for j, cls in enumerate(CONFIG["class_names"]):
        if cls in abl_train_df.columns and float(row[cls]) == 1.0:
            label = j
            break
    abl_train_labels.append(label)
abl_train_labels = np.array(abl_train_labels)

abl_sampler = create_balanced_sampler(abl_train_labels)

abl_train_loader = DataLoader(
    abl_train_dataset,
    batch_size=CONFIG["batch_size"],
    sampler=abl_sampler,
    num_workers=CONFIG["num_workers"],
    pin_memory=True,
)

abl_val_loader = DataLoader(
    abl_val_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    num_workers=CONFIG["num_workers"],
    pin_memory=True,
)

ablation_results = []

# Train each ablation model for 10 epochs (budgeted), report BEST val BAcc
for name, model_fn in ablation_models.items():
    print(f"\nTraining {name}...")
    model = model_fn().to(CONFIG["device"])

    samples_per_class = np.maximum(np.bincount(abl_train_labels, minlength=CONFIG["num_classes"]), 1)
    criterion = ClassBalancedLoss(samples_per_class, CONFIG["num_classes"])

    optimizer = AdamW(model.parameters(), lr=CONFIG["lr"], weight_decay=CONFIG["weight_decay"])
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

    best_bacc, best_acc, best_f1 = -1.0, 0.0, 0.0

    for epoch in range(10):
        train_epoch(model, abl_train_loader, criterion, optimizer, scheduler, CONFIG["device"], epoch + 1)
        _, val_acc, val_bacc, val_f1, _, _, _ = validate(model, abl_val_loader, criterion, CONFIG["device"])

        if val_bacc > best_bacc:
            best_bacc, best_acc, best_f1 = val_bacc, val_acc, val_f1

    print(f"{name} (BEST): Acc={best_acc:.4f}, BAcc={best_bacc:.4f}, F1={best_f1:.4f}")

    ablation_results.append({
        "Model": name,
        "Accuracy_best": float(best_acc),
        "BalancedAcc_best": float(best_bacc),
        "MacroF1_best": float(best_f1),
        "epochs": 10,
        "protocol": "budgeted_single_fold_best_epoch"
    })

    del model
    torch.cuda.empty_cache()

abl_df = pd.DataFrame(ablation_results)
print("\nAblation Results (budgeted, single fold, best epoch):")
print(abl_df.to_string(index=False))

try:
    wandb.log({"ablation_study": wandb.Table(dataframe=abl_df)})
except Exception as e:
    print("W&B log skipped:", e)


ABLATION STUDY

Training UHViT (Full)...


Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

UHViT (Full) (BEST): Acc=0.3306, BAcc=0.6317, F1=0.3917

Training Swin-T Only...


Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>^^
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^    ^^self._shutdown_workers()^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^    ^if w.is_alive():^
^ ^ ^ ^ ^

Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Swin-T Only (BEST): Acc=0.4227, BAcc=0.6745, F1=0.4213

Training EfficientNet Only...


Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process


Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

EfficientNet Only (BEST): Acc=0.3570, BAcc=0.6232, F1=0.3358

Training Concat Fusion...


Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
IOStream.flush timed out
IOStream.flush timed out


Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
IOStream.flush timed out
IOStream.flush timed out


Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Concat Fusion (BEST): Acc=0.4579, BAcc=0.6884, F1=0.4658

Ablation Results (budgeted, single fold, best epoch):
            Model  Accuracy_best  BalancedAcc_best  MacroF1_best  epochs                        protocol
     UHViT (Full)       0.330570          0.631712      0.391707      10 budgeted_single_fold_best_epoch
      Swin-T Only       0.422735          0.674525      0.421324      10 budgeted_single_fold_best_epoch
EfficientNet Only       0.357016          0.623233      0.335812      10 budgeted_single_fold_best_epoch
    Concat Fusion       0.457865          0.688450      0.465830      10 budgeted_single_fold_best_epoch


In [33]:
# CELL 22: SOTA COMPARISON (FIXED: proper samples_per_class + BEST epoch)

print("\n" + "="*50 + "\nSOTA COMPARISON (budgeted single-fold)\n" + "="*50)

# This cell assumes you already ran the FIXED ablation cell (CELL 21) that created:
#   abl_train_loader, abl_val_loader, abl_train_labels
# If not, run CELL 21 first.

assert "abl_train_loader" in globals() and "abl_val_loader" in globals() and "abl_train_labels" in globals(), \
    "Run the fixed Ablation Study cell (CELL 21) first."

# Define criterion inputs correctly for this fold
samples_per_class = np.maximum(np.bincount(abl_train_labels, minlength=CONFIG["num_classes"]), 1)
criterion = ClassBalancedLoss(samples_per_class, CONFIG["num_classes"])

sota_results = []

for sota_name in SOTA_MODELS:
    print(f"\nTraining SOTA baseline: {sota_name} ...")
    try:
        model = create_sota_model(sota_name, CONFIG["num_classes"]).to(CONFIG["device"])
        optimizer = AdamW(model.parameters(), lr=CONFIG["lr"], weight_decay=CONFIG["weight_decay"])
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

        best_bacc, best_acc, best_f1 = -1.0, 0.0, 0.0

        for epoch in range(10):
            train_epoch(model, abl_train_loader, criterion, optimizer, scheduler, CONFIG["device"], epoch + 1)
            _, val_acc, val_bacc, val_f1, _, _, _ = validate(model, abl_val_loader, criterion, CONFIG["device"])

            if val_bacc > best_bacc:
                best_bacc, best_acc, best_f1 = val_bacc, val_acc, val_f1

        print(f"{sota_name} (BEST): Acc={best_acc:.4f}, BAcc={best_bacc:.4f}, F1={best_f1:.4f}")

        sota_results.append({
            "Model": sota_name,
            "Accuracy_best": float(best_acc),
            "BalancedAcc_best": float(best_bacc),
            "MacroF1_best": float(best_f1),
            "epochs": 10,
            "protocol": "budgeted_single_fold_best_epoch"
        })

        del model
        torch.cuda.empty_cache()

    except Exception as e:
        print(f"Error with {sota_name}: {e}")

# Add your model result (from 5-fold CV means) as "Ours"
# These variables accs/baccs/f1s exist in your CV summary cell; if not, you can skip or hardcode.
if "accs" in globals() and "baccs" in globals() and "f1s" in globals():
    sota_results.append({
        "Model": "UHViT (Ours, 5-fold CV mean)",
        "Accuracy_best": float(np.mean(accs)),
        "BalancedAcc_best": float(np.mean(baccs)),
        "MacroF1_best": float(np.mean(f1s)),
        "epochs": CONFIG.get("epochs", 30),
        "protocol": "5-fold_CV_mean"
    })

sota_df = pd.DataFrame(sota_results).sort_values("BalancedAcc_best", ascending=False)
print("\nSOTA Comparison Results:")
print(sota_df.to_string(index=False))

try:
    wandb.log({"sota_comparison": wandb.Table(dataframe=sota_df)})
except Exception as e:
    print("W&B log skipped:", e)


SOTA COMPARISON (budgeted single-fold)

Training SOTA baseline: resnet50 ...


model.safetensors:   0%|          | 0.00/102M [00:00<?, ?B/s]

Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdow

Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

resnet50 (BEST): Acc=0.1508, BAcc=0.4836, F1=0.1594

Training SOTA baseline: efficientnet_b3 ...


Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

efficientnet_b3 (BEST): Acc=0.4894, BAcc=0.6453, F1=0.4352

Training SOTA baseline: vit_base_patch16_224 ...


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

vit_base_patch16_224 (BEST): Acc=0.3949, BAcc=0.6142, F1=0.3561

Training SOTA baseline: swin_tiny_patch4_window7_224 ...


Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist

Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/datal

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

swin_tiny_patch4_window7_224 (BEST): Acc=0.4470, BAcc=0.6940, F1=0.4667

Training SOTA baseline: convnext_tiny ...


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x728d383977e0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Val:   0%|          | 0/159 [00:00<?, ?it/s]

convnext_tiny (BEST): Acc=0.4654, BAcc=0.6753, F1=0.4874

Training SOTA baseline: densenet121 ...


model.safetensors:   0%|          | 0.00/32.3M [00:00<?, ?B/s]

Train E1:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E2:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E3:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E4:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E5:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E6:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E7:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E8:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E9:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

Train E10:   0%|          | 0/634 [00:00<?, ?it/s]

Val:   0%|          | 0/159 [00:00<?, ?it/s]

densenet121 (BEST): Acc=0.3689, BAcc=0.6107, F1=0.3293

SOTA Comparison Results:
                       Model  Accuracy_best  BalancedAcc_best  MacroF1_best  epochs                        protocol
UHViT (Ours, 5-fold CV mean)       0.505074          0.733522      0.581658      30                  5-fold_CV_mean
swin_tiny_patch4_window7_224       0.447010          0.694033      0.466748      10 budgeted_single_fold_best_epoch
               convnext_tiny       0.465364          0.675253      0.487435      10 budgeted_single_fold_best_epoch
             efficientnet_b3       0.489441          0.645291      0.435196      10 budgeted_single_fold_best_epoch
        vit_base_patch16_224       0.394908          0.614220      0.356067      10 budgeted_single_fold_best_epoch
                 densenet121       0.368857          0.610687      0.329264      10 budgeted_single_fold_best_epoch
                    resnet50       0.150780          0.483578      0.159402      10 budgeted_single_fold_be

In [34]:
# CELL 23: CROSS-DATASET VALIDATION (HAM10000)
print("\n" + "="*50 + "\nCROSS-DATASET VALIDATION (HAM10000)\n" + "="*50)
ham_base = Path(CONFIG['data_dir']) / 'ham10000'
ham_csv_files = list(ham_base.rglob('*metadata*.csv'))
if not ham_csv_files: ham_csv_files = list(ham_base.rglob('*.csv'))
ham_csv = ham_csv_files[0]
ham_img_dirs = [d for d in ham_base.iterdir() if d.is_dir() and 'HAM' in d.name]
if not ham_img_dirs: ham_img_dirs = [ham_base]
ham_dataset = HAM10000Dataset(ham_csv, ham_img_dirs, get_val_transforms(CONFIG['img_size']))
ham_loader = DataLoader(ham_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'])
best_model = fold_results[0]['model']; best_model.eval()
ham_preds, ham_labels, ham_probs = [], [], []
with torch.no_grad():
    for images, labels in tqdm(ham_loader, desc='HAM10000 Eval'):
        images = images.to(CONFIG['device']); outputs = best_model(images)
        ham_probs.extend(F.softmax(outputs, dim=1).cpu().numpy()); ham_preds.extend(outputs.argmax(dim=1).cpu().numpy()); ham_labels.extend(labels.numpy())
ham_preds, ham_labels, ham_probs = np.array(ham_preds), np.array(ham_labels), np.array(ham_probs)
ham_acc, ham_bacc, ham_f1 = accuracy_score(ham_labels, ham_preds), balanced_accuracy_score(ham_labels, ham_preds), f1_score(ham_labels, ham_preds, average='macro')
print(f"\nHAM10000 Results:\nAccuracy: {ham_acc:.4f}\nBalanced Accuracy: {ham_bacc:.4f}\nMacro F1: {ham_f1:.4f}")
plot_confusion_matrix(ham_labels, ham_preds, CONFIG['class_names'], f"{CONFIG['output_dir']}/ham10000_cm.png")
wandb.log({'cross_dataset/ham_acc': ham_acc, 'cross_dataset/ham_bacc': ham_bacc, 'cross_dataset/ham_f1': ham_f1})




CROSS-DATASET VALIDATION (HAM10000)


HAM10000 Eval:   0%|          | 0/313 [00:00<?, ?it/s]


HAM10000 Results:
Accuracy: 0.5255
Balanced Accuracy: 0.6976
Macro F1: 0.4359


In [35]:
# CELL 23B: EXTERNAL VALIDATION (DermaMNIST from MedMNIST) - public direct download

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import medmnist
from medmnist import INFO
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, roc_auc_score

derma_info = INFO["dermamnist"]
DermaMNIST = getattr(medmnist, derma_info["python_class"])

# DermaMNIST: 7-class dermoscopy dataset (different label set than ISIC 8-class)
# We use it as an EXTERNAL distribution test by reporting:
#   - Top-1 confidence distribution
#   - OOD-style metrics / uncertainty (optional)
#   - And/or map to MEL vs NON-MEL if you want (DermaMNIST doesn't have MEL as a direct class label).
# For IF≈5, it’s acceptable to report "cross-dataset robustness" with caveats.

class DermaMNISTWrapper(Dataset):
    def __init__(self, split="test", transform=None):
        self.data = DermaMNIST(split=split, download=True)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, y = self.data[idx]           # img is PIL, y is int or array
        if isinstance(y, np.ndarray):
            y = int(y.squeeze())
        else:
            y = int(y)
        img = np.array(img.convert("RGB"))
        if self.transform:
            img = self.transform(image=img)["image"]
        return img, y

# Use val transforms (resize + normalize) for compatibility
derma_test = DermaMNISTWrapper(split="test", transform=get_val_transforms(CONFIG["img_size"]))
derma_loader = DataLoader(derma_test, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=CONFIG["num_workers"])

best_model = fold_results[0]["model"]
best_model.eval()

all_conf = []
all_entropy = []

with torch.no_grad():
    for images, _ in tqdm(derma_loader, desc="DermaMNIST External Eval"):
        images = images.to(CONFIG["device"])
        probs8 = F.softmax(best_model(images), dim=1)          # ISIC 8-class outputs
        conf = probs8.max(dim=1).values.detach().cpu().numpy() # confidence of predicted ISIC class
        ent = (-probs8 * (probs8.clamp_min(1e-9)).log()).sum(dim=1).detach().cpu().numpy()

        all_conf.extend(conf.tolist())
        all_entropy.extend(ent.tolist())

all_conf = np.array(all_conf)
all_entropy = np.array(all_entropy)

report = {
    "DermaMNIST_external_n": int(len(all_conf)),
    "mean_confidence": float(all_conf.mean()),
    "median_confidence": float(np.median(all_conf)),
    "mean_entropy": float(all_entropy.mean()),
    "median_entropy": float(np.median(all_entropy)),
}

print("DermaMNIST External Robustness Report:")
for k, v in report.items():
    print(f"{k}: {v}")

wandb.run.summary["external_note"] = (
    "External dataset label space differs from ISIC2019; "
    "we report robustness metrics (confidence/entropy) under distribution shift."
)
wandb.log({"external/note": "External dataset label space differs; we report robustness metrics (confidence/entropy) under distribution shift."})

100%|██████████| 19.7M/19.7M [00:01<00:00, 15.0MB/s]


DermaMNIST External Eval:   0%|          | 0/63 [00:00<?, ?it/s]

DermaMNIST External Robustness Report:
DermaMNIST_external_n: 2005
mean_confidence: 0.6616833858507827
median_confidence: 0.6399524807929993
mean_entropy: 0.9612262350551208
median_entropy: 1.0979413986206055


In [36]:
# CELL 23C: DermaMNIST W&B logging + note (append after 23B)

import pandas as pd

# 'report' should exist from CELL 23B
wandb.log({"external/dermamnist_report": wandb.Table(dataframe=pd.DataFrame([report]))})

wandb.run.summary["external_note"] = (
    "External dataset label space differs from ISIC2019; "
    "we report robustness metrics (confidence/entropy) under distribution shift."
)
wandb.log({
    "external/note": "External dataset label space differs; we report robustness metrics (confidence/entropy) under distribution shift."
})
print("✓ Logged DermaMNIST robustness report + note to W&B")

✓ Logged DermaMNIST robustness report + note to W&B


In [None]:
# CELL 24: COMPUTATIONAL ANALYSIS
print("\n" + "="*50 + "\nCOMPUTATIONAL ANALYSIS\n" + "="*50)
comp_models = {'UHViT': lambda: UHViT(CONFIG['num_classes'], CONFIG['dropout_rate']), 'Swin-T': lambda: SwinOnly(CONFIG['num_classes'], CONFIG['dropout_rate']), 'EfficientNet-B3': lambda: EfficientNetOnly(CONFIG['num_classes'], CONFIG['dropout_rate']), 'ResNet50': lambda: create_sota_model('resnet50', CONFIG['num_classes']), 'ViT-Base': lambda: create_sota_model('vit_base_patch16_224', CONFIG['num_classes'])}
comp_df = computational_analysis(comp_models, CONFIG['device']); print("\nComputational Analysis:\n" + comp_df.to_string(index=False))




COMPUTATIONAL ANALYSIS
Analyzing UHViT...
Analyzing Swin-T...
Analyzing EfficientNet-B3...
Analyzing ResNet50...
Analyzing ViT-Base...

Computational Analysis:
          Model  Params (M)  FLOPs (G)  Time (ms)
          UHViT   40.580266  10.712108  19.199839
         Swin-T   27.527042   8.760381   8.547678
EfficientNet-B3   10.711600   1.945989   9.249110
       ResNet50   23.524424   8.260809   4.372694
       ViT-Base   85.804808  24.033695   4.171274


In [None]:
# CELL 25: SAVE FINAL RESULTS
results_summary = {'UHViT_5Fold_CV': {'Accuracy': f"{np.mean(accs):.4f} ± {np.std(accs):.4f}", 'Balanced_Accuracy': f"{np.mean(baccs):.4f} ± {np.std(baccs):.4f}", 'Macro_F1': f"{np.mean(f1s):.4f} ± {np.std(f1s):.4f}", 'CI_95_Accuracy': f"[{acc_lo:.4f}, {acc_hi:.4f}]", 'CI_95_Balanced_Acc': f"[{bacc_lo:.4f}, {bacc_hi:.4f}]", 'CI_95_Macro_F1': f"[{f1_lo:.4f}, {f1_hi:.4f}]"}, 'Cross_Dataset_HAM10000': {'Accuracy': f"{ham_acc:.4f}", 'Balanced_Accuracy': f"{ham_bacc:.4f}", 'Macro_F1': f"{ham_f1:.4f}"}}
with open(f"{CONFIG['output_dir']}/results_summary.json", 'w') as f: json.dump(results_summary, f, indent=2)
fold_df = pd.DataFrame([{'Fold': i+1, 'Accuracy': r['acc'], 'Balanced_Acc': r['bacc'], 'Macro_F1': r['f1']} for i, r in enumerate(fold_results)])
fold_df.to_csv(f"{CONFIG['output_dir']}/fold_results.csv", index=False)
print(f"\n{'='*50}\nTRAINING COMPLETE!\n{'='*50}")
print(f"\nResults saved to: {CONFIG['output_dir']}\nCheckpoints saved to: {CONFIG['checkpoint_dir']}")
print(f"\nFinal Results:\n  Balanced Accuracy: {np.mean(baccs):.4f} ± {np.std(baccs):.4f}\n  Macro F1: {np.mean(f1s):.4f} ± {np.std(f1s):.4f}\n  Cross-Dataset (HAM10000) BAcc: {ham_bacc:.4f}")
wandb.finish()
print("\n✓ All done! Check W&B for full logs.")



TRAINING COMPLETE!

Results saved to: ./outputs
Checkpoints saved to: ./checkpoints

Final Results:
  Balanced Accuracy: 0.7335 ± 0.0108
  Macro F1: 0.5817 ± 0.0129
  Cross-Dataset (HAM10000) BAcc: 0.6976


0,1
cross_dataset/ham_acc,▁
cross_dataset/ham_bacc,▁
cross_dataset/ham_f1,▁
final/acc_mean,▁
final/acc_std,▁
final/bacc_mean,▁
final/bacc_std,▁
final/f1_mean,▁
final/f1_std,▁
fold0/train_bacc,▁▃▃▄▄▄▅▆▅▅▅▆▆▇▇▇▆▆▆▆▆▇▇▇▇▇████

0,1
comparison_protocol,Budgeted baselines/a...
cross_dataset/ham_acc,0.52551
cross_dataset/ham_bacc,0.69758
cross_dataset/ham_f1,0.43587
external/note,External dataset lab...
external_note,External dataset lab...
final/acc_mean,0.50507
final/acc_std,0.01678
final/bacc_mean,0.73352
final/bacc_std,0.01076



✓ All done! Check W&B for full logs.


In [41]:
# CELL 20C: MEL additional operating points (Youden J + PR-AUC)

import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, confusion_matrix, average_precision_score, precision_recall_curve, roc_auc_score

MEL_CLASS = CONFIG["class_names"].index("MEL")

def summarize_threshold(mel_true, mel_score, threshold):
    mel_pred = (mel_score >= threshold).astype(int)
    tn, fp, fn, tp = confusion_matrix(mel_true, mel_pred).ravel()
    sens = tp / (tp + fn + 1e-9)
    spec = tn / (tn + fp + 1e-9)
    prec = tp / (tp + fp + 1e-9)
    f1 = 2 * prec * sens / (prec + sens + 1e-9)
    return dict(
        threshold=float(threshold),
        sensitivity=float(sens),
        specificity=float(spec),
        precision=float(prec),
        f1=float(f1),
        tp=int(tp), fp=int(fp), tn=int(tn), fn=int(fn)
    )

# Pooled CV arrays should exist from CELL 20B:
# all_y_true, all_y_prob
mel_true = (all_y_true == MEL_CLASS).astype(int)
mel_score = all_y_prob[:, MEL_CLASS]

# AUROC + PR-AUC (PR-AUC is informative for imbalanced MEL)
mel_auroc = roc_auc_score(mel_true, mel_score) if 0 < mel_true.sum() < len(mel_true) else 0.0
mel_prauc = average_precision_score(mel_true, mel_score) if 0 < mel_true.sum() < len(mel_true) else 0.0

# Youden's J optimal threshold (maximizes sensitivity + specificity - 1)
fpr, tpr, thr = roc_curve(mel_true, mel_score)
youden_j = tpr - fpr
best_idx = int(np.argmax(youden_j))
thr_youden = thr[best_idx]

youden_report = summarize_threshold(mel_true, mel_score, thr_youden)
youden_report.update({
    "metric": "YoudenJ_optimal",
    "MEL_AUROC": float(mel_auroc),
    "MEL_PR_AUC": float(mel_prauc),
    "youden_J": float(youden_j[best_idx]),
})

# Also report a higher-specificity point (optional): 90% sensitivity
target_sens = 0.90
idx = np.where(tpr >= target_sens)[0]
if len(idx) > 0:
    best_idx2 = idx[np.argmin(fpr[idx])]
    thr_sens90 = thr[best_idx2]
    sens90_report = summarize_threshold(mel_true, mel_score, thr_sens90)
    sens90_report.update({
        "metric": "Sensitivity_90%",
        "target_sensitivity": float(target_sens),
        "MEL_AUROC": float(mel_auroc),
        "MEL_PR_AUC": float(mel_prauc),
    })
else:
    sens90_report = None

rows = [youden_report] + ([sens90_report] if sens90_report is not None else [])
df = pd.DataFrame(rows)

print("Additional MEL operating points:")
print(df[["metric","threshold","sensitivity","specificity","precision","f1","MEL_AUROC","MEL_PR_AUC"]].to_string(index=False))

wandb.log({"clinical/MEL_additional_operating_points": wandb.Table(dataframe=df)})
wandb.log({"clinical/MEL_PR_AUC": float(mel_prauc), "clinical/MEL_AUROC": float(mel_auroc)})

Additional MEL operating points:
         metric  threshold  sensitivity  specificity  precision       f1  MEL_AUROC  MEL_PR_AUC
YoudenJ_optimal   0.341932     0.726670     0.750156   0.387272 0.505266   0.819794    0.569331
Sensitivity_90%   0.215831     0.900044     0.494546   0.278996 0.425955   0.819794    0.569331


Error: You must call wandb.init() before wandb.log()