<a href="https://colab.research.google.com/github/QasimWani/simple-transformer/blob/main/mlp/image_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from datasets import load_dataset, Dataset
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.models as models
import pandas
import matplotlib.pyplot as plt

from diffusers.utils import make_image_grid
from PIL import Image

from torch.utils.data import DataLoader

import sklearn.metrics as metrics
from collections import defaultdict
from tqdm import tqdm


print(torch.cuda.is_available())

In [None]:
from functools import partial
# Step 1 - inspect dataset
ds = load_dataset("uoft-cs/cifar10")

CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)

# Step 2 - Define transform
validate_transform = transforms.Compose([
    transforms.ToTensor(),  # (H, W, C) PIL → (C, H, W) float32 [0,1]
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])
train_transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(p=0.5),
  transforms.RandomRotation(15),
])

# Step 3 - Wrap in dataset class
def transform_batch(batch, transform):
    imgs = [transform(img) for img in batch['img']]
    labels = batch['label']
    return {'pixel_values': imgs, 'labels': labels}


ds_val = ds['test'].with_transform(partial(transform_batch, transform=validate_transform))
ds_train = ds['train'].with_transform(partial(transform_batch, transform=validate_transform)) # Because we're creating a noisy dataset, doesn't make sense to add augs

# Step 4 - Torch DataLoader
train_loader = DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(ds_val, batch_size=128, shuffle=False, num_workers=2)

In [None]:
class NoisyDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataset, feature_noise=0.0, label_noise=0.0, num_classes=10):
        self.data = original_dataset
        self.original_size = len(original_dataset)

        # Label noise: which samples get random labels
        self.noisy_labels = np.random.random(self.original_size) < label_noise
        # NOTE: if you do 0, num_classes. the class distribution will match the existing distribution.
        # But to make the test even harder, we will increase label noise as well as make the dataset more unbalanced
        self.random_labels = np.random.randint(0, 4, self.original_size)

        # Feature noise: duplicate some samples
        num_dups = int(self.original_size * feature_noise)
        self.dup_indices = np.random.choice(self.original_size, num_dups, replace=True) if num_dups > 0 else []

    def __len__(self):
        return self.original_size + len(self.dup_indices)

    def __getitem__(self, idx):
        # Get original or duplicate sample
        orig_idx = int(self.dup_indices[idx - self.original_size]) if idx >= self.original_size else idx
        sample = self.data[orig_idx]

        # Unpack
        img = sample['pixel_values'] if isinstance(sample, dict) else sample[0]
        label = sample['labels'] if isinstance(sample, dict) else sample[1]

        # Apply noise
        if idx >= self.original_size:  # Feature noise (duplicate)
            img = img + torch.randn_like(img) * 0.01
        elif self.noisy_labels[idx]:  # Label noise
            label = int(self.random_labels[idx])

        return {'pixel_values': img, 'labels': label} if isinstance(sample, dict) else (img, label)


def inject_noise(dataloader, feature_noise: float, label_noise: float):
    """Inject label and feature noise into a DataLoader."""
    noisy_dataset = NoisyDataset(dataloader.dataset, feature_noise, label_noise)

    return DataLoader(
        noisy_dataset,
        batch_size=dataloader.batch_size,
        shuffle=isinstance(dataloader.sampler, torch.utils.data.sampler.RandomSampler),
        num_workers=dataloader.num_workers,
        pin_memory=dataloader.pin_memory,
        drop_last=dataloader.drop_last
    )

# Usage:
noisy_train_loader = inject_noise(train_loader, feature_noise=0.4, label_noise=0.2)

In [None]:
def visualize_distribution(dataloader, title):
   label_freq = torch.zeros(10)
   image_mean = []
   image_std = []
   per_class_feature_distribution = torch.zeros(10)

   for batch in dataloader:
       imgs = batch['pixel_values']
       labels = batch['labels']
       label_freq += torch.bincount(labels, minlength=10)

       for i in range(10):
           idx = torch.where(labels == i)[0]
           if len(idx) > 0:
               samples = imgs[idx]
               per_class_feature_distribution[i] += samples.mean()

       batch_mean = imgs.mean(dim=(0, 2, 3))
       batch_std = imgs.std(dim=(0, 2, 3))
       image_mean.append(batch_mean)
       image_std.append(batch_std)

   per_class_feature_distribution /= len(dataloader)
   image_mean = torch.vstack(image_mean).mean(0).numpy()
   image_std = torch.vstack(image_std).mean(0).numpy()

   fig, ax = plt.subplots(1, 3, figsize=(12, 4))
   fig.suptitle(title)

   ax[0].bar(["R", "G", "B"], image_mean, yerr=image_std, capsize=5)
   ax[0].set_title("Channel Distribution")

   ax[1].bar(np.arange(10), per_class_feature_distribution)
   ax[1].set_title("Per class feature distribution")

   ax[2].bar(np.arange(10), label_freq)
   ax[2].set_title("Label distribution")

   plt.tight_layout()
   plt.show()

In [None]:
# def featurize(batch: torch.Tensor, out_size=8):
#     # simple featurizer that just interpolates the image to a very small resolution 8 x 8
#     gray = batch.mean(1, keepdim=True)
#     small = F.interpolate(gray, size=(out_size, out_size), mode='bilinear', align_corners=False)
#     feat  = small.flatten(1)
#     return feat

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
resnet = models.resnet50(pretrained=True).to(device)
resnet.eval()
# Remove the final classification layer
feature_extractor = torch.nn.Sequential(*list(resnet.children())[:-1])

def featurize(batch: torch.Tensor, out_size=None, mini_batch_size=32):
    """
    Extract features using pretrained ResNet50 from torchvision
    batch: (batch_size, 3, H, W) tensor with values in [0, 1]
    returns: (batch_size, 2048) feature vectors
    """
    batch_size = batch.shape[0]
    features_list = []

    # Normalize for ImageNet pretrained model
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(batch.device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(batch.device)
    normalized = (batch - mean) / std

    # Resize to 224x224 if needed (ResNet expects this size)
    if batch.shape[-1] != 224:
        normalized = F.interpolate(normalized, size=(224, 224), mode='bilinear', align_corners=False)

    # Process in mini-batches
    with torch.no_grad():
        for i in range(0, batch_size, mini_batch_size):
            mini_batch = normalized[i:i+mini_batch_size].to(device)
            mini_features = feature_extractor(mini_batch)
            mini_features = mini_features.squeeze(-1).squeeze(-1)  # Remove spatial dims
            features_list.append(mini_features.cpu())  # Move to CPU to save GPU memory

            # Clear CUDA cache after each mini-batch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    # Concatenate all features
    features = torch.cat(features_list, dim=0).to(batch.device)

    return features


def distance(samples: torch.Tensor) -> torch.Tensor:
  # samples = batch_size, d_embed
  # Output: batch_size, batch_size
  samples = samples / samples.norm(p=2, dim=-1, keepdim=True) # NOTE: we need to take the norm wrt to each row, i.e. norm(col)
  out = samples @ samples.t()
  eps = 1e-3
  assert ((out >= -1 - eps) & (out <= 1 + eps)).all(), out # the bounds of cosine function are -1, 1
  return out

def batch_de_dup(batch, k: int, threshold: float, use_pbar: bool = False, can_vote: bool = False):
  '''
  This function will prune out all samples that are similar in feature space.
  Algorithm:
  1. Featurize dataset
  2. Compute L2 norm - cosine similarity
  3. Find all samples that are é distance away from the centroid.
  But for simplicity, we can just do it on a per sample basis - remove up to k nearest neighbors per sample
  '''
  images, labels = batch['pixel_values'], batch['labels']
  embeddings = featurize(images) # (batch_size, d_embed)
  indices_pruned = set()
  batch_size = len(embeddings)

  new_labels = labels.clone()

  pbar = tqdm(range(batch_size)) if use_pbar else range(batch_size)
  # Calculate the similarity matrix between each query and sample position
  sim:torch.Tensor = distance(embeddings)  # (batch_size, batch_size)
  sim.fill_diagonal_(-1)  # query with query should be non-sensical value since we do not want to return it from topk

  labels_corrected = 0

  for i in pbar:
    if i in indices_pruned:
      continue

    # take top k samples and corresponding indices
    val, idx = torch.topk(sim[i], min(k, batch_size - 1))

    # 1. Apply de-duplication: reduces feature noise
    # 2. Apply de-noise via majority voting: reduces label noise
    # for (2), you'd need to look at all the labels for this batch and assign query's label to the majority
    high_confident_sample_indices = idx[torch.where(val >= threshold)[0]]
    if len(high_confident_sample_indices) == 0:
      continue

    if can_vote:
      minibatch_labels = labels[torch.cat([high_confident_sample_indices, torch.tensor([i], device=idx.device, dtype=idx.dtype)])]
      majority_label = torch.bincount(minibatch_labels.cpu()).argmax()
      new_labels[i] = majority_label
      labels_corrected += 1 * (new_labels[i] != labels[i])

    # once you have the value and indices, the next step is to prune them out of the batch
    for j in high_confident_sample_indices:
      if j not in indices_pruned:
        indices_pruned.add(j.item())

  if indices_pruned:
    indices_pruned = torch.tensor(np.array(list(indices_pruned)))
    unique_idx_mask = ~torch.isin(torch.arange(batch_size), indices_pruned)
  else:
    unique_idx_mask = torch.ones(batch_size, dtype=torch.bool)

  unique_idx = torch.where(unique_idx_mask)[0]

  return {'pixel_values': images[unique_idx], 'labels': new_labels[unique_idx], 'labels_corrected': labels_corrected}


def de_dup_across_all_batches(dataloader, k: int, threshold: float):
  # we need to create some sort of a heap to ensure that we're only taking values that are the optimal
  global_batch = {'pixel_values': [], 'labels': []}
  total_labels_corrected = 0
  for batch in tqdm(dataloader):
    pruned_batch = batch_de_dup(batch, k, threshold, can_vote=False)
    total_labels_corrected += pruned_batch['labels_corrected']
    global_batch['pixel_values'].append(pruned_batch['pixel_values'])
    global_batch['labels'].append(pruned_batch['labels'])

  global_batch['pixel_values'] = torch.cat(global_batch['pixel_values'], dim=0)
  global_batch['labels'] = torch.cat(global_batch['labels'], dim=0)

  print(f"Total labels corrected (first stage): {total_labels_corrected} / {len(dataloader.dataset)}")

  global_dataset = batch_de_dup(global_batch, k, threshold, use_pbar=True, can_vote=True)
  total_labels_corrected += global_dataset['labels_corrected']

  print(f"Total labels corrected (second stage): {total_labels_corrected} / {len(dataloader.dataset)}\nFinal dataset size: {len(global_dataset['labels'])}")
  hf_dataset = Dataset.from_dict({
      "pixel_values": global_dataset["pixel_values"].cpu().numpy(),
      "labels": global_dataset["labels"].cpu().numpy()
  })
  hf_dataset.set_format(type="torch", columns=["pixel_values", "labels"])
  loader = DataLoader(hf_dataset, batch_size=128, shuffle=True, num_workers=2)
  return loader

In [None]:
%%time
clean_dataloader = de_dup_across_all_batches(noisy_train_loader, k=5, threshold=0.98)

In [None]:
%%time
# visualize_distribution(train_loader, "Original Train Set")
visualize_distribution(clean_dataloader, "Pruned out noisy set")
visualize_distribution(noisy_train_loader, "Synthetic Noisy Train Set")
# visualize_distribution(test_loader, "Original Test Set")

In [None]:
class SimpleCNN(nn.Module):
  def __init__(self, num_classes):
    super().__init__()
    self.feature_extractor = nn.Sequential(
        nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1), # 32 x 32 -> 32 x 32
        nn.BatchNorm2d(16),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2), # 32 x 32 -> 16 x 16

        nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),  # 16 x 16 -> 16 x 16
        nn.BatchNorm2d(32),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2), # 16 x 16 -> 8 x 8

        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 8 x 8 -> 8 x 8
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2) # 8 x 8 -> 4 x 4
    )

    self.classifier = nn.Sequential(
        nn.AdaptiveAvgPool2d((1, 1)),  # output shape: (N, 64, 1, 1)
        nn.Flatten(), # (N, 64)
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Linear(64, num_classes) # logits
    )

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.feature_extractor(x)
    logits = self.classifier(x)
    return logits

In [None]:
def train_one_epoch(dataloader, model, criterion, optimizer):
  model.train()
  total_loss = 0.0
  for batch in dataloader:
    torch.cuda.empty_cache()

    optimizer.zero_grad()
    imgs = batch['pixel_values'].to('cuda') if isinstance(batch, dict) else batch[0].to('cuda')
    labels = batch['labels'].to('cuda') if isinstance(batch, dict) else batch[1].to('cuda')
    logits = model(imgs)
    loss = criterion(logits, labels)

    loss.backward()
    optimizer.step()

    total_loss += loss.item()
  return total_loss / len(dataloader)

def validate(model, dataloader):
  model.eval()
  scores = []
  all_pred, all_true = [], []
  for batch in dataloader:
      imgs = batch["pixel_values"].to('cuda') if isinstance(batch, dict) else batch[0].to('cuda')
      labels = batch["labels"] if isinstance(batch, dict) else batch[1]
      with torch.no_grad():
        preds = torch.softmax(model(imgs), dim=-1).argmax(dim=-1)
      all_pred.append(preds.cpu())
      all_true.append(labels.cpu())
  y_pred = torch.cat(all_pred).numpy()
  y_true = torch.cat(all_true).numpy()
  cm = metrics.confusion_matrix(y_true, y_pred)
  return metrics.f1_score(y_true, y_pred, average="macro"), cm


def calibration(logits, y_true, bin_count: int = 10):
  '''
  Implements ECE calibration technique
  Algorithm:
  1. Compute confidence (probs) and correct predictions for each item in the batch
  2. Create n bins to discritize your model predictions into them.
  Every bin represents a confidence bound, i.e. how many correct predictions were in that bin?
  3. For each bin, compute:
   1) Accuracy - proportion of correct predictions in that bin
   2) Conf - average confidence for that bin
  4. Apply the ECE formula for each bin: |conf - acc| * bin_size / total_size
  The total sum represents the error, expected calibration error.
  '''
  probs = torch.softmax(logits, dim=-1)
  conf, preds = probs.max(dim=1) # torch.max returns both the argmax and the value corresponding to that. output shape = batch_size

  correct = 1 * (y_true == preds)

  bins = torch.linspace(0, 1, bin_count + 1)

  total_ece = 0.0
  individual_confidence_in_bin = []
  individual_accuracy_in_bin = []

  for i in range(bin_count):
    mask = (conf >= bins[i]) & (conf < bins[i + 1])

    if mask.sum() == 0: # no predictions in bin
      individual_accuracy_in_bin.append(0.0)
      individual_confidence_in_bin.append(0.0)
      continue

    average_accuracy_in_bin = correct[mask].mean()
    average_confidence_in_bin = conf[mask].mean()

    # Update statistics
    # ECE = abs(bin accuracy - bin confidence) * (num elements in bin / batch_size)
    bin_ece = torch.abs(average_accuracy_in_bin - average_confidence_in_bin) * mask.sum().item() / len(conf)
    individual_accuracy_in_bin.append(average_accuracy_in_bin)
    individual_confidence_in_bin.append(average_confidence_in_bin)
    total_ece += bin_ece

  return total_ece, individual_accuracy_in_bin, individual_confidence_in_bin

In [None]:
clean_model = SimpleCNN(10).to('cuda')
clean_optimizer = optim.AdamW(clean_model.parameters(), lr=1e-3)

noisy_model = SimpleCNN(10).to('cuda')
noisy_optimizer = optim.AdamW(noisy_model.parameters(), lr=1e-3)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # true class gets 0.9 while all the other classes get 0.1 / 10 = 0.011. prevents from overconfident predictions

In [None]:
%%time
NUM_EPOCHS = 50
results = {"noisy_loss": [], "clean_loss": [], "noisy_f1": [], "clean_f1": []}

for epoch in range(NUM_EPOCHS):
    noisy_loss = train_one_epoch(noisy_train_loader, noisy_model, criterion, noisy_optimizer)
    clean_loss = train_one_epoch(clean_dataloader, clean_model, criterion, clean_optimizer)

    noisy_f1 = validate(noisy_model, test_loader)[0]
    clean_f1 = validate(clean_model, test_loader)[0]

    results["noisy_loss"].append(noisy_loss)
    results["clean_loss"].append(clean_loss)
    results["noisy_f1"].append(noisy_f1)
    results["clean_f1"].append(clean_f1)

    print(f"[Epoch {epoch+1}/{NUM_EPOCHS}] "
          f"Noisy Loss: {noisy_loss:.3f}, F1: {noisy_f1:.3f} | "
          f"Clean Loss: {clean_loss:.3f}, F1: {clean_f1:.3f}")

In [None]:
plt.figure(figsize=(10,4))

# Loss subplot
plt.subplot(1,2,1)
plt.plot(results["clean_loss"], label="Clean", marker='o')
plt.plot(results["noisy_loss"], label="Noisy", marker='o')
plt.title("Loss"); plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.grid(alpha=0.3)

# F1 subplot
plt.subplot(1,2,2)
plt.plot(results["clean_f1"], label="Clean", marker='o')
plt.plot(results["noisy_f1"], label="Noisy", marker='o')
plt.title("F1 Score"); plt.xlabel("Epoch"); plt.ylabel("F1"); plt.legend(); plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Validation

# Build a confusion matrix to understand which class is performing the best

def visualize_outputs(model, dataloader):
  cm = torch.tensor(validate(model, train_loader)[1])
  # Per class accuracy
  import seaborn as sns
  plt.title("Confusion Matrix visualized")
  plt.xlabel("Predicted Class")
  plt.ylabel("True Class")
  sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
  plt.show()

  plt.title("Per class accuracy")
  plt.bar(np.arange(10), (cm.diag() / cm.sum(dim=1)).numpy())
  plt.show()

In [None]:
# Training output visualization
%%time

print("CLEAN STATS")
visualize_outputs(clean_model, clean_dataloader)
visualize_outputs(clean_model, test_loader)

print("NOISY STATS")
visualize_outputs(noisy_model, noisy_train_loader)
visualize_outputs(noisy_model, test_loader)