In [None]:
import os
import numpy as np
from typing import List, Dict, Optional, Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
from torchvision.transforms import v2

import timm


# -------------------------------------------------
# Dataset: CIFAR-100-C (.npy corruption benchmark)
# -------------------------------------------------
class CIFAR100C(Dataset):
    """
    Supports BOTH formats:

    Format A:
        labels.npy → (10000,)
        corruption.npy → (50000, 32, 32, 3)

    Format B:
        labels.npy → (50000,)
        corruption.npy → (50000, 32, 32, 3)

    Severity s ∈ {1..5} → selects block of 10000
    """

    def __init__(self, root: str, corruption: str, severity: int,
                 transform=None):

        assert 1 <= severity <= 5, "Severity must be 1..5"

        self.root = root
        self.corruption = corruption
        self.severity = severity
        self.transform = transform

        x_path = os.path.join(root, f"{corruption}.npy")
        y_path = os.path.join(root, "labels.npy")

        if not os.path.isfile(x_path):
            raise FileNotFoundError(f"Missing corruption file: {x_path}")
        if not os.path.isfile(y_path):
            raise FileNotFoundError(f"Missing labels file: {y_path}")

        # Memory map images (critical for clusters)
        self.x_all = np.load(x_path, mmap_mode="r")
        self.y_all = np.load(y_path)

        if self.x_all.shape[0] != 50000:
            raise ValueError(
                f"{corruption}.npy must have 50000 images, got {self.x_all.shape[0]}"
            )

        if self.y_all.shape[0] not in (10000, 50000):
            raise ValueError(
                f"labels.npy must have 10000 or 50000 entries, got {self.y_all.shape[0]}"
            )

        self.start = (severity - 1) * 10000
        self.end = severity * 10000

    def __len__(self):
        return 10000

    def __getitem__(self, idx):

        img = self.x_all[self.start + idx]  # uint8 [32,32,3]

        # Handle label formats
        if self.y_all.shape[0] == 10000:
            label = int(self.y_all[idx])  # shared labels
        else:
            label = int(self.y_all[self.start + idx])  # expanded labels

        img = torch.from_numpy(img.copy()).permute(2, 0, 1).float().div_(255.0)

        if self.transform is not None:
            img = self.transform(img)

        return img, label


# -------------------------------------------------
# Utility
# -------------------------------------------------
def list_corruptions(root: str) -> List[str]:
    files = [f for f in os.listdir(root) if f.endswith(".npy")]
    corruptions = sorted([
        os.path.splitext(f)[0]
        for f in files if f != "labels.npy"
    ])

    if not corruptions:
        raise RuntimeError(f"No corruption .npy files found in {root}")

    return corruptions


# -------------------------------------------------
# Evaluation
# -------------------------------------------------
@torch.no_grad()
def evaluate_loader(model, loader, device="cuda"):
    model.eval()
    correct = 0
    total = 0

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        with torch.amp.autocast(device_type=device, dtype=torch.bfloat16):
            logits = model(x)
        pred = logits.argmax(dim=1)

        correct += (pred == y).sum().item()
        total += y.numel()

    return correct / total


def eval_cifar100c(
    model: torch.nn.Module,
    c_root: str,
    batch_size: int = 256,
    num_workers: int = 4,
    device: str = "cuda",
    corruptions: Optional[List[str]] = None,
    mean=(0.485, 0.456, 0.406),
    std=(0.229, 0.224, 0.225),
) -> Tuple[Dict[str, List[float]], Dict[int, float], float]:

    if corruptions is None:
        corruptions = list_corruptions(c_root)

    val_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Resize((224, 224), antialias=True),

    v2.Normalize(mean=mean, std=std),
])

    accs_by_corr: Dict[str, List[float]] = {}
    mean_acc_by_sev: Dict[int, float] = {}

    for severity in range(1, 6):
        sev_accs = []

        print(f"\nSeverity {severity}")

        for corr in corruptions:
            ds = CIFAR100C(
                root=c_root,
                corruption=corr,
                severity=severity,
                transform=val_transform,
            )

            loader = DataLoader(
                ds,
                batch_size=batch_size,
                shuffle=False,
                num_workers=num_workers,
                pin_memory=True,
                persistent_workers=(num_workers > 0),
            )

            acc = evaluate_loader(model, loader, device)

            accs_by_corr.setdefault(corr, []).append(acc)
            sev_accs.append(acc)

            print(f"{corr:20s}: {acc*100:.2f}%")

        mean_acc_by_sev[severity] = float(np.mean(sev_accs))

    all_accs = [a for v in accs_by_corr.values() for a in v]
    mean_acc_all = float(np.mean(all_accs))

    return accs_by_corr, mean_acc_by_sev, mean_acc_all


# -------------------------------------------------
# Pretty Print
# -------------------------------------------------
def pretty_print(accs_by_corr, mean_acc_by_sev, mean_acc_all):

    print("\n==============================")
    print(" CIFAR-100-C Summary")
    print("==============================")

    for s in range(1, 6):
        print(f"Severity {s}: {mean_acc_by_sev[s]*100:.2f}%")

    print(f"\nMean over all: {mean_acc_all*100:.2f}%")

    corr_mean = {k: float(np.mean(v)) for k, v in accs_by_corr.items()}
    hardest = sorted(corr_mean.items(), key=lambda kv: kv[1])[:5]

    print("\nHardest corruptions:")
    for k, v in hardest:
        print(f"{k:20s}: {v*100:.2f}%")


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# 1) Load your model
device = "cuda"

# Replace this with your actual model definition
num_classes = 100
# model = models.efficientnet_b0(we.ights=None)
# model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=num_classes)
ckpt = torch.load("/home/hice1/yyu496/scratch/Model_Checkpoint/efficientnet_b0_bs_512_teacher_cifar100.pth")
model.load_state_dict(ckpt)
model.to(device)

# For demo only:
# raise SystemExit("Edit the model loading part, then run eval.")

# 2) Point to CIFAR-100-C folder
# Expected structure:
#   CIFAR-100-C/
#     gaussian_noise.npy
#     shot_noise.npy
#     ...
#     labels.npy

c_root = "/home/hice1/yyu496/scratch/data/CIFAR-100-C"

accs_by_corr, mean_acc_by_sev, mean_acc_all = eval_cifar100c(
    model, c_root,
    batch_size=256, num_workers=4, device=device
)
pretty_print(accs_by_corr, mean_acc_by_sev, mean_acc_all)



Severity 1
brightness          : 83.26%
contrast            : 83.08%
defocus_blur        : 82.69%
elastic_transform   : 76.53%
fog                 : 82.87%
frost               : 78.11%
gaussian_blur       : 82.68%
gaussian_noise      : 55.29%
glass_blur          : 33.37%
impulse_noise       : 62.66%
jpeg_compression    : 69.50%
motion_blur         : 77.40%
pixelate            : 78.94%
saturate            : 76.26%
shot_noise          : 66.36%
snow                : 76.75%
spatter             : 79.42%
speckle_noise       : 68.26%
zoom_blur           : 75.47%

Severity 2
brightness          : 83.14%
contrast            : 82.42%
defocus_blur        : 80.20%
elastic_transform   : 75.95%
fog                 : 81.89%
frost               : 73.17%
gaussian_blur       : 75.82%
gaussian_noise      : 37.01%
glass_blur          : 33.30%
impulse_noise       : 49.38%
jpeg_compression    : 59.99%
motion_blur         : 70.75%
pixelate            : 65.78%
saturate            : 67.90%
shot_noise         

In [None]:

Severity 1
brightness          : 77.81%
contrast            : 77.67%
defocus_blur        : 77.32%
elastic_transform   : 70.43%
fog                 : 77.57%
frost               : 71.64%
gaussian_blur       : 77.34%
gaussian_noise      : 48.49%
glass_blur          : 22.41%
impulse_noise       : 63.77%
jpeg_compression    : 68.96%
motion_blur         : 70.67%
pixelate            : 75.18%
saturate            : 73.00%
shot_noise          : 60.87%
snow                : 70.91%
spatter             : 74.86%
speckle_noise       : 62.42%
zoom_blur           : 67.61%

Severity 2
brightness          : 77.48%
contrast            : 76.69%
defocus_blur        : 74.39%
elastic_transform   : 69.22%
fog                 : 75.02%
frost               : 65.37%
gaussian_blur       : 67.16%
gaussian_noise      : 28.91%
glass_blur          : 24.21%
impulse_noise       : 50.06%
jpeg_compression    : 63.27%
motion_blur         : 63.21%
pixelate            : 70.18%
saturate            : 64.14%
shot_noise          : 48.35%
snow                : 57.42%
spatter             : 69.46%
speckle_noise       : 44.44%
zoom_blur           : 63.71%

Severity 3
brightness          : 76.90%
contrast            : 75.86%
defocus_blur        : 67.22%
elastic_transform   : 64.70%
fog                 : 71.89%
frost               : 55.38%
gaussian_blur       : 58.12%
gaussian_noise      : 16.85%
glass_blur          : 29.90%
impulse_noise       : 38.13%
jpeg_compression    : 60.85%
motion_blur         : 55.12%
pixelate            : 66.52%
saturate            : 77.01%
shot_noise          : 26.36%
snow                : 61.86%
spatter             : 63.36%
speckle_noise       : 36.20%
zoom_blur           : 57.96%

Severity 4
brightness          : 76.37%
contrast            : 73.95%
defocus_blur        : 57.87%
elastic_transform   : 58.83%
fog                 : 64.25%
frost               : 53.28%
gaussian_blur       : 48.34%
gaussian_noise      : 13.14%
glass_blur          : 14.27%
impulse_noise       : 19.09%
jpeg_compression    : 57.97%
motion_blur         : 55.09%
pixelate            : 48.95%
saturate            : 74.50%
shot_noise          : 20.20%
snow                : 61.01%
spatter             : 66.12%
speckle_noise       : 23.26%
zoom_blur           : 52.65%

Severity 5
brightness          : 74.47%
contrast            : 61.90%
defocus_blur        : 40.41%
elastic_transform   : 51.98%
fog                 : 45.16%
frost               : 44.55%
gaussian_blur       : 31.45%
gaussian_noise      : 10.46%
glass_blur          : 18.30%
impulse_noise       : 9.22%
jpeg_compression    : 53.94%
motion_blur         : 47.70%
pixelate            : 29.52%
saturate            : 70.56%
shot_noise          : 13.06%
snow                : 55.69%
spatter             : 55.32%
speckle_noise       : 15.33%
zoom_blur           : 44.21%

==============================
 CIFAR-100-C Summary
==============================
Severity 1: 67.84%
Severity 2: 60.67%
Severity 3: 55.80%
Severity 4: 49.43%
Severity 5: 40.70%

Mean over all: 54.89%

Hardest corruptions:
glass_blur          : 21.82%
gaussian_noise      : 23.57%
shot_noise          : 33.77%
impulse_noise       : 36.05%
speckle_noise       : 36.33%


==============================
 CIFAR-100-C Summary
==============================
Severity 1: 67.91%
Severity 2: 60.72%
Severity 3: 55.67%
Severity 4: 49.24%
Severity 5: 40.57%

Mean over all: 54.82%

Hardest corruptions:
glass_blur          : 22.70%
gaussian_noise      : 23.56%
shot_noise          : 34.16%
speckle_noise       : 36.64%
impulse_noise       : 38.42%

In [1]:
import timm
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
timm.list_models("*efficientnet*", pretrained=True)

['efficientnet_b0.ra4_e3600_r224_in1k',
 'efficientnet_b0.ra_in1k',
 'efficientnet_b1.ft_in1k',
 'efficientnet_b1.ra4_e3600_r240_in1k',
 'efficientnet_b1_pruned.in1k',
 'efficientnet_b2.ra_in1k',
 'efficientnet_b2_pruned.in1k',
 'efficientnet_b3.ra2_in1k',
 'efficientnet_b3_pruned.in1k',
 'efficientnet_b4.ra2_in1k',
 'efficientnet_b5.sw_in12k',
 'efficientnet_b5.sw_in12k_ft_in1k',
 'efficientnet_el.ra_in1k',
 'efficientnet_el_pruned.in1k',
 'efficientnet_em.ra2_in1k',
 'efficientnet_es.ra_in1k',
 'efficientnet_es_pruned.in1k',
 'efficientnet_lite0.ra_in1k',
 'efficientnetv2_rw_m.agc_in1k',
 'efficientnetv2_rw_s.ra2_in1k',
 'efficientnetv2_rw_t.ra2_in1k',
 'gc_efficientnetv2_rw_t.agc_in1k',
 'test_efficientnet.r160_in1k',
 'test_efficientnet_evos.r160_in1k',
 'test_efficientnet_gn.r160_in1k',
 'test_efficientnet_ln.r160_in1k',
 'tf_efficientnet_b0.aa_in1k',
 'tf_efficientnet_b0.ap_in1k',
 'tf_efficientnet_b0.in1k',
 'tf_efficientnet_b0.ns_jft_in1k',
 'tf_efficientnet_b1.aa_in1k',
 'tf_e