In [None]:
# # Imports
import os
import random
from typing import Tuple, Dict, List

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms
import torchvision.transforms.functional as F

from PIL import Image, ImageFilter, ImageChops
import matplotlib.pyplot as plt

In [None]:
# Basic config
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Corruption labels (index -> name)
CORRUPTION_NAMES: Dict[int, str] = {
    0: "clean",
    1: "gaussian_noise",
    2: "gaussian_blur",
    3: "jpeg_compression",
    4: "brightness",
    5: "contrast",
    6: "fog",
    7: "snow",
    8: "pixel_dropout",
    9: "motion_blur",
    10: "color_jitter",
    11: "grayscale",
}
NUM_CORRUPTIONS = len(CORRUPTION_NAMES)

DATA_ROOT = "./data"
BATCH_SIZE = 128
EPOCHS = 10
LEARNING_RATE = 1e-3

to_tensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()


In [None]:
# Corruption functions

def add_gaussian_noise(img: Image.Image, severity: int = 1) -> Image.Image:
    """Add Gaussian noise; severity controls std."""
    x = to_tensor(img)
    std_map = {1: 0.05, 2: 0.1, 3: 0.2}
    std = std_map.get(severity, 0.1)
    noise = torch.randn_like(x) * std
    x_noisy = torch.clamp(x + noise, 0.0, 1.0)
    return to_pil(x_noisy)


def add_gaussian_blur(img: Image.Image, severity: int = 1) -> Image.Image:
    """Gaussian blur; severity controls radius."""
    radius_map = {1: 0.8, 2: 1.5, 3: 2.5}
    radius = radius_map.get(severity, 1.5)
    return img.filter(ImageFilter.GaussianBlur(radius=radius))


def add_jpeg_compression(img: Image.Image, severity: int = 1) -> Image.Image:
    """Simulate compression via downscale+upscale; severity controls factor."""
    w, h = img.size
    factor_map = {1: 2, 2: 4, 3: 6}
    factor = factor_map.get(severity, 4)
    small_size = (max(1, w // factor), max(1, h // factor))
    down = img.resize(small_size, Image.BILINEAR)
    up = down.resize((w, h), Image.BILINEAR)
    return up


def add_brightness(img: Image.Image, severity: int = 1) -> Image.Image:
    factor_map = {1: 0.7, 2: 0.5, 3: 0.3}  # darker as severity increases
    factor = factor_map.get(severity, 0.5)
    return F.adjust_brightness(img, factor)


def add_contrast(img: Image.Image, severity: int = 1) -> Image.Image:
    factor_map = {1: 0.75, 2: 0.5, 3: 0.3}
    factor = factor_map.get(severity, 0.5)
    return F.adjust_contrast(img, factor)


def add_fog(img: Image.Image, severity: int = 1) -> Image.Image:
    """Blend image with white and blur it slightly."""
    x = to_tensor(img)
    alpha_map = {1: 0.3, 2: 0.5, 3: 0.7}
    alpha = alpha_map.get(severity, 0.5)
    fog = torch.ones_like(x)
    x_fog = torch.clamp((1 - alpha) * x + alpha * fog, 0.0, 1.0)
    fog_img = to_pil(x_fog)
    return fog_img.filter(ImageFilter.GaussianBlur(radius=1.0))


def add_snow(img: Image.Image, severity: int = 1) -> Image.Image:
    """Overlay white 'snow'."""
    x = to_tensor(img)
    h, w = x.shape[1], x.shape[2]
    density_map = {1: 0.03, 2: 0.07, 3: 0.12}
    density = density_map.get(severity, 0.07)
    snow_mask = (torch.rand(1, h, w) < density).float()
    x_snow = torch.clamp(x + snow_mask, 0.0, 1.0)
    return to_pil(x_snow)


def add_pixel_dropout(img: Image.Image, severity: int = 1) -> Image.Image:
    """Randomly zero out patches."""
    x = to_tensor(img)
    h, w = x.shape[1], x.shape[2]
    num_patches_map = {1: 2, 2: 4, 3: 6}
    num_patches = num_patches_map.get(severity, 4)
    patch_size_map = {1: 4, 2: 6, 3: 8}
    patch_size = patch_size_map.get(severity, 6)

    for _ in range(num_patches):
        top = random.randint(0, max(0, h - patch_size))
        left = random.randint(0, max(0, w - patch_size))
        x[:, top:top + patch_size, left:left + patch_size] = 0.0
    return to_pil(x)

def add_motion_blur(img: Image.Image, severity: int = 1) -> Image.Image:
    """Apply motion blur using PyTorch Conv2d."""
    # 1. Determine kernel size
    ksize_map = {1: 5, 2: 9, 3: 13}
    k = ksize_map.get(severity, 9)

    # 2. Convert PIL -> Tensor
    x = to_tensor(img).unsqueeze(0)

    # 3. Create the horizontal motion blur kernel
    kernel_2d = torch.zeros((k, k))
    kernel_2d[k // 2, :] = 1.0 / k

    # Reshape to [Channels, 1, k, k] for depthwise convolution
    c = x.shape[1]
    weight = kernel_2d.expand(c, 1, k, k)

    # 4. Convolve (Padding = k//2 maintains original image size)
    x_blurred = torch.nn.functional.conv2d(x, weight, padding=k // 2, groups=c)

    # 5. Convert Tensor -> PIL
    return to_pil(x_blurred.squeeze(0))


def add_color_jitter(img: Image.Image, severity: int = 1) -> Image.Image:
    """Random color jitter; severity controls strength."""
    if severity == 1:
        cj = transforms.ColorJitter(brightness=0.2, contrast=0.2,
                                    saturation=0.2, hue=0.05)
    elif severity == 2:
        cj = transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                    saturation=0.4, hue=0.1)
    else:
        cj = transforms.ColorJitter(brightness=0.6, contrast=0.6,
                                    saturation=0.6, hue=0.15)
    return cj(img)


def add_grayscale(img: Image.Image, severity: int = 1) -> Image.Image:
    """Convert to grayscale then slightly adjust contrast."""
    g = F.rgb_to_grayscale(img, num_output_channels=3)
    factor_map = {1: 1.0, 2: 0.8, 3: 0.6}
    factor = factor_map.get(severity, 0.8)
    return F.adjust_contrast(g, factor)


def apply_corruption(img: Image.Image, corruption_idx: int, severity: int) -> Image.Image:
    """
    Apply one of the predefined corruptions with given severity.
    """
    if corruption_idx == 0:
        return img
    elif corruption_idx == 1:
        return add_gaussian_noise(img, severity)
    elif corruption_idx == 2:
        return add_gaussian_blur(img, severity)
    elif corruption_idx == 3:
        return add_jpeg_compression(img, severity)
    elif corruption_idx == 4:
        return add_brightness(img, severity)
    elif corruption_idx == 5:
        return add_contrast(img, severity)
    elif corruption_idx == 6:
        return add_fog(img, severity)
    elif corruption_idx == 7:
        return add_snow(img, severity)
    elif corruption_idx == 8:
        return add_pixel_dropout(img, severity)
    elif corruption_idx == 9:
        return add_motion_blur(img, severity)
    elif corruption_idx == 10:
        return add_color_jitter(img, severity)
    elif corruption_idx == 11:
        return add_grayscale(img, severity)
    else:
        raise ValueError(f"Invalid corruption_idx: {corruption_idx}")


In [None]:
# Corrupted CIFAR-10 Dataset

class CorruptedCIFAR10(Dataset):
    def __init__(self, root: str, train: bool = True, download: bool = True):
        self.base = torchvision.datasets.CIFAR10(
            root=root,
            train=train,
            download=download,
            transform=None,
        )
        self.num_corruptions = NUM_CORRUPTIONS
        self.to_tensor = transforms.ToTensor()

    def __len__(self) -> int:
        return len(self.base) * self.num_corruptions

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        base_idx = idx // self.num_corruptions
        corruption_idx = idx % self.num_corruptions

        img, _ = self.base[base_idx]
        severity = random.randint(1, 3)

        img_corrupted = apply_corruption(img, corruption_idx, severity)
        x = self.to_tensor(img_corrupted)
        y = corruption_idx
        return x, y


train_dataset = CorruptedCIFAR10(root=DATA_ROOT, train=True, download=True)
test_dataset = CorruptedCIFAR10(root=DATA_ROOT, train=False, download=True)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                          shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                         shuffle=False, num_workers=0)

print("Train samples:", len(train_dataset))
print("Test samples:", len(test_dataset))
print("Number of classes (corruptions):", NUM_CORRUPTIONS)


In [None]:
# Look at some samples
def show_batch(loader: DataLoader, n_images: int = 8):
    x_batch, y_batch = next(iter(loader))
    x_batch = x_batch[:n_images]
    y_batch = y_batch[:n_images]

    fig, axes = plt.subplots(1, n_images, figsize=(n_images * 2, 2))
    for i in range(n_images):
        img = x_batch[i].permute(1, 2, 0)  # [C,H,W] -> [H,W,C]
        label = CORRUPTION_NAMES[int(y_batch[i].item())]
        axes[i].imshow(img)
        axes[i].set_title(label, fontsize=7)
        axes[i].axis("off")
    plt.tight_layout()
    plt.show()


show_batch(train_loader, n_images=8)


In [None]:
# Model Definition
from torchvision import models

class ResNet18Corruption(nn.Module):
    def __init__(self, num_classes: int = NUM_CORRUPTIONS):
        super().__init__()
        self.backbone = models.resnet18(weights=None)
        self.backbone.conv1 = nn.Conv2d(
            3, 64, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.backbone.maxpool = nn.Identity()
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)


model = ResNet18Corruption(num_classes=NUM_CORRUPTIONS).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(model.backbone.fc)


In [None]:
# Train

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for x, y in loader:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)
        _, preds = torch.max(outputs, dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    correct_per_class = [0 for _ in range(NUM_CORRUPTIONS)]
    total_per_class = [0 for _ in range(NUM_CORRUPTIONS)]

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            outputs = model(x)
            loss = criterion(outputs, y)
            running_loss += loss.item() * x.size(0)

            _, preds = torch.max(outputs, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

            for cls in range(NUM_CORRUPTIONS):
                mask = (y == cls)
                n = mask.sum().item()
                if n > 0:
                    total_per_class[cls] += n
                    correct_per_class[cls] += ((preds == y) & mask).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    per_class_stats = {
        cls: (correct_per_class[cls], total_per_class[cls])
        for cls in range(NUM_CORRUPTIONS)
    }
    return epoch_loss, epoch_acc, per_class_stats


history: List[dict] = []

for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_stats = evaluate(model, test_loader, criterion, device)

    print(
        f"Epoch {epoch}/{EPOCHS} | "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
    )

    print("Per-corruption validation accuracy:")
    for cls_idx, name in CORRUPTION_NAMES.items():
        correct_cls, total_cls = val_stats[cls_idx]
        acc_cls = correct_cls / total_cls if total_cls > 0 else 0.0
        print(
            f"  {cls_idx:2d} ({name:>15}): "
            f"{acc_cls:.4f}  ({correct_cls}/{total_cls})"
        )

    print("-" * 60)

    epoch_record = {
        "epoch": epoch,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "val_loss": val_loss,
        "val_acc": val_acc,
        "per_class": {
            cls: (val_stats[cls][0], val_stats[cls][1])
            for cls in range(NUM_CORRUPTIONS)
        },
    }
    history.append(epoch_record)


In [None]:
def predict_sample(model, loader, device, n_samples: int = 8):
    model.eval()
    x_batch, y_batch = next(iter(loader))
    x_batch = x_batch[:n_samples].to(device)
    y_batch = y_batch[:n_samples]

    with torch.no_grad():
        outputs = model(x_batch)
        _, preds = torch.max(outputs, dim=1)

    x_batch_cpu = x_batch.cpu()

    fig, axes = plt.subplots(1, n_samples, figsize=(n_samples * 2, 2))
    for i in range(n_samples):
        img = x_batch_cpu[i].permute(1, 2, 0)
        true_label = CORRUPTION_NAMES[int(y_batch[i].item())]
        pred_label = CORRUPTION_NAMES[int(preds[i].item())]

        axes[i].imshow(img)
        axes[i].set_title(f"T:{true_label}\nP:{pred_label}", fontsize=7)
        axes[i].axis("off")
    plt.tight_layout()
    plt.show()


predict_sample(model, test_loader, device, n_samples=8)
