<a href="https://colab.research.google.com/github/tannisthamaiti/DiffusionModels_DDPM_DDIM/blob/main/Augmix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# AugMix Implementation in PyTorch
# Source: Hendrycks et al. "AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty." ICLR 2020.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
import random
import matplotlib.pyplot as plt

# --- AugMix Utils ---
def int_parameter(level, maxval):
    return int(level * maxval / 10)

def float_parameter(level, maxval):
    return float(level) * maxval / 10.

def sample_level(n):
    return np.random.uniform(low=0.1, high=n)

def autocontrast(img):
    return Image.fromarray(np.array(img)).convert("RGB")

def rotate(img, level):
    return img.rotate(int_parameter(level, 30))

def shear_x(img, level):
    return img.transform(img.size, Image.AFFINE, (1, float_parameter(level, 0.3), 0, 0, 1, 0))

def shear_y(img, level):
    return img.transform(img.size, Image.AFFINE, (1, 0, 0, float_parameter(level, 0.3), 1, 0))

augmentations = [
    lambda x: x,
    autocontrast,
    lambda x: rotate(x, sample_level(3)),
    lambda x: shear_x(x, sample_level(3)),
    lambda x: shear_y(x, sample_level(3)),
]

def augmix(image, severity=3, width=3, depth=-1, alpha=1.):
    ws = np.float32(np.random.dirichlet([alpha] * width))
    m = np.float32(np.random.beta(alpha, alpha))

    mix = torch.zeros_like(T.ToTensor()(image))
    for i in range(width):
        image_aug = image.copy()
        d = depth if depth > 0 else np.random.randint(1, 4)
        for _ in range(d):
            op = random.choice(augmentations)
            image_aug = op(image_aug)
        mix += ws[i] * T.ToTensor()(image_aug)

    mixed = (1 - m) * T.ToTensor()(image) + m * mix
    return mixed

# --- Dataset Setup ---
transform_base = T.Compose([
    T.Resize(32),
    T.CenterCrop(32),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_base)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# --- Model Setup ---
model = models.resnet18(pretrained=False, num_classes=10).to('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def jsd_loss(p_clean, p_aug1, p_aug2):
    p_mixture = (p_clean + p_aug1 + p_aug2) / 3.
    p_mixture = torch.clamp(p_mixture, 1e-7, 1).log()
    return (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
            F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
            F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.

# --- Training Loop (1 Epoch for Demonstration) ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.train()

for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    clean = images
    aug1 = torch.stack([augmix(img) for img in images.cpu()]).to(device)
    aug2 = torch.stack([augmix(img) for img in images.cpu()]).to(device)

    logits_clean = model(clean)
    logits_aug1 = model(aug1)
    logits_aug2 = model(aug2)

    loss_ce = F.cross_entropy(logits_clean, labels)
    loss_jsd = jsd_loss(F.softmax(logits_clean, dim=1),
                        F.softmax(logits_aug1, dim=1),
                        F.softmax(logits_aug2, dim=1))
    loss = loss_ce + 12 * loss_jsd

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    break  # remove this to train for full epoch

print("AugMix training step completed.")

# --- Citation ---
print("""
@inproceedings{hendrycks2020augmix,
  title={AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty},
  author={Hendrycks, Dan and Mu, Norman and Cubuk, Ekin D and Zoph, Barret and Gilmer, Justin and Lakshminarayanan, Balaji},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2020}
}
""")
