In [51]:
# 0. install prerequisites
!pip -q install scikit-learn medmnist tqdm ipywidgets jupyter robustbench torch-uncertainty



In [52]:
# 1. imports + global config
import os
import math
import time
import json
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, models
from tqdm.auto import tqdm

from sklearn.metrics import roc_auc_score


In [53]:
# 2. reproducibility utilities
def set_seed(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
DEVICE

print(f"Current device: {torch.cuda.current_device()}")
print(f"Device name: {torch.cuda.get_device_name(0)}")

Current device: 0
Device name: NVIDIA GeForce RTX 2050


In [54]:
# 3. settings
@dataclass
class TrainConfig:
    # Datasets: "cifar10", "cifar100", "tinyimagenet", "tissuemnist",
    #           "cifar10-c", "cifar100-c", "tinyimagenet-c"
    dataset: str = "cifar100"

    # Models: "shufflenetv2_0_5", "mobilenetv3_small", "efficientnetv2_s"
    model_name: str = "shufflenetv2_0_5"

    # Data paths
    data_root: str = "./data"
    tinyimagenet_root: str = "./tiny-imagenet-200"  # required only for tinyimagenet

    # Training
    epochs: int = 50
    batch_size: int = 64  # Reduced from 128 for smaller GPU
    num_workers: int = 0
    lr: float = 0.05
    weight_decay: float = 1e-4
    momentum: float = 0.9
    seed: int = 42

    # --- stability knobs ---
    label_smoothing: float = 0.1      # 0.0 to disable
    max_grad_norm: float = 1.0        # 0.0 to disable
    lr_efficientnet: float = 0.01     # override for efficientnetv2_s (was 0.05)

    # Percentile rejection
    reject_percentiles: Tuple[int, ...] = (10, 20, 30, 40, 50)

    # Hybrid definition
    hybrid_weight_entropy: float = 0.5
    hybrid_weight_grad: float = 0.5

    # Where to save outputs
    out_dir: str = "./outputs"

cfg = TrainConfig()
os.makedirs(cfg.out_dir, exist_ok=True)
asdict(cfg)

{'dataset': 'cifar100',
 'model_name': 'shufflenetv2_0_5',
 'data_root': './data',
 'tinyimagenet_root': './tiny-imagenet-200',
 'epochs': 50,
 'batch_size': 64,
 'num_workers': 0,
 'lr': 0.05,
 'weight_decay': 0.0001,
 'momentum': 0.9,
 'seed': 42,
 'reject_percentiles': (10, 20, 30, 40, 50),
 'hybrid_weight_entropy': 0.5,
 'hybrid_weight_grad': 0.5,
 'out_dir': './outputs'}

In [55]:
# 4. data transforms (Shahriar-aligned normalization)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

def build_transforms(train: bool, dataset_name: str = ""):
    if train:
        if dataset_name == "tissuemnist":
            return transforms.Compose([
                transforms.Grayscale(num_output_channels=3),  # Convert 1-channel to 3-channel
                transforms.ToTensor(),
                transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
            ])
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])
    else:
        if dataset_name == "tissuemnist":
            return transforms.Compose([
                transforms.Grayscale(num_output_channels=3),  # Convert 1-channel to 3-channel
                transforms.ToTensor(),
                transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
            ])
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])

In [56]:
# 5. dataset helpers needed by loaders
# 5a. TinyImageNet downloader
def download_tinyimagenet(root: str):
    """
    Downloads and extracts TinyImageNet-200 to root directory.
    Download URL: http://cs231n.stanford.edu/tiny-imagenet-200.zip

    If automatic download fails, manually download the zip file and place it in the root directory.
    """
    import urllib.request
    import zipfile

    tiny_root = os.path.join(root, "tiny-imagenet-200")

    # Check if already downloaded
    if os.path.exists(os.path.join(tiny_root, "train")):
        print(f"✓ TinyImageNet already exists at {tiny_root}")
        return tiny_root

    # Download
    url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
    zip_path = os.path.join(root, "tiny-imagenet-200.zip")

    print(f"Downloading TinyImageNet from {url}...")
    print("(This is ~270MB and will take a few minutes)")

    try:
        urllib.request.urlretrieve(url, zip_path)
        print("✓ Download complete")
    except Exception as e:
        print(f"❌ Auto-download failed: {e}")
        print(f"\nTo use TinyImageNet, manually download from:")
        print(f"  {url}")
        print(f"Then extract to: {root}")
        raise

    # Extract
    print(f"Extracting to {root}...")
    with zipfile.ZipFile(zip_path, 'r') as zf:
        zf.extractall(root)
    print("✓ Extraction complete")

    # Cleanup zip
    os.remove(zip_path)

    return tiny_root


# Helper class for local Tiny-ImageNet-C loading
class LocalTinyImageNetCDataset(torch.utils.data.Dataset):
    """Load Tiny ImageNet-C from local directory structure.

    Structure: root/{corruption}/{severity}/{class_id}/{image}.JPEG
    Loads all corruptions and all severity levels.
    """
    def __init__(self, root: str, transform=None):
        self.root = root
        self.transform = transform
        self.images = []
        self.labels = []

        # Build class name to label mapping from class directories
        class_to_label = {}
        class_dirs = sorted([d for d in os.listdir(root) 
                            if os.path.isdir(os.path.join(root, d)) and d[0] == 'n'])

        # Get label mapping from any corruption/severity directory
        for corruption_dir in os.listdir(root):
            corruption_path = os.path.join(root, corruption_dir)
            if not os.path.isdir(corruption_path):
                continue

            # Find a severity level directory
            for severity_dir in os.listdir(corruption_path):
                severity_path = os.path.join(corruption_path, severity_dir)
                if not os.path.isdir(severity_path):
                    continue

                # Build class mapping
                for class_id in sorted(os.listdir(severity_path)):
                    class_path = os.path.join(severity_path, class_id)
                    if os.path.isdir(class_path):
                        if class_id not in class_to_label:
                            class_to_label[class_id] = len(class_to_label)

                # Only need one severity to build the mapping
                break
            break

        if not class_to_label:
            raise FileNotFoundError(f"No class directories found in {root}")

        # Now load all images from all corruptions and severities
        for corruption_dir in sorted(os.listdir(root)):
            corruption_path = os.path.join(root, corruption_dir)
            if not os.path.isdir(corruption_path):
                continue

            for severity_dir in sorted(os.listdir(corruption_path)):
                severity_path = os.path.join(corruption_path, severity_dir)
                if not os.path.isdir(severity_path):
                    continue

                for class_id in sorted(os.listdir(severity_path)):
                    class_path = os.path.join(severity_path, class_id)
                    if not os.path.isdir(class_path):
                        continue

                    label = class_to_label[class_id]

                    # Load all images in this class
                    for img_file in sorted(os.listdir(class_path)):
                        if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                            img_path = os.path.join(class_path, img_file)
                            self.images.append(img_path)
                            self.labels.append(label)

        if not self.images:
            raise FileNotFoundError(f"No images found in {root}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        try:
            from PIL import Image
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            raise RuntimeError(f"Failed to load image {img_path}: {e}")

        if self.transform:
            img = self.transform(img)

        return img, label


# Helper classes for CIFAR-10-C and CIFAR-100-C loading
class CIFAR10CDataset(torch.utils.data.Dataset):
    """Load CIFAR-10-C from local .npy files.

    Structure: root/{corruption}.npy (e.g., gaussian_noise.npy)
               root/labels.npy
    Each .npy file contains 50000 images (10000 per severity, severities 1-5).
    """
    def __init__(self, root: str, transform=None, severity: int = 5):
        self.root = root
        self.transform = transform
        self.severity = severity

        # Load labels
        labels_path = os.path.join(root, "labels.npy")
        if not os.path.exists(labels_path):
            raise FileNotFoundError(f"Labels not found at {labels_path}")

        all_labels = np.load(labels_path)  # shape: (50000,)

        # Get corruption types (all .npy files except labels.npy)
        corruption_files = [f for f in os.listdir(root) 
                           if f.endswith('.npy') and f != 'labels.npy']

        if not corruption_files:
            raise FileNotFoundError(f"No corruption files found in {root}")

        # Load all corruptions for the specified severity
        all_images = []
        all_labels_list = []

        # Each corruption file has 50000 images: 10000 per severity (1-5)
        # Severity 1: indices 0-9999
        # Severity 5: indices 40000-49999
        start_idx = (severity - 1) * 10000
        end_idx = start_idx + 10000

        for corruption_file in sorted(corruption_files):
            corruption_path = os.path.join(root, corruption_file)
            images = np.load(corruption_path)  # shape: (50000, 32, 32, 3)

            # Extract images for the specified severity
            severity_images = images[start_idx:end_idx]
            all_images.append(severity_images)
            all_labels_list.append(all_labels[start_idx:end_idx])

        # Concatenate all corruptions
        self.images = np.concatenate(all_images, axis=0)  # (N, 32, 32, 3)
        self.labels = np.concatenate(all_labels_list, axis=0)  # (N,)

        print(f"✓ Loaded CIFAR-10-C: {len(corruption_files)} corruptions, "
              f"severity {severity}, {len(self.images)} total images")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]  # (32, 32, 3) uint8
        label = int(self.labels[idx])

        # Convert to PIL Image
        from PIL import Image
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        return img, label


class CIFAR100CDataset(torch.utils.data.Dataset):
    """Load CIFAR-100-C from local .npy files.

    Structure: root/{corruption}.npy (e.g., gaussian_noise.npy)
               root/labels.npy
    Each .npy file contains 50000 images (10000 per severity, severities 1-5).
    """
    def __init__(self, root: str, transform=None, severity: int = 5):
        self.root = root
        self.transform = transform
        self.severity = severity

        # Load labels
        labels_path = os.path.join(root, "labels.npy")
        if not os.path.exists(labels_path):
            raise FileNotFoundError(f"Labels not found at {labels_path}")

        all_labels = np.load(labels_path)  # shape: (50000,)

        # Get corruption types (all .npy files except labels.npy)
        corruption_files = [f for f in os.listdir(root) 
                           if f.endswith('.npy') and f != 'labels.npy']

        if not corruption_files:
            raise FileNotFoundError(f"No corruption files found in {root}")

        # Load all corruptions for the specified severity
        all_images = []
        all_labels_list = []

        # Each corruption file has 50000 images: 10000 per severity (1-5)
        # Severity 1: indices 0-9999
        # Severity 5: indices 40000-49999
        start_idx = (severity - 1) * 10000
        end_idx = start_idx + 10000

        for corruption_file in sorted(corruption_files):
            corruption_path = os.path.join(root, corruption_file)
            images = np.load(corruption_path)  # shape: (50000, 32, 32, 3)

            # Extract images for the specified severity
            severity_images = images[start_idx:end_idx]
            all_images.append(severity_images)
            all_labels_list.append(all_labels[start_idx:end_idx])

        # Concatenate all corruptions
        self.images = np.concatenate(all_images, axis=0)  # (N, 32, 32, 3)
        self.labels = np.concatenate(all_labels_list, axis=0)  # (N,)

        print(f"✓ Loaded CIFAR-100-C: {len(corruption_files)} corruptions, "
              f"severity {severity}, {len(self.images)} total images")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]  # (32, 32, 3) uint8
        label = int(self.labels[idx])

        # Convert to PIL Image
        from PIL import Image
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        return img, label


def build_tinyimagenet_c_loaders(cfg: TrainConfig):
    """
    Load Tiny ImageNet-C from local directory structure.
    Structure: Tiny-ImageNet-C/{corruption}/{severity}/{class_id}/{image}.JPEG
    Train: on clean Tiny ImageNet
    Test: on corrupted Tiny ImageNet-C (all corruptions, all severities)
    """
    train_tf = build_transforms(train=True, dataset_name="tinyimagenet")
    test_tf  = build_transforms(train=False, dataset_name="tinyimagenet")

    # Load regular training data (standard Tiny ImageNet)
    tiny_root = download_tinyimagenet(cfg.data_root)
    train_dir = os.path.join(tiny_root, "train")

    def is_valid_file(x):
        return x.lower().endswith(('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'))

    train_ds = datasets.ImageFolder(root=train_dir, transform=train_tf, is_valid_file=is_valid_file)

    # Load corrupted Tiny ImageNet-C from local directory
    tinyimagenet_c_root = os.path.join(cfg.data_root, "Tiny-ImageNet-C")
    if not os.path.exists(tinyimagenet_c_root):
        raise FileNotFoundError(
            f"Tiny-ImageNet-C not found at {tinyimagenet_c_root}\n"
            f"Expected structure: Tiny-ImageNet-C/{{corruption}}/{{severity}}/{{class}}/{{image}}.JPEG"
        )

    test_ds = LocalTinyImageNetCDataset(root=tinyimagenet_c_root, transform=test_tf)

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                              num_workers=cfg.num_workers, pin_memory=True, drop_last=False)
    test_loader  = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False,
                              num_workers=cfg.num_workers, pin_memory=True, drop_last=False)
    print(f"✓ Loaded Tiny ImageNet-C from local directory - {len(test_ds)} samples")
    return train_loader, test_loader

In [57]:
# 5b. Diagnostic: Check TinyImageNet structure
def check_tinyimagenet_structure(data_root: str):
    """Debug: inspect the TinyImageNet directory structure"""
    tiny_root = os.path.join(data_root, "tiny-imagenet-200")
    
    if not os.path.exists(tiny_root):
        print(f"❌ TinyImageNet not found at {tiny_root}")
        return
    
    print(f"✓ TinyImageNet found at {tiny_root}")
    
    # Check train folder
    train_dir = os.path.join(tiny_root, "train")
    if os.path.exists(train_dir):
        train_classes = os.listdir(train_dir)
        print(f"  ✓ Train: {len(train_classes)} classes")
    else:
        print(f"  ❌ Train folder not found")
    
    # Check val folder
    val_dir = os.path.join(tiny_root, "val")
    if os.path.exists(val_dir):
        val_contents = os.listdir(val_dir)
        print(f"  ✓ Val: {len(val_contents)} items: {val_contents[:5]}")
        
        # Check if images are organized into class folders
        class_folders = [d for d in val_contents if os.path.isdir(os.path.join(val_dir, d)) and d != "images"]
        if class_folders:
            sample_class = class_folders[0]
            sample_images = os.listdir(os.path.join(val_dir, sample_class))
            print(f"    - Class folder '{sample_class}' has {len(sample_images)} images")
        else:
            print(f"    - No class folders found (images still in flat structure)")
            images_dir = os.path.join(val_dir, "images")
            if os.path.exists(images_dir):
                img_count = len(os.listdir(images_dir))
                sample_imgs = os.listdir(images_dir)[:3]
                print(f"    - Found {img_count} images in 'images' folder")
                print(f"    - Sample: {sample_imgs}")
    else:
        print(f"  ❌ Val folder not found")


In [58]:
# 5c. Helper functions: get_num_classes and build_loaders
def get_num_classes(dataset_name: str) -> int:
    """Return number of classes for each dataset."""
    mapping = {
        "cifar10": 10,
        "cifar10-c": 10,
        "cifar100": 100,
        "cifar100-c": 100,
        "tinyimagenet": 200,
        "tinyimagenet-c": 200,
        "tissuemnist": 8,
    }
    if dataset_name not in mapping:
        raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(mapping.keys())}")
    return mapping[dataset_name]

def build_loaders(cfg: TrainConfig) -> Tuple[DataLoader, DataLoader]:
    """
    Build train and test loaders for all supported datasets.
    
    Protocol for corrupted datasets:
    - Train: on clean dataset
    - Test: on corrupted dataset
    
    Note: Test loader uses num_workers=0 for efficiency during evaluation with gradients
    """
    train_tf = build_transforms(train=True, dataset_name=cfg.dataset.replace("-c", ""))
    test_tf  = build_transforms(train=False, dataset_name=cfg.dataset.replace("-c", ""))
    
    # Test loaders use 0 workers for efficiency (gradient computation is GPU-bound, not I/O bound)
    test_num_workers = 0
    
    # === CIFAR-10 ===
    if cfg.dataset == "cifar10":
        train_ds = datasets.CIFAR10(root=cfg.data_root, train=True, transform=train_tf, download=True)
        test_ds = datasets.CIFAR10(root=cfg.data_root, train=False, transform=test_tf, download=True)
        train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                                  num_workers=cfg.num_workers, pin_memory=True, drop_last=False)
        test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False,
                                 num_workers=test_num_workers, pin_memory=True, drop_last=False)
        return train_loader, test_loader
    
    # === CIFAR-10-C (train clean, test corrupted) ===
    elif cfg.dataset == "cifar10-c":
        # Train on clean CIFAR-10
        train_ds = datasets.CIFAR10(root=cfg.data_root, train=True, transform=train_tf, download=True)
        train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                                  num_workers=cfg.num_workers, pin_memory=True, drop_last=False)
        
        # Test on corrupted CIFAR-10 (from local directory: ./data/cifar-10-c/)
        cifar10c_root = os.path.join(cfg.data_root, "cifar-10-c")
        if not os.path.exists(cifar10c_root):
            raise FileNotFoundError(
                f"CIFAR-10-C not found at {cifar10c_root}\n"
                f"Download from: https://zenodo.org/record/2535967\n"
                f"Extract to: {cfg.data_root}/cifar-10-c/"
            )
        
        # Load all corruptions at max severity (severity 5)
        test_ds = CIFAR10CDataset(root=cifar10c_root, transform=test_tf, severity=3)
        test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False,
                                 num_workers=test_num_workers, pin_memory=True, drop_last=False)
        return train_loader, test_loader
    
    # === CIFAR-100 ===
    elif cfg.dataset == "cifar100":
        train_ds = datasets.CIFAR100(root=cfg.data_root, train=True, transform=train_tf, download=True)
        test_ds = datasets.CIFAR100(root=cfg.data_root, train=False, transform=test_tf, download=True)
        train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                                  num_workers=cfg.num_workers, pin_memory=True, drop_last=False)
        test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False,
                                 num_workers=test_num_workers, pin_memory=True, drop_last=False)
        return train_loader, test_loader
    
    # === CIFAR-100-C (train clean, test corrupted) ===
    elif cfg.dataset == "cifar100-c":
        # Train on clean CIFAR-100
        train_ds = datasets.CIFAR100(root=cfg.data_root, train=True, transform=train_tf, download=True)
        train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                                  num_workers=cfg.num_workers, pin_memory=True, drop_last=False)
        
        # Test on corrupted CIFAR-100 (from local directory: ./data/cifar-100-c/)
        cifar100c_root = os.path.join(cfg.data_root, "cifar-100-c")
        if not os.path.exists(cifar100c_root):
            raise FileNotFoundError(
                f"CIFAR-100-C not found at {cifar100c_root}\n"
                f"Download from: https://zenodo.org/record/3555552\n"
                f"Extract to: {cfg.data_root}/cifar-100-c/"
            )
        
        # Load all corruptions at max severity (severity 5)
        test_ds = CIFAR100CDataset(root=cifar100c_root, transform=test_tf, severity=3)
        test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False,
                                 num_workers=test_num_workers, pin_memory=True, drop_last=False)
        return train_loader, test_loader
    
    # === Tiny ImageNet ===
    elif cfg.dataset == "tinyimagenet":
        tiny_root = download_tinyimagenet(cfg.data_root)
        train_dir = os.path.join(tiny_root, "train")
        val_dir = os.path.join(tiny_root, "val")
        
        def is_valid_file(x):
            return x.lower().endswith(('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'))
        
        train_ds = datasets.ImageFolder(root=train_dir, transform=train_tf, is_valid_file=is_valid_file)
        test_ds = datasets.ImageFolder(root=val_dir, transform=test_tf, is_valid_file=is_valid_file)
        
        train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                                  num_workers=cfg.num_workers, pin_memory=True, drop_last=False)
        test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False,
                                 num_workers=test_num_workers, pin_memory=True, drop_last=False)
        return train_loader, test_loader
    
    # === Tiny ImageNet-C (train clean, test corrupted) ===
    elif cfg.dataset == "tinyimagenet-c":
        return build_tinyimagenet_c_loaders(cfg)
    
    # === TissueMNIST ===
    elif cfg.dataset == "tissuemnist":
        from medmnist import TissueMNIST
        train_ds = TissueMNIST(root=cfg.data_root, split="train", transform=train_tf, download=True)
        test_ds = TissueMNIST(root=cfg.data_root, split="test", transform=test_tf, download=True)
        
        train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                                  num_workers=cfg.num_workers, pin_memory=True, drop_last=False)
        test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False,
                                 num_workers=test_num_workers, pin_memory=True, drop_last=False)
        return train_loader, test_loader
    
    else:
        raise ValueError(f"Unknown dataset: {cfg.dataset}")


In [59]:
# 6. model builder
def build_model(model_name: str, num_classes: int) -> nn.Module:
    if model_name == "shufflenetv2_0_5":
        m = models.shufflenet_v2_x0_5(weights=None)
        in_features = m.fc.in_features
        m.fc = nn.Linear(in_features, num_classes)
        return m

    if model_name == "mobilenetv3_small":
        m = models.mobilenet_v3_small(weights=None)
        in_features = m.classifier[-1].in_features
        m.classifier[-1] = nn.Linear(in_features, num_classes)
        return m

    if model_name == "efficientnetv2_s":
        m = models.efficientnet_v2_s(weights=None)
        in_features = m.classifier[-1].in_features
        m.classifier[-1] = nn.Linear(in_features, num_classes)
        return m

    raise ValueError(f"Unknown model: {model_name}")

In [60]:
# 7. training loop
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total = 0

    for x, y in tqdm(loader, desc="Train", leave=False):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True).long().view(-1)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()

        if cfg.max_grad_norm and cfg.max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)

        optimizer.step()

        with torch.no_grad():
            preds = logits.argmax(dim=1)
            total_correct += (preds == y).sum().item()
            total += y.numel()
            total_loss += loss.item() * y.numel()

    return total_loss / total, total_correct / total

def train_model(cfg: TrainConfig, model, train_loader):
    model = model.to(DEVICE)
    criterion = nn.CrossEntropyLoss(label_smoothing=cfg.label_smoothing)

    # --- model-specific LR override (EfficientNet only) ---
    effective_lr = cfg.lr
    if "efficientnet" in cfg.model_name.lower():
        effective_lr = cfg.lr_efficientnet

    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=effective_lr,
        momentum=cfg.momentum,
        weight_decay=cfg.weight_decay,
        nesterov=True
    )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs)

    history = []
    for epoch in range(1, cfg.epochs + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
        scheduler.step()
        history.append({"epoch": epoch, "train_loss": tr_loss, "train_acc": tr_acc, "lr": scheduler.get_last_lr()[0]})
        print(f"Epoch {epoch:03d}/{cfg.epochs} | train_loss={tr_loss:.4f} | train_acc={tr_acc:.4f} | lr={history[-1]['lr']:.6f}")

    return model, history

In [61]:
# 8. metrics: ece, risk-coverage + arc area, avuc loss
def expected_calibration_error(probs: np.ndarray, y_true: np.ndarray, n_bins: int = 15) -> float:
    """
    probs: (N, K) softmax probabilities
    y_true: (N,)
    """
    confidences = probs.max(axis=1)
    predictions = probs.argmax(axis=1)
    accuracies = (predictions == y_true).astype(np.float32)

    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    for i in range(n_bins):
        lo, hi = bin_edges[i], bin_edges[i + 1]
        mask = (confidences > lo) & (confidences <= hi) if i > 0 else (confidences >= lo) & (confidences <= hi)
        if mask.sum() == 0:
            continue
        bin_acc = accuracies[mask].mean()
        bin_conf = confidences[mask].mean()
        ece += (mask.sum() / len(y_true)) * abs(bin_acc - bin_conf)
    return float(ece)

def risk_coverage_curve(uncertainty: np.ndarray, correct: np.ndarray):
    """
    uncertainty: higher means more uncertain
    correct: 1 if correct prediction else 0

    Returns arrays coverage (ascending) and risk (= 1 - accuracy among kept).
    """
    order = np.argsort(uncertainty)  # keep least uncertain first
    correct_sorted = correct[order].astype(np.float32)

    n = len(correct_sorted)
    cum_correct = np.cumsum(correct_sorted)
    kept = np.arange(1, n + 1)
    acc_kept = cum_correct / kept
    risk = 1.0 - acc_kept
    coverage = kept / n
    return coverage, risk

def arc_area(uncertainty: np.ndarray, correct: np.ndarray) -> float:
    """
    Area under accuracy-coverage curve (kept least uncertain first).
    """
    coverage, risk = risk_coverage_curve(uncertainty, correct)
    acc = 1.0 - risk
    return float(np.trapezoid(acc, coverage))

def avuc_loss(danger: np.ndarray, correct: np.ndarray) -> float:
    """
    AvUC as defined in your thesis:
      AvUC = mean( | 1{correct} - (1 - D(x)) | )
    where D(x) is the danger score in [0,1] (higher = more dangerous / uncertain).
    Lower is better.
    """
    D = danger.astype(np.float64)
    c = correct.astype(np.float64)  # 1 if correct else 0
    conf = 1.0 - D
    return float(np.mean(np.abs(c - conf)))


In [62]:
# 9. core: unified test evaluation (one pass) + uncertainty signals
@torch.inference_mode()
def _forward_only(model, x):
    logits = model(x)
    probs = torch.softmax(logits, dim=1)
    return logits, probs

def evaluate_with_uncertainty(cfg: TrainConfig, model: nn.Module, test_loader: DataLoader):
    model = model.to(DEVICE)
    model.eval()

    # Freeze parameters explicitly so backward does not build grads for weights
    for p in model.parameters():
        p.requires_grad_(False)

    criterion = nn.CrossEntropyLoss(reduction="mean")

    all_y_true = []
    all_y_pred = []
    all_probs = []
    all_entropy = []
    all_gradnorm = []

    all_per_sample_loss = []

    total_loss = 0.0
    total_correct = 0
    total = 0
    
    batch_count = 0
    total_batches = len(test_loader)

    # diagnostics
    max_abs_logit = 0.0
    saw_nan = False
    saw_inf = False
    batch_diagnostics = []  # optional per-batch summary records

    for x, y in tqdm(test_loader, desc="Test+Uncertainty", leave=False):
        batch_count += 1
        
        # labels
        y = y.to(DEVICE, non_blocking=True).long().view(-1)

        # Inputs: need grad for gradient sensitivity
        x = x.to(DEVICE, non_blocking=True)
        x = x.detach()
        x.requires_grad_(True)

        # Forward (with grad enabled, since we need input gradients)
        logits = model(x)
        probs = torch.softmax(logits, dim=1)

        # Baseline loss and accuracy
        loss = criterion(logits, y)

        # Per-sample cross-entropy (for tail diagnostics)
        with torch.no_grad():
            per_sample_loss = F.cross_entropy(logits, y, reduction="none")  # shape (B,)
            all_per_sample_loss.append(per_sample_loss.detach().cpu().numpy())

        preds = probs.argmax(dim=1)
        correct = (preds == y)

        total_loss += loss.item() * y.numel() 
        total_correct += correct.sum().item()
        total += y.numel()

        # Entropy per sample
        entropy = -(probs * torch.log(probs.clamp_min(1e-12))).sum(dim=1)

        # Gradient sensitivity: grad of predicted-class logit w.r.t input
        # Pick the logit corresponding to predicted class for each sample
        idx = torch.arange(logits.size(0), device=DEVICE)
        selected = logits[idx, preds]  # shape: (B,)

        # Backward: we want d(selected)/dx for each sample
        # Use torch.autograd.grad to avoid storing param grads
        grads = torch.autograd.grad(
            outputs=selected.sum(),
            inputs=x,
            create_graph=False,
            retain_graph=False,
            only_inputs=True
        )[0]  # shape: (B, C, H, W)

        grad_norm = grads.view(grads.size(0), -1).norm(p=2, dim=1)

        # --- diagnostics checks (new) ---
        with torch.no_grad():
            logits_abs_max_batch = float(logits.abs().max().cpu().item())
            logits_mean_batch = float(logits.mean().cpu().item())
            loss_val = float(loss.item())
            batch_avg_loss = loss_val

            nan_in_logits = bool(torch.isnan(logits).any().item())
            inf_in_logits = bool(torch.isinf(logits).any().item())
            nan_in_probs = bool(torch.isnan(probs).any().item())
            inf_in_probs = bool(torch.isinf(probs).any().item())

            saw_nan = saw_nan or nan_in_logits or nan_in_probs
            saw_inf = saw_inf or inf_in_logits or inf_in_probs
            if logits_abs_max_batch > max_abs_logit:
                max_abs_logit = logits_abs_max_batch

            # record a compact batch summary (useful for debugging or saving)
            batch_diagnostics.append({
                "batch": batch_count,
                "loss_mean": loss_val,
                "batch_avg_loss": batch_avg_loss,
                "logits_max_abs": logits_abs_max_batch,
                "logits_mean": logits_mean_batch,
                "nan": nan_in_logits or nan_in_probs,
                "inf": inf_in_logits or inf_in_probs
            })

            # Print occasional progress diagnostics to help locate problem batches
            if batch_count % 50 == 0 or nan_in_logits or inf_in_logits:
                print(f"[BATCH {batch_count}/{total_batches}] loss_sum={loss_val:.4e} "
                      f"batch_avg_loss={batch_avg_loss:.4e} logits_max_abs={logits_abs_max_batch:.4e} "
                      f"logits_mean={logits_mean_batch:.4e} nan={nan_in_logits or nan_in_probs} inf={inf_in_logits or inf_in_probs}")

        # Store batch results
        all_y_true.append(y.detach().cpu().numpy())
        all_y_pred.append(preds.detach().cpu().numpy())
        all_probs.append(probs.detach().cpu().numpy())
        all_entropy.append(entropy.detach().cpu().numpy())
        all_gradnorm.append(grad_norm.detach().cpu().numpy())

    # concat results
    y_true = np.concatenate(all_y_true)
    y_pred = np.concatenate(all_y_pred)
    probs  = np.concatenate(all_probs)
    entropy = np.concatenate(all_entropy)
    gradnorm = np.concatenate(all_gradnorm)

    per_sample_loss_all = np.concatenate(all_per_sample_loss)

    print("=== LOSS TAIL DIAGNOSTICS ===")
    print(f"loss_median={np.median(per_sample_loss_all):.4f}")
    print(f"loss_p95={np.percentile(per_sample_loss_all, 95):.4f}")
    print(f"loss_p99={np.percentile(per_sample_loss_all, 99):.4f}")
    print(f"loss_max={np.max(per_sample_loss_all):.4f}")

    test_loss = total_loss / total
    test_acc = total_correct / total

    # "Accuracy Before Rejection" must match test_acc
    correct_vec = (y_pred == y_true).astype(np.int32)

    # final diagnostics print
    print("=== EVAL DIAGNOSTICS ===")
    print(f"total_samples={total} total_loss_sum={total_loss:.4e} test_loss_avg={test_loss:.4e} test_acc={test_acc:.4f}")
    print(f"max_abs_logit={max_abs_logit:.4e} saw_nan={saw_nan} saw_inf={saw_inf}")

    diagnostics = {
        "max_abs_logit": float(max_abs_logit),
        "saw_nan": bool(saw_nan),
        "saw_inf": bool(saw_inf),
        "batch_diagnostics_sample": batch_diagnostics[:10]  # keep small preview
    }

    return {
        "test_loss": float(test_loss),
        "test_acc": float(test_acc),
        "y_true": y_true,
        "y_pred": y_pred,
        "probs": probs,
        "entropy": entropy,
        "gradnorm": gradnorm,
        "correct": correct_vec,
        "diagnostics": diagnostics
    }


In [63]:
# 10. hybrid score + uncertainty evaluation + percentile rejection + reliability bins + histograms
def minmax_norm(x: np.ndarray) -> np.ndarray:
    lo, hi = float(np.min(x)), float(np.max(x))
    if hi - lo < 1e-12:
        return np.zeros_like(x, dtype=np.float32)
    return ((x - lo) / (hi - lo)).astype(np.float32)

def compute_uncertainty_metrics(method_scores: Dict[str, np.ndarray], eval_pack: Dict, cfg: TrainConfig):
    y_true = eval_pack["y_true"]
    y_pred = eval_pack["y_pred"]
    probs  = eval_pack["probs"]
    correct = eval_pack["correct"]

    errors = 1 - correct  # 1 if incorrect, 0 if correct

    results = {}

    # ECE uses probabilities (common baseline confidence calibration)
    ece = expected_calibration_error(probs, y_true, n_bins=15)

    for name, score in method_scores.items():
        # AUROC for detecting errors: higher uncertainty should indicate error
        # If all errors are same class, roc_auc_score can fail; handle safely
        try:
            auroc = roc_auc_score(errors, score)
        except ValueError:
            auroc = float("nan")

        # ARC area from risk-coverage curve (keep least uncertain first)
        arc = arc_area(score, correct)
        avuc = avuc_loss(score, correct)

        results[name] = {
            "AUROC_error": float(auroc),
            "ECE": float(ece),
            "ARC_area": float(arc),
            "AvUC": float(avuc),
        }

    return results

def percentile_rejection_table(method_scores: Dict[str, np.ndarray], eval_pack: Dict, cfg: TrainConfig):
    y_true = eval_pack["y_true"]
    y_pred = eval_pack["y_pred"]
    correct = (y_pred == y_true).astype(np.int32)
    base_acc = float(correct.mean())

    table = []
    N = len(correct)

    for name, score in method_scores.items():
        for p in cfg.reject_percentiles:
            reject_n = int(round(N * (p / 100.0)))
            # Reject most uncertain => sort descending
            order = np.argsort(score)[::-1]
            reject_idx = order[:reject_n]
            keep_mask = np.ones(N, dtype=bool)
            keep_mask[reject_idx] = False

            kept = keep_mask.sum()
            acc_after = float(correct[keep_mask].mean()) if kept > 0 else float("nan")

            table.append({
                "method": name,
                "reject_percent": p,
                "rejection_rate": float(reject_n / N),
                "accuracy_before_rejection": base_acc,
                "accuracy_after_rejection": acc_after,
                "kept_count": int(kept),
            })

    return table

def reliability_bins_from_probs(probs: np.ndarray, y_true: np.ndarray, n_bins: int = 15):
    conf = probs.max(axis=1)
    pred = probs.argmax(axis=1)
    correct = (pred == y_true).astype(np.int32)

    edges = np.linspace(0.0, 1.0, n_bins + 1)
    bin_ids = np.digitize(conf, edges, right=True) - 1
    bin_ids = np.clip(bin_ids, 0, n_bins - 1)

    bin_conf = np.zeros(n_bins, dtype=np.float32)
    bin_acc = np.zeros(n_bins, dtype=np.float32)
    bin_count = np.zeros(n_bins, dtype=np.int32)

    for b in range(n_bins):
        m = (bin_ids == b)
        bin_count[b] = int(m.sum())
        if bin_count[b] > 0:
            bin_conf[b] = float(conf[m].mean())
            bin_acc[b] = float(correct[m].mean())
        else:
            bin_conf[b] = np.nan
            bin_acc[b] = np.nan

    return {
        "n_bins": int(n_bins),
        "edges": edges.tolist(),
        "bin_conf": bin_conf.tolist(),
        "bin_acc": bin_acc.tolist(),
        "bin_count": bin_count.tolist(),
    }

def score_histogram_by_correct(score: np.ndarray, correct: np.ndarray, n_bins: int = 30, lo: float = 0.0, hi: float = 1.0):
    score = np.clip(score, lo, hi)
    edges = np.linspace(lo, hi, n_bins + 1)

    correct_mask = correct.astype(bool)
    incorrect_mask = ~correct_mask

    c_counts, _ = np.histogram(score[correct_mask], bins=edges)
    i_counts, _ = np.histogram(score[incorrect_mask], bins=edges)

    return {
        "n_bins": int(n_bins),
        "edges": edges.tolist(),
        "correct_counts": c_counts.astype(int).tolist(),
        "incorrect_counts": i_counts.astype(int).tolist(),
    }

In [64]:
# 11. run one full experiment (train → evaluate → save results)
def run_experiment(cfg: TrainConfig):
    # Clear GPU cache at start
    torch.cuda.empty_cache()
    set_seed(cfg.seed)
    
    num_classes = get_num_classes(cfg.dataset)
    train_loader, test_loader = build_loaders(cfg)

    model = build_model(cfg.model_name, num_classes)

    print("Training configuration:")
    print(json.dumps(asdict(cfg), indent=2))

    # For corrupted datasets, try to load pre-trained clean model
    tag = f"{cfg.dataset}_{cfg.model_name}_e{cfg.epochs}_bs{cfg.batch_size}"
    ckpt_path = os.path.join(cfg.out_dir, f"{tag}.pth")
    
    # Map corrupted dataset names to their clean counterparts
    clean_dataset_map = {
        "cifar10-c": "cifar10",
        "cifar100-c": "cifar100",
        "tinyimagenet-c": "tinyimagenet"
    }
    
    model_loaded = False
    if cfg.dataset in clean_dataset_map:
        # This is a corrupted dataset - try to load clean model
        clean_dataset = clean_dataset_map[cfg.dataset]
        clean_tag = f"{clean_dataset}_{cfg.model_name}_e{cfg.epochs}_bs{cfg.batch_size}"
        clean_ckpt_path = os.path.join(cfg.out_dir, f"{clean_tag}.pth")
        
        if os.path.exists(clean_ckpt_path):
            print(f"\n✓ Loading pre-trained model from {clean_ckpt_path}")
            model.load_state_dict(torch.load(clean_ckpt_path, map_location=DEVICE))
            model_loaded = True
            train_hist = []  # No training history since we skipped training
        else:
            print(f"\n⚠ Pre-trained model not found at {clean_ckpt_path}")
            print(f"   Training from scratch on clean {clean_dataset} data...")
    
    # Train if model wasn't loaded
    if not model_loaded:
        if os.path.exists(ckpt_path):
            print(f"\n✓ Loading existing model from {ckpt_path}")
            model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
            train_hist = []  # No training history since we loaded existing model
        else:
            print("\nTraining model from scratch...")
            model, train_hist = train_model(cfg, model, train_loader)
            # Save final model
            torch.save(model.state_dict(), ckpt_path)
            print(f"Saved checkpoint: {ckpt_path}")
    else:
        # For corrupted datasets, still save under the corrupted name for reference
        torch.save(model.state_dict(), ckpt_path)
        print(f"Saved checkpoint: {ckpt_path}")

    # Unified evaluation with uncertainty
    eval_pack = evaluate_with_uncertainty(cfg, model, test_loader)

    # Build uncertainty scores
    entropy = eval_pack["entropy"]
    grad = eval_pack["gradnorm"]

    # Entropy: normalize by ln(num_classes) to get [0,1] scale
    entropy_n = np.clip(entropy / np.log(num_classes), 0.0, 1.0).astype(np.float32)

    # Gradient norm: min-max normalize to [0,1]
    grad_n = minmax_norm(grad).astype(np.float32)

    hybrid = (cfg.hybrid_weight_entropy * entropy_n) + (cfg.hybrid_weight_grad * grad_n)

    method_scores = {
        "entropy": entropy_n,
        "gradient": grad_n,
        "hybrid": hybrid.astype(np.float32)
    }

    # Metrics
    metrics = compute_uncertainty_metrics(method_scores, eval_pack, cfg)
    rej_table = percentile_rejection_table(method_scores, eval_pack, cfg)

    # reliability + score distributions
    reliability_summary = reliability_bins_from_probs(
        probs=eval_pack["probs"],
        y_true=eval_pack["y_true"],
        n_bins=15
    )

    score_distributions = {
        name: score_histogram_by_correct(score, eval_pack["correct"], n_bins=30, lo=0.0, hi=1.0)
        for name, score in method_scores.items()
    }

    # Consistency check
    base_acc = eval_pack["test_acc"]
    acc_before = float((eval_pack["y_pred"] == eval_pack["y_true"]).mean())
    print(f"\nConsistency check: test_acc={base_acc:.6f} vs accuracy_before_rejection={acc_before:.6f}")

    # Save artifacts
    out = {
        "config": asdict(cfg),
        "train_history": train_hist,
        "baseline": {"test_acc": eval_pack["test_acc"], "test_loss": eval_pack["test_loss"]},
        "uncertainty_metrics": metrics,
        "percentile_rejection": rej_table,
        "reliability": reliability_summary,              
        "score_distributions": score_distributions,     
        "checkpoint_path": ckpt_path
    }

    out_path = os.path.join(cfg.out_dir, f"{tag}_results.json")
    with open(out_path, "w") as f:
        json.dump(out, f, indent=2)
    print(f"Saved results JSON: {out_path}")

    # Clear GPU cache at end
    torch.cuda.empty_cache()
    
    return out


In [65]:
# 12. Define dataset-model pairs to run
dataset_model_pairs = [
    # Clean datasets
    # ("cifar10", "shufflenetv2_0_5"),
    # ("cifar10", "mobilenetv3_small"),
    # ("cifar10", "efficientnetv2_s"),
    # ("cifar100", "shufflenetv2_0_5"),
    # ("cifar100", "mobilenetv3_small"),
    # ("cifar100", "efficientnetv2_s"),
    # ("tinyimagenet", "shufflenetv2_0_5"),
    # ("tinyimagenet", "mobilenetv3_small"),
    # ("tinyimagenet", "efficientnetv2_s"),
    # ("tissuemnist", "shufflenetv2_0_5"),
    # ("tissuemnist", "mobilenetv3_small"),
    ("tissuemnist", "efficientnetv2_s"),
    
    # Corrupted datasets
    ("cifar10-c", "shufflenetv2_0_5"),
    ("cifar10-c", "mobilenetv3_small"),
    ("cifar10-c", "efficientnetv2_s"),
    ("cifar100-c", "shufflenetv2_0_5"),
    ("cifar100-c", "mobilenetv3_small"),
    ("cifar100-c", "efficientnetv2_s"),
    ("tinyimagenet-c", "shufflenetv2_0_5"),
    ("tinyimagenet-c", "mobilenetv3_small"),
    ("tinyimagenet-c", "efficientnetv2_s"),
]

print(f"Running {len(dataset_model_pairs)} experiments:")
for i, (ds, model) in enumerate(dataset_model_pairs, 1):
    print(f"  {i}. {ds:15} + {model:20}")

Running 10 experiments:
  1. tissuemnist     + efficientnetv2_s    
  2. cifar10-c       + shufflenetv2_0_5    
  3. cifar10-c       + mobilenetv3_small   
  4. cifar10-c       + efficientnetv2_s    
  5. cifar100-c      + shufflenetv2_0_5    
  6. cifar100-c      + mobilenetv3_small   
  7. cifar100-c      + efficientnetv2_s    
  8. tinyimagenet-c  + shufflenetv2_0_5    
  9. tinyimagenet-c  + mobilenetv3_small   
  10. tinyimagenet-c  + efficientnetv2_s    


In [66]:
# 13. Run all experiments in sequence
all_results = []
start_time = time.time()

for idx, (dataset, model_name) in enumerate(dataset_model_pairs, 1):
    print(f"\n{'='*80}")
    print(f"Experiment {idx}/{len(dataset_model_pairs)}: {dataset} + {model_name}")
    print(f"{'='*80}\n")
    
    # Create a fresh config for this pair
    experiment_cfg = TrainConfig(
        dataset=dataset,
        model_name=model_name,
        data_root=cfg.data_root,
        tinyimagenet_root=cfg.tinyimagenet_root,
        epochs=cfg.epochs,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
        momentum=cfg.momentum,
        seed=cfg.seed,
        reject_percentiles=cfg.reject_percentiles,
        hybrid_weight_entropy=cfg.hybrid_weight_entropy,
        hybrid_weight_grad=cfg.hybrid_weight_grad,
        out_dir=cfg.out_dir,
    )
    
    try:
        result = run_experiment(experiment_cfg)
        all_results.append({
            "dataset": dataset,
            "model": model_name,
            "status": "success",
            "result": result
        })
    except Exception as e:
        print(f"\n❌ Error in experiment {idx}: {e}")
        all_results.append({
            "dataset": dataset,
            "model": model_name,
            "status": "failed",
            "error": str(e)
        })

elapsed = time.time() - start_time
print(f"\n{'='*80}")
print(f"All experiments completed in {elapsed:.1f} seconds ({elapsed/60:.1f} minutes)")
print(f"{'='*80}\n")

# Summary
successful = sum(1 for r in all_results if r["status"] == "success")
failed = sum(1 for r in all_results if r["status"] == "failed")
print(f"Summary: {successful} successful, {failed} failed")
for r in all_results:
    status_icon = "✓" if r["status"] == "success" else "✗"
    print(f"  {status_icon} {r['dataset']:15} + {r['model']:20}")



Experiment 1/10: tissuemnist + efficientnetv2_s

Training configuration:
{
  "dataset": "tissuemnist",
  "model_name": "efficientnetv2_s",
  "data_root": "./data",
  "tinyimagenet_root": "./tiny-imagenet-200",
  "epochs": 50,
  "batch_size": 64,
  "num_workers": 0,
  "lr": 0.05,
  "weight_decay": 0.0001,
  "momentum": 0.9,
  "seed": 42,
  "reject_percentiles": [
    10,
    20,
    30,
    40,
    50
  ],
  "hybrid_weight_entropy": 0.5,
  "hybrid_weight_grad": 0.5,
  "out_dir": "./outputs"
}

Training model from scratch...


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 001/50 | train_loss=1.6545 | train_acc=0.4273 | lr=0.049951


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 002/50 | train_loss=1.3968 | train_acc=0.4862 | lr=0.049803


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 003/50 | train_loss=1.3749 | train_acc=0.4877 | lr=0.049557


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 004/50 | train_loss=1.3406 | train_acc=0.4976 | lr=0.049215


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 005/50 | train_loss=1.3458 | train_acc=0.5014 | lr=0.048776


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 006/50 | train_loss=1.2876 | train_acc=0.5182 | lr=0.048244


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 007/50 | train_loss=1.3809 | train_acc=0.4835 | lr=0.047621


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 008/50 | train_loss=1.3747 | train_acc=0.4899 | lr=0.046908


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 009/50 | train_loss=1.3063 | train_acc=0.5124 | lr=0.046108


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 010/50 | train_loss=1.2539 | train_acc=0.5300 | lr=0.045225


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 011/50 | train_loss=1.2301 | train_acc=0.5381 | lr=0.044263


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 012/50 | train_loss=1.1751 | train_acc=0.5584 | lr=0.043224


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 013/50 | train_loss=1.1367 | train_acc=0.5760 | lr=0.042114


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 014/50 | train_loss=1.1054 | train_acc=0.5897 | lr=0.040936


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 015/50 | train_loss=1.0744 | train_acc=0.6030 | lr=0.039695


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 016/50 | train_loss=1.0521 | train_acc=0.6129 | lr=0.038396


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 017/50 | train_loss=1.0325 | train_acc=0.6192 | lr=0.037044


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 018/50 | train_loss=1.0140 | train_acc=0.6274 | lr=0.035644


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 019/50 | train_loss=0.9965 | train_acc=0.6352 | lr=0.034203


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 020/50 | train_loss=0.9832 | train_acc=0.6392 | lr=0.032725


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 021/50 | train_loss=0.9696 | train_acc=0.6447 | lr=0.031217


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 022/50 | train_loss=0.9557 | train_acc=0.6492 | lr=0.029685


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 023/50 | train_loss=0.9415 | train_acc=0.6546 | lr=0.028133


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 024/50 | train_loss=0.9300 | train_acc=0.6593 | lr=0.026570


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 025/50 | train_loss=0.9186 | train_acc=0.6646 | lr=0.025000


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 026/50 | train_loss=0.9058 | train_acc=0.6690 | lr=0.023430


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 027/50 | train_loss=0.8939 | train_acc=0.6727 | lr=0.021867


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 028/50 | train_loss=0.8840 | train_acc=0.6768 | lr=0.020315


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 029/50 | train_loss=0.8701 | train_acc=0.6813 | lr=0.018783


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 030/50 | train_loss=0.8597 | train_acc=0.6861 | lr=0.017275


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 031/50 | train_loss=0.8460 | train_acc=0.6905 | lr=0.015797


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 032/50 | train_loss=0.8340 | train_acc=0.6965 | lr=0.014356


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 033/50 | train_loss=0.8207 | train_acc=0.7002 | lr=0.012956


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 034/50 | train_loss=0.8058 | train_acc=0.7081 | lr=0.011604


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 035/50 | train_loss=0.7918 | train_acc=0.7115 | lr=0.010305


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 036/50 | train_loss=0.7779 | train_acc=0.7163 | lr=0.009064


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 037/50 | train_loss=0.7607 | train_acc=0.7231 | lr=0.007886


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 038/50 | train_loss=0.7437 | train_acc=0.7295 | lr=0.006776


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 039/50 | train_loss=0.7254 | train_acc=0.7371 | lr=0.005737


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 040/50 | train_loss=0.7056 | train_acc=0.7439 | lr=0.004775


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 041/50 | train_loss=0.6865 | train_acc=0.7508 | lr=0.003892


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 042/50 | train_loss=0.6631 | train_acc=0.7602 | lr=0.003092


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 043/50 | train_loss=0.6404 | train_acc=0.7690 | lr=0.002379


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 044/50 | train_loss=0.6175 | train_acc=0.7777 | lr=0.001756


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 045/50 | train_loss=0.5931 | train_acc=0.7878 | lr=0.001224


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 046/50 | train_loss=0.5686 | train_acc=0.7963 | lr=0.000785


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 047/50 | train_loss=0.5490 | train_acc=0.8038 | lr=0.000443


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 048/50 | train_loss=0.5299 | train_acc=0.8120 | lr=0.000197


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 049/50 | train_loss=0.5160 | train_acc=0.8165 | lr=0.000049


Train:   0%|          | 0/2586 [00:00<?, ?it/s]

Epoch 050/50 | train_loss=0.5086 | train_acc=0.8207 | lr=0.000000
Saved checkpoint: ./outputs\tissuemnist_efficientnetv2_s_e50_bs64.pth


Test+Uncertainty:   0%|          | 0/739 [00:00<?, ?it/s]

[BATCH 50/739] loss_sum=6.8287e+01 batch_avg_loss=1.0670e+00 logits_max_abs=1.0697e+01 logits_mean=-3.1021e-04 nan=False inf=False
[BATCH 100/739] loss_sum=5.9278e+01 batch_avg_loss=9.2622e-01 logits_max_abs=1.1182e+01 logits_mean=-1.8586e-04 nan=False inf=False
[BATCH 150/739] loss_sum=6.1283e+01 batch_avg_loss=9.5755e-01 logits_max_abs=1.1644e+01 logits_mean=-1.6010e-05 nan=False inf=False
[BATCH 200/739] loss_sum=5.2886e+01 batch_avg_loss=8.2634e-01 logits_max_abs=1.2195e+01 logits_mean=-2.0765e-04 nan=False inf=False
[BATCH 250/739] loss_sum=9.8172e+01 batch_avg_loss=1.5339e+00 logits_max_abs=1.2886e+01 logits_mean=-2.3213e-04 nan=False inf=False
[BATCH 300/739] loss_sum=6.4334e+01 batch_avg_loss=1.0052e+00 logits_max_abs=9.6771e+00 logits_mean=-1.8001e-04 nan=False inf=False
[BATCH 350/739] loss_sum=7.1567e+01 batch_avg_loss=1.1182e+00 logits_max_abs=1.2160e+01 logits_mean=-1.2834e-04 nan=False inf=False
[BATCH 400/739] loss_sum=6.2153e+01 batch_avg_loss=9.7114e-01 logits_max_abs=

Test+Uncertainty:   0%|          | 0/2969 [00:00<?, ?it/s]

[BATCH 50/2969] loss_sum=9.8565e+02 batch_avg_loss=1.5401e+01 logits_max_abs=1.9852e+01 logits_mean=2.4665e-02 nan=False inf=False
[BATCH 100/2969] loss_sum=9.4573e+02 batch_avg_loss=1.4777e+01 logits_max_abs=1.8122e+01 logits_mean=2.2849e-02 nan=False inf=False
[BATCH 150/2969] loss_sum=1.0170e+03 batch_avg_loss=1.5891e+01 logits_max_abs=1.7245e+01 logits_mean=2.3681e-02 nan=False inf=False
[BATCH 200/2969] loss_sum=1.0158e+03 batch_avg_loss=1.5872e+01 logits_max_abs=2.4342e+01 logits_mean=2.1999e-02 nan=False inf=False
[BATCH 250/2969] loss_sum=1.1173e+03 batch_avg_loss=1.7458e+01 logits_max_abs=1.7731e+01 logits_mean=2.3233e-02 nan=False inf=False
[BATCH 300/2969] loss_sum=9.6484e+02 batch_avg_loss=1.5076e+01 logits_max_abs=2.1500e+01 logits_mean=2.2441e-02 nan=False inf=False
[BATCH 350/2969] loss_sum=1.0327e+03 batch_avg_loss=1.6136e+01 logits_max_abs=2.0634e+01 logits_mean=2.8022e-02 nan=False inf=False
[BATCH 400/2969] loss_sum=1.0332e+03 batch_avg_loss=1.6143e+01 logits_max_abs

Test+Uncertainty:   0%|          | 0/2969 [00:00<?, ?it/s]

[BATCH 50/2969] loss_sum=2.3081e+02 batch_avg_loss=3.6063e+00 logits_max_abs=7.2321e+00 logits_mean=9.2990e-05 nan=False inf=False
[BATCH 100/2969] loss_sum=2.1384e+02 batch_avg_loss=3.3413e+00 logits_max_abs=6.1923e+00 logits_mean=-8.3705e-05 nan=False inf=False
[BATCH 150/2969] loss_sum=2.1575e+02 batch_avg_loss=3.3710e+00 logits_max_abs=5.0071e+00 logits_mean=-1.9174e-04 nan=False inf=False
[BATCH 200/2969] loss_sum=2.5237e+02 batch_avg_loss=3.9432e+00 logits_max_abs=8.9917e+00 logits_mean=-5.9755e-03 nan=False inf=False
[BATCH 250/2969] loss_sum=2.7487e+02 batch_avg_loss=4.2949e+00 logits_max_abs=6.1451e+00 logits_mean=-7.8064e-03 nan=False inf=False
[BATCH 300/2969] loss_sum=2.5639e+02 batch_avg_loss=4.0060e+00 logits_max_abs=9.1263e+00 logits_mean=-7.4939e-03 nan=False inf=False
[BATCH 350/2969] loss_sum=2.3705e+02 batch_avg_loss=3.7040e+00 logits_max_abs=5.3618e+00 logits_mean=-5.1426e-03 nan=False inf=False
[BATCH 400/2969] loss_sum=2.2940e+02 batch_avg_loss=3.5843e+00 logits_m

Test+Uncertainty:   0%|          | 0/2969 [00:00<?, ?it/s]

[BATCH 50/2969] loss_sum=4.6076e+05 batch_avg_loss=7.1994e+03 logits_max_abs=1.4193e+05 logits_mean=-6.7399e+01 nan=False inf=False
[BATCH 100/2969] loss_sum=3.8004e+05 batch_avg_loss=5.9381e+03 logits_max_abs=1.2251e+05 logits_mean=-6.4230e+01 nan=False inf=False
[BATCH 150/2969] loss_sum=6.8041e+05 batch_avg_loss=1.0631e+04 logits_max_abs=1.2499e+05 logits_mean=-1.0014e+02 nan=False inf=False
[BATCH 200/2969] loss_sum=1.4023e+06 batch_avg_loss=2.1911e+04 logits_max_abs=2.2807e+05 logits_mean=-2.0846e+02 nan=False inf=False
[BATCH 250/2969] loss_sum=7.4174e+05 batch_avg_loss=1.1590e+04 logits_max_abs=1.5467e+05 logits_mean=-1.1269e+02 nan=False inf=False
[BATCH 300/2969] loss_sum=6.9903e+05 batch_avg_loss=1.0922e+04 logits_max_abs=2.5956e+05 logits_mean=-1.1296e+02 nan=False inf=False
[BATCH 350/2969] loss_sum=6.6254e+05 batch_avg_loss=1.0352e+04 logits_max_abs=1.4663e+05 logits_mean=-9.6506e+01 nan=False inf=False
[BATCH 400/2969] loss_sum=5.4487e+05 batch_avg_loss=8.5136e+03 logits_

Test+Uncertainty:   0%|          | 0/2969 [00:00<?, ?it/s]

[BATCH 50/2969] loss_sum=9.3468e+02 batch_avg_loss=1.4604e+01 logits_max_abs=4.1375e+01 logits_mean=6.9806e-03 nan=False inf=False
[BATCH 100/2969] loss_sum=9.8477e+02 batch_avg_loss=1.5387e+01 logits_max_abs=5.9310e+01 logits_mean=8.4915e-03 nan=False inf=False
[BATCH 150/2969] loss_sum=8.8699e+02 batch_avg_loss=1.3859e+01 logits_max_abs=6.7738e+01 logits_mean=7.1796e-03 nan=False inf=False
[BATCH 200/2969] loss_sum=7.8643e+02 batch_avg_loss=1.2288e+01 logits_max_abs=2.8275e+01 logits_mean=-2.2349e-03 nan=False inf=False
[BATCH 250/2969] loss_sum=7.7311e+02 batch_avg_loss=1.2080e+01 logits_max_abs=2.6719e+01 logits_mean=-2.5569e-03 nan=False inf=False
[BATCH 300/2969] loss_sum=8.7779e+02 batch_avg_loss=1.3715e+01 logits_max_abs=4.2320e+01 logits_mean=-9.7475e-04 nan=False inf=False
[BATCH 350/2969] loss_sum=8.7321e+02 batch_avg_loss=1.3644e+01 logits_max_abs=2.5628e+01 logits_mean=1.1592e-03 nan=False inf=False
[BATCH 400/2969] loss_sum=9.2151e+02 batch_avg_loss=1.4399e+01 logits_max_

Test+Uncertainty:   0%|          | 0/2969 [00:00<?, ?it/s]

[BATCH 50/2969] loss_sum=5.6830e+02 batch_avg_loss=8.8797e+00 logits_max_abs=1.0569e+01 logits_mean=8.7137e-03 nan=False inf=False
[BATCH 100/2969] loss_sum=5.9107e+02 batch_avg_loss=9.2355e+00 logits_max_abs=1.0596e+01 logits_mean=8.4416e-03 nan=False inf=False
[BATCH 150/2969] loss_sum=5.5030e+02 batch_avg_loss=8.5984e+00 logits_max_abs=1.0550e+01 logits_mean=9.0808e-03 nan=False inf=False
[BATCH 200/2969] loss_sum=7.0517e+02 batch_avg_loss=1.1018e+01 logits_max_abs=1.5127e+01 logits_mean=4.0685e-03 nan=False inf=False
[BATCH 250/2969] loss_sum=6.5417e+02 batch_avg_loss=1.0221e+01 logits_max_abs=1.4442e+01 logits_mean=2.3311e-03 nan=False inf=False
[BATCH 300/2969] loss_sum=6.8757e+02 batch_avg_loss=1.0743e+01 logits_max_abs=1.4123e+01 logits_mean=3.4816e-03 nan=False inf=False
[BATCH 350/2969] loss_sum=6.4316e+02 batch_avg_loss=1.0049e+01 logits_max_abs=1.1954e+01 logits_mean=5.5866e-03 nan=False inf=False
[BATCH 400/2969] loss_sum=6.2022e+02 batch_avg_loss=9.6909e+00 logits_max_abs

Test+Uncertainty:   0%|          | 0/2969 [00:00<?, ?it/s]

[BATCH 50/2969] loss_sum=3.1085e+06 batch_avg_loss=4.8570e+04 logits_max_abs=7.9318e+05 logits_mean=-5.9052e+01 nan=False inf=False
[BATCH 100/2969] loss_sum=3.0227e+06 batch_avg_loss=4.7229e+04 logits_max_abs=1.0868e+06 logits_mean=-5.6801e+01 nan=False inf=False
[BATCH 150/2969] loss_sum=1.9194e+06 batch_avg_loss=2.9990e+04 logits_max_abs=7.6977e+05 logits_mean=-3.5234e+01 nan=False inf=False
[BATCH 200/2969] loss_sum=1.9906e+06 batch_avg_loss=3.1103e+04 logits_max_abs=1.0546e+06 logits_mean=-3.5388e+01 nan=False inf=False
[BATCH 250/2969] loss_sum=7.5457e+05 batch_avg_loss=1.1790e+04 logits_max_abs=4.2756e+05 logits_mean=-1.0982e+01 nan=False inf=False
[BATCH 300/2969] loss_sum=5.6538e+05 batch_avg_loss=8.8341e+03 logits_max_abs=4.7545e+05 logits_mean=-1.0307e+01 nan=False inf=False
[BATCH 350/2969] loss_sum=1.5501e+06 batch_avg_loss=2.4221e+04 logits_max_abs=5.1885e+05 logits_mean=-4.0184e+01 nan=False inf=False
[BATCH 400/2969] loss_sum=2.6313e+06 batch_avg_loss=4.1114e+04 logits_

KeyboardInterrupt: 

In [None]:
# 13a. DIAGNOSTIC: Quick test - where is it hanging?

print("DIAGNOSTIC TEST - Finding the bottleneck")
print("=" * 80)

# Step 1: Test CIFAR-10-C dataset loading (this is likely the culprit)
print("\n[STEP 1] Testing CIFAR-10-C dataset loading...")
t0 = time.time()

cifar10c_root = "./data/Tiny-ImageNet-C"
print(f"Looking for files in: {cifar10c_root}")

if os.path.exists(cifar10c_root):
    files = os.listdir(cifar10c_root)
    print(f"Found files: {files}")
    
    # Check one file
    data_file = os.path.join(cifar10c_root, "gaussian_noise.npy")
    if os.path.exists(data_file):
        print(f"\nLoading {data_file}...")
        t1 = time.time()
        images = np.load(data_file)
        elapsed = time.time() - t1
        print(f"✓ Loaded in {elapsed:.1f}s, shape: {images.shape}")
        
        # Try to slice it
        print(f"\nSlicing to severity 5 (rows 40000:50000)...")
        t1 = time.time()
        subset = images[40000:50000]
        elapsed = time.time() - t1
        print(f"✓ Sliced in {elapsed:.3f}s, shape: {subset.shape}")
    else:
        print(f"✗ File not found: {data_file}")
else:
    print(f"✗ Directory not found: {cifar10c_root}")

print("\n" + "=" * 80)
print("TOTAL DIAGNOSTIC TIME:", time.time() - t0, "seconds")
print("=" * 80)
