# Lab 3: Transformers

In this lab we apply the Transformer architecture to two tasks:
1. **Sequence-to-Sequence** modelling (reversing a sequence)
2. **Set Anomaly Detection** (finding the odd image out)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

DATASET_PATH = os.environ.get("DATASET_PATH", "./data")

---
## Part 1: Sequence to Sequence

Given a sequence of $N$ numbers between $0$ and $M$, the task is to **reverse** the input sequence. In NumPy notation, if our input is $x$, the output should be $x$[::-1].

We'll use only a Transformer encoder for this simple task.

### The data

In [None]:
class ReverseDataset(data.Dataset):

    def __init__(self, num_categories, seq_len, size):
        super().__init__()
        self.num_categories = num_categories
        self.seq_len = seq_len
        self.size = size
        self.data = torch.randint(self.num_categories, size=(self.size, self.seq_len))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        inp_data = self.data[idx]
        labels = torch.flip(inp_data, dims=(0,))
        return inp_data, labels

In [None]:
dataset = partial(ReverseDataset, 10, 16)
train_loader = data.DataLoader(dataset(50000), batch_size=128, shuffle=True, drop_last=True, pin_memory=True)
val_loader   = data.DataLoader(dataset(1000),  batch_size=128)
test_loader  = data.DataLoader(dataset(10000), batch_size=128)

In [None]:
# Look at a sample
inp_data, labels = train_loader.dataset[0]
print("Input data:", inp_data)
print("Labels:    ", labels)

### Exercise 1.1: Build the model

Create a `ReversePredictor` model:
- Use `nn.Embedding` to convert each input number into an embedding vector
- Add **positional encoding** so the model knows about the order of the sequence
- Pass through one or more `nn.TransformerEncoderLayer` blocks
- Predict the output for each position using a linear layer on top
- Use Cross-Entropy loss

Hint: `nn.TransformerEncoder` wraps multiple `nn.TransformerEncoderLayer` blocks. Start with 1 layer and 1 attention head.

### Exercise 1.2: Train the model

Write a training loop and train the model. Tips:
- A single encoder block and single attention head should be enough
- Try gradient clipping (`torch.nn.utils.clip_grad_norm_`) to stabilize training
- Test your model on the test set — it should reach near-perfect accuracy

### Exercise 1.3: Visualize the attention

Visualize the attention weights from the Multi-Head Attention block for an arbitrary input.

Hint: You can pass `need_weights=True` to the attention layer, or use hooks to capture them. The attention pattern should show that position $i$ attends to position $N-1-i$ (the reversed position).

---
## Part 2: Set Anomaly Detection

Transformers are well-suited for **set** problems because Multi-Head Attention is permutation-equivariant.

**Task:** Given a set of 10 images where 9 belong to the same class and 1 does not, identify the anomaly.

We use CIFAR-100 (100 classes, 600 images each at 32x32). We first extract features using a pre-trained ResNet34, then feed these features into a Transformer.

### The data

In [None]:
# ImageNet normalization statistics
DATA_MEANS = np.array([0.485, 0.456, 0.406])
DATA_STD = np.array([0.229, 0.224, 0.225])
TORCH_DATA_MEANS = torch.from_numpy(DATA_MEANS).view(1, 3, 1, 1)
TORCH_DATA_STD = torch.from_numpy(DATA_STD).view(1, 3, 1, 1)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(DATA_MEANS, DATA_STD)
])

train_set = CIFAR100(root=DATASET_PATH, train=True,  transform=transform, download=True)
test_set  = CIFAR100(root=DATASET_PATH, train=False, transform=transform, download=True)

### Exercise 2.1: Extract features with a pre-trained ResNet

Load a pre-trained ResNet34 from `torchvision.models` and extract features (the output before the final classification layer) for all images. Store these features so you don't need to recompute them.

Hint: Remove the last FC layer or use a hook to capture the features from the penultimate layer.

### Setting up train/val/test splits and the anomaly dataset

We split the training data 90/10 into train/val in a balanced way, then create the anomaly detection datasets.

In [None]:
# Split train into train + val (90/10 balanced)
labels = torch.LongTensor(train_set.targets)
num_labels = labels.max() + 1
sorted_indices = torch.argsort(labels).reshape(num_labels, -1)

num_val_exmps = sorted_indices.shape[1] // 10
val_indices   = sorted_indices[:, :num_val_exmps].reshape(-1)
train_indices = sorted_indices[:, num_val_exmps:].reshape(-1)

# Replace train_set_feats below with your extracted features
# train_feats, train_labels = train_set_feats[train_indices], labels[train_indices]
# val_feats,   val_labels   = train_set_feats[val_indices],   labels[val_indices]

In [None]:
class SetAnomalyDataset(data.Dataset):

    def __init__(self, img_feats, labels, set_size=10, train=True):
        """
        img_feats - Tensor [num_imgs, feat_dim]: high-level features from ResNet
        labels    - Tensor [num_imgs]: class labels
        set_size  - Number of elements in a set (N-1 same class + 1 anomaly)
        train     - If True, sample a new set each time __getitem__ is called
        """
        super().__init__()
        self.img_feats = img_feats
        self.labels = labels
        self.set_size = set_size - 1
        self.train = train

        self.num_labels = labels.max() + 1
        self.img_idx_by_label = torch.argsort(self.labels).reshape(self.num_labels, -1)

        if not train:
            self.test_sets = self._create_test_sets()

    def _create_test_sets(self):
        np.random.seed(42)
        test_sets = [self.sample_img_set(self.labels[idx]) for idx in range(len(self.img_feats))]
        return torch.stack(test_sets, dim=0)

    def sample_img_set(self, anomaly_label):
        set_label = np.random.randint(self.num_labels - 1)
        if set_label >= anomaly_label:
            set_label += 1
        img_indices = np.random.choice(self.img_idx_by_label.shape[1], size=self.set_size, replace=False)
        img_indices = self.img_idx_by_label[set_label, img_indices]
        return img_indices

    def __len__(self):
        return self.img_feats.shape[0]

    def __getitem__(self, idx):
        anomaly = self.img_feats[idx]
        if self.train:
            img_indices = self.sample_img_set(self.labels[idx])
        else:
            img_indices = self.test_sets[idx]
        # Anomaly is always the last image
        img_set = torch.cat([self.img_feats[img_indices], anomaly[None]], dim=0)
        indices = torch.cat([img_indices, torch.LongTensor([idx])], dim=0)
        label = img_set.shape[0] - 1
        return img_set, indices, label

In [None]:
# Create data loaders (uncomment after extracting features)
SET_SIZE = 10
test_labels = torch.LongTensor(test_set.targets)

# train_anom_dataset = SetAnomalyDataset(train_feats, train_labels, set_size=SET_SIZE, train=True)
# val_anom_dataset   = SetAnomalyDataset(val_feats,   val_labels,   set_size=SET_SIZE, train=False)
# test_anom_dataset  = SetAnomalyDataset(test_feats,  test_labels,  set_size=SET_SIZE, train=False)

# train_anom_loader = data.DataLoader(train_anom_dataset, batch_size=64, shuffle=True,  drop_last=True, pin_memory=True)
# val_anom_loader   = data.DataLoader(val_anom_dataset,   batch_size=64, shuffle=False)
# test_anom_loader  = data.DataLoader(test_anom_dataset,  batch_size=64, shuffle=False)

In [None]:
def visualize_exmp(indices, orig_dataset):
    images = [orig_dataset[idx][0] for idx in indices.reshape(-1)]
    images = torch.stack(images, dim=0)
    images = images * TORCH_DATA_STD + TORCH_DATA_MEANS
    img_grid = torchvision.utils.make_grid(images, nrow=SET_SIZE, normalize=True, pad_value=0.5, padding=16)
    plt.figure(figsize=(12, 8))
    plt.title("Anomaly examples on CIFAR100 (last image = anomaly)")
    plt.imshow(img_grid.permute(1, 2, 0))
    plt.axis('off')
    plt.show()

# Uncomment after creating test_anom_loader:
# _, indices, _ = next(iter(test_anom_loader))
# visualize_exmp(indices[:4], test_set)

### Exercise 2.2: Build the anomaly detection model

Write a Transformer-based model that takes a **set** of image features and outputs one logit per image. Apply softmax over these logits and train the anomaly image to have the highest probability.

Since the prediction must be permutation-equivariant, the Transformer is a natural choice. The input to each element is a ResNet feature vector (512-dim for ResNet34).

### Exercise 2.3: Train the model

Write a training loop (same structure as for the reverse task). Train your model and evaluate on the test set.

### Exercise 2.4: Visualize the attention

Plot the images in the input set, the model's prediction, and the attention maps from different heads/layers. This helps interpret what information the model shares between images.

Explore different input examples — are there cases where the task is harder (e.g. visually similar classes)?