In [10]:
# Imports
import os
import random
from typing import Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms

from PIL import ImageFilter
import matplotlib.pyplot as plt

In [None]:
# Basic config
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

# Device agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Corruption labels
CORRUPTION_NAMES = {
    0: "clean",
    1: "gaussian_noise",
    2: "blur",
    3: "jpeg_compression",
}
NUM_CORRUPTIONS = len(CORRUPTION_NAMES)

DATA_ROOT = "./data"
BATCH_SIZE = 128
EPOCHS = 5
LEARNING_RATE = 1e-3


In [None]:
# Corruption functions

from PIL import Image

to_tensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()


def add_gaussian_noise(img: Image.Image, std: float = 0.1) -> Image.Image:
    x = to_tensor(img)
    noise = torch.randn_like(x) * std
    x_noisy = torch.clamp(x + noise, 0.0, 1.0)
    return to_pil(x_noisy)


def add_blur(img: Image.Image, radius: float = 1.5) -> Image.Image:
    return img.filter(ImageFilter.GaussianBlur(radius=radius))


def add_jpeg_compression(img: Image.Image, downscale_factor: int = 2) -> Image.Image:
    w, h = img.size
    small_size = (w // downscale_factor, h // downscale_factor)
    down = img.resize(small_size, Image.BILINEAR)
    up = down.resize((w, h), Image.BILINEAR)
    return up


def apply_corruption(img: Image.Image, corruption_idx: int) -> Image.Image:
    if corruption_idx == 0:
        return img
    elif corruption_idx == 1:
        return add_gaussian_noise(img)
    elif corruption_idx == 2:
        return add_blur(img)
    elif corruption_idx == 3:
        return add_jpeg_compression(img)
    else:
        raise ValueError(f"Invalid corruption_idx: {corruption_idx}")


In [None]:
# Corrupted CIFAR-10 Dataset

class CorruptedCIFAR10(Dataset):
    def __init__(self, root: str, train: bool = True, download: bool = True):
        self.base = torchvision.datasets.CIFAR10(
            root=root,
            train=train,
            download=download,
            transform=None,
        )
        self.num_corruptions = NUM_CORRUPTIONS
        self.to_tensor = transforms.ToTensor()

    def __len__(self) -> int:
        return len(self.base) * self.num_corruptions

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        base_idx = idx // self.num_corruptions
        corruption_idx = idx % self.num_corruptions

        img, _ = self.base[base_idx]
        img = apply_corruption(img, corruption_idx)
        x = self.to_tensor(img)  # [C, H, W] in [0, 1]
        y = corruption_idx
        return x, y

train_dataset = CorruptedCIFAR10(root=DATA_ROOT, train=True, download=True)
test_dataset = CorruptedCIFAR10(root=DATA_ROOT, train=False, download=True)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print("Train samples:", len(train_dataset))
print("Test samples:", len(test_dataset))
print("Number of classes (corruptions):", NUM_CORRUPTIONS)


In [None]:
# Look at some examples
def show_batch(loader: DataLoader, n_images: int = 8):
    x_batch, y_batch = next(iter(loader))
    x_batch = x_batch[:n_images]
    y_batch = y_batch[:n_images]

    fig, axes = plt.subplots(1, n_images, figsize=(n_images * 2, 2))
    for i in range(n_images):
        # print(f"Currently doing the {i}th image:")
        img = x_batch[i].permute(1, 2, 0)  # [C,H,W] -> [H,W,C]
        label = CORRUPTION_NAMES[int(y_batch[i].item())]
        axes[i].imshow(img)
        axes[i].set_title(label, fontsize=8)
        axes[i].axis("off")
    plt.tight_layout()
    plt.show()


show_batch(train_loader, n_images=8)


In [None]:
# Model Definition

class SimpleCNN(nn.Module):
    def __init__(self, num_classes: int = NUM_CORRUPTIONS):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 32x32 -> 16x16

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 16x16 -> 8x8

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),  # -> [128, 1, 1]
        )
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = x.view(x.size(0), -1)  # flatten
        x = self.classifier(x)
        return x


model = SimpleCNN(num_classes=NUM_CORRUPTIONS).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(model)


In [None]:
#Train
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for x, y in loader:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)

        _, preds = torch.max(outputs, dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            outputs = model(x)
            loss = criterion(outputs, y)

            running_loss += loss.item() * x.size(0)
            _, preds = torch.max(outputs, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(model, test_loader, criterion, device)

    print(
        f"Epoch {epoch}/{EPOCHS} | "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
    )


In [None]:
def predict_sample(model, loader, device, n_samples: int = 5):
    model.eval()
    x_batch, y_batch = next(iter(loader))
    x_batch = x_batch[:n_samples].to(device)
    y_batch = y_batch[:n_samples]

    with torch.no_grad():
        outputs = model(x_batch)
        _, preds = torch.max(outputs, dim=1)

    x_batch_cpu = x_batch.cpu()

    fig, axes = plt.subplots(1, n_samples, figsize=(n_samples * 2, 2))
    for i in range(n_samples):
        img = x_batch_cpu[i].permute(1, 2, 0)
        true_label = CORRUPTION_NAMES[int(y_batch[i].item())]
        pred_label = CORRUPTION_NAMES[int(preds[i].item())]

        axes[i].imshow(img)
        axes[i].set_title(f"T:{true_label}\nP:{pred_label}", fontsize=7)
        axes[i].axis("off")
    plt.tight_layout()
    plt.show()


predict_sample(model, test_loader, device, n_samples=8)
