<a href="https://colab.research.google.com/github/KTK-Jadoo/pytorch-practice/blob/main/%5BFall_2024%5D_HW_3A_Vision_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ViT Assignment
Authors: Alexander Wan, Aryan Jain

### Assignment Goals


1. Familiarity with the Vision Transformer architecture
2. Familiarity with the self-attention algorithm
3. Practice with PyTorch matrix operations



### Tasks
1. Implement multi-head self-attention
2. Incorporate that into a ViT

### Runtime Acceleration
Colab limits GPU usage, so set `device` below as `'cpu'` and change your runtime to CPU as well (Runtime > Change runtime type) when you're developing, and only change it to `'cuda'` (and your runtime to GPU) when you're ready to train.

In [5]:
#device = 'cpu'
device = 'cuda'

### Multi-head self-attention
Begin by implementing multiheaded self-attention. Do **not** use any `for` loops, and instead put all of the calculations into [batch matrix multiplications](https://pytorch.org/docs/stable/generated/torch.bmm.html) or [Linear layers](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html).

Useful references include the lecture slides on transformers and ViTs, and the [illustrated transformer](https://jalammar.github.io/illustrated-transformer/) blog post.

Hint: you are not required to use the exact skeleton code below. Feel free to use `torch.einsum` if you prefer it (this is something you will have to figure out from the PyTorch documentation yourself; this function is somewhat non-intuitive at first but it's extremely powerful once you truly understand how it works!).


In [2]:
import torch.nn.functional as F
from torch import nn
import torch

class MSA(nn.Module):
  def __init__(self, input_dim, embed_dim, num_heads):
    '''
    input_dim: Dimension of input token embeddings
    embed_dim: Dimension of internal key, query, and value embeddings
    num_heads: Number of self-attention heads
    '''

    super().__init__()

    self.input_dim = input_dim
    self.embed_dim = embed_dim
    self.num_heads = num_heads

    self.K_embed = nn.Linear(input_dim, embed_dim, bias=False)
    self.Q_embed = nn.Linear(input_dim, embed_dim, bias=False)
    self.V_embed = nn.Linear(input_dim, embed_dim, bias=False)
    self.out_embed = nn.Linear(embed_dim, embed_dim, bias=False)

  def forward(self, x):
    '''
    x: input of shape (batch_size, max_length, input_dim)
    return: output of shape (batch_size, max_length, embed_dim)
    '''

    batch_size, max_length, given_input_dim = x.shape
    assert given_input_dim == self.input_dim
    assert max_length % self.num_heads == 0

    # You shouldn't need to initialize any new modules. Everything you need is
    # already in __init__

    # HINT: If you're stuck on how to handle multiple heads without for loops, try to
    # reshape matrix such that the batch_size is num_heads * batch_size
    # e.g. if you have two heads, you'd be doing self-attention twice per instance
    # in the batch, so you essentially have batch_size * 2

    # HINT 2: Feel free to reference: https://d2l.ai/chapter_attention-mechanisms-and-transformers/multihead-attention.html
    # although make sure you understand what each command does

    # this implementation projects KQV before splitting into multiple heads
    # but you can also split into multiple heads first

    # compute KQV as a whole, embedding and
    x = x.reshape(batch_size * max_length, -1)
    K = self.K_embed(x).reshape(batch_size, max_length, self.embed_dim) # (batch_size, max_length, embed_dim)
    Q = self.Q_embed(x).reshape(batch_size, max_length, self.embed_dim)
    V = self.V_embed(x).reshape(batch_size, max_length, self.embed_dim)


    # TODO: split each KQV into heads, by reshaping each into (batch_size, max_length, self.num_heads, indiv_dim)
    indiv_dim = self.embed_dim // self.num_heads
    K = K.view(batch_size, max_length, self.num_heads, indiv_dim)
    Q = Q.view(batch_size, max_length, self.num_heads, indiv_dim)
    V = V.view(batch_size, max_length, self.num_heads, indiv_dim)

    K = K.permute(0, 2, 1, 3) # (batch_size, num_heads, max_length, embed_dim / num_heads)
    Q = Q.permute(0, 2, 1, 3)
    V = V.permute(0, 2, 1, 3)

    K = K.reshape(batch_size * self.num_heads, max_length, indiv_dim)
    Q = Q.reshape(batch_size * self.num_heads, max_length, indiv_dim)
    V = V.reshape(batch_size * self.num_heads, max_length, indiv_dim)

    # transpose and batch matrix multiply
    # This is our K transposed so we can do a simple batched matrix multiplication (see torch.bmm for more details and the quick solution)
    K_T = K.permute(0, 2, 1)
    # TODO: Compute the weights before dividing by square root of d (batch_size * num_heads, max_length, max_length)
    QK = torch.bmm(Q, K_T)

    # calculate weights by dividing everything by the square root of d (self.embed_dim)
    # Scale and softmax
    weights = QK / torch.sqrt(torch.tensor(indiv_dim, dtype=torch.float32, device=x.device))
    weights = F.softmax(weights, dim=-1) # TODO Take the softmax over the last dimension (see torch.functional.Softmax) (batch_size * num_heads, max_length, max_length)

    # TODO get weighted average... see torch.bmm for a one line solution
    # weights is (batch_size * num_heads, max_length, max_length) and V is (batch_size * self.num_heads, max_length, indiv_dim)
    # so we want the matrix multiplication of weights and V
    w_V = torch.bmm(weights, V)

    # rejoin heads
    w_V = w_V.reshape(batch_size, self.num_heads, max_length, indiv_dim)
    w_V = w_V.permute(0, 2, 1, 3) # (batch_size, max_length, num_heads, embed_dim / num_heads)
    w_V = w_V.reshape(batch_size, max_length, self.embed_dim)

    out = self.out_embed(w_V)

    return out

### Implement the ViT architecture
You will be implementing the ViT architecture based on the "An image is worth 16x16 words" paper.

Although the ViT and Transformer architecture are very similar, note a few differences:

1. Image patches instead of discrete tokens as input.
2. [GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used for the linear layers in the transformer layer (instead of ReLU)
3. LayerNorm before the sublayer instead of after.
4. Dropout after every linear layer except for KQV projections and also directly after adding positional embeddings to the patch embeddings.
5. Learnable [CLS] token at the beginning of the input.

A useful reference is Figure 1 in the [paper](https://arxiv.org/pdf/2010.11929.pdf).

First, implement a single layer:

In [3]:
class ViTLayer(nn.Module):
  def __init__(self, num_heads, input_dim, embed_dim, mlp_hidden_dim, dropout=0.1):
    '''
    num_heads: Number of heads for multi-head self-attention
    embed_dim: Dimension of internal key, query, and value embeddings
    mlp_hidden_dim: Hidden dimension of the linear layer
    dropout: Dropout rate
    '''

    super().__init__()

    self.input_dim = input_dim
    self.msa = MSA(input_dim, embed_dim, num_heads)

    self.layernorm1 = nn.LayerNorm(embed_dim)
    self.w_o_dropout = nn.Dropout(dropout)
    self.layernorm2 = nn.LayerNorm(embed_dim)
    self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_hidden_dim),
                              nn.GELU(),
                              nn.Dropout(dropout),
                              nn.Linear(mlp_hidden_dim, embed_dim),
                              nn.Dropout(dropout))

  def forward(self, x):
    '''
    x: input embeddings (batch_size, max_length, input_dim)
    return: output embeddings (batch_size, max_length, embed_dim)
    '''

    # Step 1: LayerNorm of x
    norm_x1 = self.layernorm1(x)  # (batch_size, max_length, embed_dim)

    # Step 2: Self-Attention on normalized input
    attn_out = self.msa(norm_x1)  # (batch_size, max_length, embed_dim)

    # Step 3: Dropout on attention output
    attn_out = self.w_o_dropout(attn_out)

    # Step 4: Residual connection with original x
    res_out1 = x + attn_out  # (batch_size, max_length, embed_dim)

    # Step 5: LayerNorm of residual output
    norm_x2 = self.layernorm2(res_out1)  # (batch_size, max_length, embed_dim)

    # Step 6: MLP block
    mlp_out = self.mlp(norm_x2)  # (batch_size, max_length, embed_dim)

    # Step 7: Residual connection with res_out1
    out = res_out1 + mlp_out  # (batch_size, max_length, embed_dim)

    return out


A portion of the full network is already implemented for you. Your task is to implement the preprocessing code, converting raw images into patch embeddings + positional embeddings + dropout, with a learnable CLS token at the beginning of the input.

Note that patch embeddings are to be added to positional embeddings elementwise, so the input embedding dimensions is size embed_dim.

In [4]:
class ViT(nn.Module):
    def __init__(self, patch_dim, image_dim, num_layers, num_heads, embed_dim, mlp_hidden_dim, num_classes, dropout):
        '''
        patch_dim: patch length and width to split image by
        image_dim: image length and width
        num_layers: number of layers in network
        num_heads: number of heads for multi-head attention
        embed_dim: dimension to project images patches to and dimension to use for position embeddings
        mlp_hidden_dim: hidden dimension of linear layer
        num_classes: number of classes to classify in data
        dropout: dropout rate
        '''

        super().__init__()
        self.num_layers = num_layers
        self.patch_dim = patch_dim
        self.image_dim = image_dim
        self.input_dim = self.patch_dim * self.patch_dim * 3
        self.num_heads = num_heads

        self.patch_embedding = nn.Linear(self.input_dim, embed_dim)
        self.position_embedding = nn.Parameter(torch.zeros(1, (image_dim // patch_dim) ** 2 + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.embedding_dropout = nn.Dropout(dropout)

        self.encoder_layers = nn.ModuleList([])
        for i in range(num_layers):
            self.encoder_layers.append(ViTLayer(num_heads, embed_dim, embed_dim, mlp_hidden_dim, dropout))

        self.mlp_head = nn.Linear(embed_dim, num_classes)
        self.layernorm = nn.LayerNorm(embed_dim)

    def forward(self, images):
        '''
        images: raw image data (batch_size, channels, rows, cols)
        '''

        # Don't hardcode dimensions (except for maybe channels = 3), use the variables in __init__.
        # You shouldn't need to add anything else to __init__, all of the embeddings,
        # dropout etc. are already initialized for you.

        # Put the preprocessed patches in variable "out" with shape (batch_size, length, embed_dim).

        # HINT: You can make image patches with .reshape
        # e.g.
        # x = torch.ones((100, 100))
        # x_patches = x.reshape(4, 25, 4, 25)
        # where you have 4 * 4 patches with each patch being 25 by 25

        h = w = self.image_dim // self.patch_dim
        N = images.size(0)
        images = images.reshape(N, 3, h, self.patch_dim, w, self.patch_dim)
        images = torch.einsum("nchpwq -> nhwpqc", images)
        patches = images.reshape(N, h * w, self.input_dim) # (batch, num_patches_per_image, patch_size_unrolled)

        patch_embeddings = self.patch_embedding(patches)
        patch_embeddings = torch.cat([torch.tile(self.cls_token, (N, 1, 1)), patch_embeddings], dim=1)
        out = patch_embeddings + torch.tile(self.position_embedding, (N, 1, 1)) # We add positional embeddings to our tokens (not concatenated)
        out = self.embedding_dropout(patch_embeddings)

        # add padding s.t. input length is multiple of num_heads
        add_len = (self.num_heads - out.shape[1]) % self.num_heads
        out = torch.cat([out, torch.zeros(N, add_len, out.shape[2], device=device)], dim=1)

        # TODO: Pass through each one of our encoder layers
        # Add padding so input length is a multiple of num_heads
        add_len = (self.num_heads - out.shape[1] % self.num_heads) % self.num_heads
        if add_len > 0:
            padding = torch.zeros(N, add_len, out.shape[2], device=images.device)
            out = torch.cat([out, padding], dim=1)  # Add padding

        # Pass through encoder layers
        for layer in self.encoder_layers:
            out = layer(out)  # Each layer modifies out

        # Pop off and read our classification token we added, see what the value is
        cls_head = self.layernorm(torch.squeeze(out[:, 0], dim=1))
        logits = self.mlp_head(cls_head)
        return logits

def get_vit_tiny(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=3,
               embed_dim=192, mlp_hidden_dim=768, num_classes=num_classes, dropout=0.1)

def get_vit_small(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=6,
               embed_dim=384, mlp_hidden_dim=1536, num_classes=num_classes, dropout=0.1)

def get_vit_base(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=12,
               embed_dim=768, mlp_hidden_dim=3072, num_classes=num_classes, dropout=0.1)

Now let's train the model! You don't need to write any code for this - just run the cell.

Remember to change the device variable (in the cell at the beginning of the notebook) to 'cuda' and change your runtime to GPU (Runtime > Change runtime type) as well. For reference, each epoch in the staff solution takes ~3 minutes (so training for 30 epochs will take ~1.5 hours on the Colab GPU; we know this is a long training session)

Try to get 65%+ accuracy after 30 epochs.

In [7]:
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torchvision.datasets as datasets
import torchvision
import math
import torch.optim as optim
from tqdm.notebook import tqdm

data_root = './data/cifar10'
train_size = 40000
val_size = 10000

batch_size = 32

transform_train = T.Compose([
    T.Resize(40),
    T.RandomCrop(32),
    T.RandomHorizontalFlip(),
    T.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.95, 1.05)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

transform_val = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

train_dataset = datasets.CIFAR10(
    root=data_root,
    train=True,
    download=True,
    transform=transform_train,
)

val_dataset = datasets.CIFAR10(
    root=data_root,
    train=True,
    download=True,
    transform=transform_val,
)

from torch.utils.data import sampler

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          sampler=sampler.SubsetRandomSampler(range(train_size)))

val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        sampler=sampler.SubsetRandomSampler(range(train_size, 50000)))

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

vit = get_vit_small().to(device)

learning_rate = 5e-4 * batch_size / 256
num_epochs = 30
weight_decay = 0.1

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(vit.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=weight_decay)

train_losses = []
val_losses = []
val_accuracies = []
for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc = 0.0
    train_total = 0
    vit.train()
    for inputs, labels in tqdm(train_loader):
        # 1. Set inputs and labels to the device
        inputs = inputs.to(device)
        labels = labels.to(device)

        # 2. Zero out gradients
        optimizer.zero_grad()

        # 3. Forward pass through the ViT model
        outputs = vit(inputs)

        # 4. Compute the loss
        loss = criterion(outputs, labels.long())

        # 5. Backpropagate
        loss.backward()

        # 6. Perform an optimizer step
        optimizer.step()

        train_loss += loss.item() * inputs.shape[0]
        train_acc += torch.sum((torch.argmax(outputs, dim=1) == labels)).item()
        train_total += inputs.shape[0]
    train_loss = train_loss / train_total
    train_acc = train_acc / train_total
    train_losses.append(train_loss)

    val_loss = 0.0
    val_acc = 0.0
    val_total = 0
    vit.eval()
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = vit(inputs)
            loss = criterion(outputs, labels.long())

            val_loss += loss.item() * inputs.shape[0]
            val_acc += torch.sum((torch.argmax(outputs, dim=1) == labels)).item()
            val_total += inputs.shape[0]
    val_loss = val_loss / val_total
    val_acc = val_acc / val_total
    val_losses.append(val_loss)

    val_accuracies.append(val_acc)
    if val_acc >= max(val_accuracies):
        torch.save(vit.state_dict(), 'best_model.pth')

    print(f'[{epoch + 1:2d}] train loss: {train_loss:.3f} | train accuracy: {train_acc:.3f} | val loss: {val_loss:.3f} | val accuracy: {val_acc:.3f}')

print('Finished Training')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar10/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 44.6MB/s]


Extracting ./data/cifar10/cifar-10-python.tar.gz to ./data/cifar10
Files already downloaded and verified


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 1] train loss: 1.846 | train accuracy: 0.298 | val loss: 1.660 | val accuracy: 0.391


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 2] train loss: 1.579 | train accuracy: 0.414 | val loss: 1.669 | val accuracy: 0.411


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 3] train loss: 1.451 | train accuracy: 0.472 | val loss: 1.491 | val accuracy: 0.476


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 4] train loss: 1.368 | train accuracy: 0.505 | val loss: 1.447 | val accuracy: 0.482


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 5] train loss: 1.318 | train accuracy: 0.524 | val loss: 1.454 | val accuracy: 0.493


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 6] train loss: 1.268 | train accuracy: 0.544 | val loss: 1.480 | val accuracy: 0.497


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 7] train loss: 1.236 | train accuracy: 0.556 | val loss: 1.474 | val accuracy: 0.486


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 8] train loss: 1.197 | train accuracy: 0.570 | val loss: 1.475 | val accuracy: 0.501


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 9] train loss: 1.163 | train accuracy: 0.582 | val loss: 1.286 | val accuracy: 0.543


  0%|          | 0/1250 [00:00<?, ?it/s]

[10] train loss: 1.136 | train accuracy: 0.593 | val loss: 1.300 | val accuracy: 0.548


  0%|          | 0/1250 [00:00<?, ?it/s]

[11] train loss: 1.108 | train accuracy: 0.604 | val loss: 1.308 | val accuracy: 0.551


  0%|          | 0/1250 [00:00<?, ?it/s]

[12] train loss: 1.085 | train accuracy: 0.611 | val loss: 1.414 | val accuracy: 0.527


  0%|          | 0/1250 [00:00<?, ?it/s]

[13] train loss: 1.071 | train accuracy: 0.618 | val loss: 1.312 | val accuracy: 0.553


  0%|          | 0/1250 [00:00<?, ?it/s]

[14] train loss: 1.050 | train accuracy: 0.626 | val loss: 1.390 | val accuracy: 0.539


  0%|          | 0/1250 [00:00<?, ?it/s]

[15] train loss: 1.028 | train accuracy: 0.635 | val loss: 1.273 | val accuracy: 0.567


  0%|          | 0/1250 [00:00<?, ?it/s]

[16] train loss: 1.022 | train accuracy: 0.633 | val loss: 1.389 | val accuracy: 0.552


  0%|          | 0/1250 [00:00<?, ?it/s]

[17] train loss: 0.994 | train accuracy: 0.647 | val loss: 1.204 | val accuracy: 0.590


  0%|          | 0/1250 [00:00<?, ?it/s]

[18] train loss: 0.982 | train accuracy: 0.651 | val loss: 1.395 | val accuracy: 0.545


  0%|          | 0/1250 [00:00<?, ?it/s]

[19] train loss: 0.975 | train accuracy: 0.653 | val loss: 1.288 | val accuracy: 0.567


  0%|          | 0/1250 [00:00<?, ?it/s]

[20] train loss: 0.959 | train accuracy: 0.660 | val loss: 1.307 | val accuracy: 0.563


  0%|          | 0/1250 [00:00<?, ?it/s]

[21] train loss: 0.942 | train accuracy: 0.661 | val loss: 1.188 | val accuracy: 0.601


  0%|          | 0/1250 [00:00<?, ?it/s]

[22] train loss: 0.922 | train accuracy: 0.675 | val loss: 1.270 | val accuracy: 0.582


  0%|          | 0/1250 [00:00<?, ?it/s]

[23] train loss: 0.914 | train accuracy: 0.676 | val loss: 1.257 | val accuracy: 0.573


  0%|          | 0/1250 [00:00<?, ?it/s]

[24] train loss: 0.904 | train accuracy: 0.676 | val loss: 1.447 | val accuracy: 0.554


  0%|          | 0/1250 [00:00<?, ?it/s]

[25] train loss: 0.894 | train accuracy: 0.684 | val loss: 1.351 | val accuracy: 0.572


  0%|          | 0/1250 [00:00<?, ?it/s]

[26] train loss: 0.882 | train accuracy: 0.688 | val loss: 1.307 | val accuracy: 0.570


  0%|          | 0/1250 [00:00<?, ?it/s]

[27] train loss: 0.873 | train accuracy: 0.690 | val loss: 1.239 | val accuracy: 0.602


  0%|          | 0/1250 [00:00<?, ?it/s]

[28] train loss: 0.866 | train accuracy: 0.691 | val loss: 1.333 | val accuracy: 0.578


  0%|          | 0/1250 [00:00<?, ?it/s]

[29] train loss: 0.856 | train accuracy: 0.698 | val loss: 1.215 | val accuracy: 0.596


  0%|          | 0/1250 [00:00<?, ?it/s]

[30] train loss: 0.846 | train accuracy: 0.699 | val loss: 1.323 | val accuracy: 0.581
Finished Training


### Autograder and Submission

After you feel confident that you have a decent model, run the cell below.

Feel free to read the code block but **PLEASE DO NOT TOUCH IT**: this will produce a pickle file that will contain your model's predictions on the CIFAR-10 validation set --- tampering with the code block below might mess up the file that you will submit to the Gradescope autograder.

In [9]:
import pickle

cifar_test = datasets.CIFAR10('./data/cifar10_test', download = True, train = False, transform = transform_val)
loader_test = DataLoader(cifar_test, batch_size=32, shuffle=False)

vit.load_state_dict(torch.load('best_model.pth'))
vit.eval()  # set model to evaluation mode
predictions = []
with torch.no_grad():
    for x, _ in loader_test:
        x = x.to(device=device)  # move to device, e.g. GPU
        scores = vit(x)
        _, preds = scores.max(1)
        predictions.append(preds)
predictions = torch.cat(predictions).tolist()
with open("my_predictions.pickle", "wb") as file:
    pickle.dump(predictions, file)

Files already downloaded and verified


  vit.load_state_dict(torch.load('best_model.pth'))
