<a href="https://colab.research.google.com/github/AriPathak/ViT-CS198-Fall-2022-HW-4-Solution-/blob/main/CS198_HW_4_ViTs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ViT Assignment
Authors: Alexander Wan, Aryan Jain

### Assignment Goals


1. Familiarity with the Vision Transformer architecture
2. Familiarity with the self-attention algorithm
3. Practice with PyTorch matrix operations



### Tasks
1. Implement multi-head self-attention
2. Incorporate that into a ViT

### Runtime Acceleration
Colab limits GPU usage, so set `device` below as `'cpu'` and change your runtime to CPU as well (Runtime > Change runtime type) when you're developing, and only change it to `'cuda'` (and your runtime to GPU) when you're ready to train.

In [None]:
#device = 'cpu'
device = 'cuda'

### Multi-head self-attention
Begin by implementing multiheaded self-attention. Do **not** use any `for` loops, and instead put all of the calculations into [batch matrix multiplications](https://pytorch.org/docs/stable/generated/torch.bmm.html) or [Linear layers](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html).

Useful references include the lecture slides and the [illustrated transformer](https://jalammar.github.io/illustrated-transformer/).


In [None]:
import torch.nn.functional as F
from torch import nn
import torch
import numpy as np

class MSA(nn.Module):
  def __init__(self, input_dim, embed_dim, num_heads):
      super().__init__()
      self.n_heads = num_heads
      self.embed_dim = embed_dim
      self.dim = input_dim
      self.head_dim = self.embed_dim // self.n_heads

      # Linear projections for Q, K, and V
      self.wq = nn.Linear(self.dim, self.embed_dim, bias=False)
      self.wk = nn.Linear(self.dim, self.embed_dim, bias=False)
      self.wv = nn.Linear(self.dim, self.embed_dim, bias=False)
      self.wo = nn.Linear(self.embed_dim, self.dim, bias=False)

  def forward(self, x):
      b, seq_len, dim = x.shape  # b: batch size, seq_len: sequence length

      assert dim == self.dim, "dim is not matching"

      q = self.wq(x)  # [b, seq_len, n_heads*head_dim]
      k = self.wk(x)  # [b, seq_len, n_heads*head_dim]
      v = self.wv(x)  # [b, seq_len, n_heads*head_dim]

        # Reshape the tensors for multi-head operations
      q = q.contiguous().view(b, seq_len, self.n_heads, self.head_dim)  # [b, seq_len, n_heads, head_dim]
      k = k.contiguous().view(b, seq_len, self.n_heads, self.head_dim)  # [b, seq_len, n_heads, head_dim]
      v = v.contiguous().view(b, seq_len, self.n_heads, self.head_dim)  # [b, seq_len, n_heads, head_dim]

        # Transpose to bring the head dimension to the front
      q = q.transpose(1, 2)  # [b, n_heads, seq_len, head_dim]
      k = k.transpose(1, 2)  # [b, n_heads, seq_len, head_dim]
      v = v.transpose(1, 2)  # [b, n_heads, seq_len, head_dim]


        # Compute attention scores and apply softmax
      attn = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)  # [b, n_heads, seq_len, seq_len]

      attn_scores = F.softmax(attn, dim=-1)  # [b, n_heads, seq_len, seq_len]

        # Compute the attended features
      out = torch.matmul(attn_scores, v)  # [b, n_heads, seq_len, head_dim]
      out = out.contiguous().view(b, seq_len, -1)  # [b, seq_len, n_heads*head_dim]


      return self.wo(out)  # [b, seq_len, dim]

### Implement the ViT architecture
You will be implementing the ViT architecture based on the "An image is worth 16x16 words" paper.

Although the ViT and Transformer architecture are very similar, note a few differences:

1. Image patches instead of discrete tokens as input.
2. [GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used for the linear layers in the transformer layer (instead of ReLU)
3. LayerNorm before the sublayer instead of after.
4. Dropout after every linear layer except for KQV projections and also directly after adding positional embeddings to the patch embeddings.
5. Learnable [CLS] token at the beginning of the input.

A useful reference is Figure 1 in the [paper](https://arxiv.org/pdf/2010.11929.pdf).

First, implement a single layer:

In [None]:
class ViTLayer(nn.Module):
  def __init__(self, num_heads, input_dim, embed_dim, mlp_hidden_dim, dropout=0.1):
    '''
    num_heads: Number of heads for multi-head self-attention
    embed_dim: Dimension of internal key, query, and value embeddings
    mlp_hidden_dim: Hidden dimension of the linear layer
    dropout: Dropout rate
    '''

    super().__init__()

    self.input_dim = input_dim
    self.msa = MSA(input_dim, embed_dim, num_heads)

    self.layernorm1 = nn.LayerNorm(embed_dim)
    self.w_o_dropout = nn.Dropout(dropout)
    self.layernorm2 = nn.LayerNorm(embed_dim)
    self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_hidden_dim),
                              nn.GELU(),
                              nn.Dropout(dropout),
                              nn.Linear(mlp_hidden_dim, embed_dim),
                              nn.Dropout(dropout))

  def forward(self, x):
    '''
    x: input embeddings (batch_size, max_length, input_dim)
    return: output embeddings (batch_size, max_length, embed_dim)
    '''
    res = x
    x = x + self.w_o_dropout(self.msa(self.layernorm1(x)))
    x = x + self.mlp(self.layernorm2(x))
    return x
    # TODO: Fill in the code for the forward pass below
    # You shouldn't need to initialize any more modules, everything you need is already
    # in __init__
    # A forward function consists of:
    # 1) LayerNorm of x
    # 2) Self-Attention on output of 1)
    # 3) Dropout
    # 4) Residual w/ original x
    # 5) LayerNorm
    # 6) MLP
    # 7) Residual


A portion of the full network is already implemented for you. Your task is to implement the preprocessing code, converting raw images into patch embeddings + positional embeddings + dropout, with a learnable CLS token at the beginning of the input.

Note that patch embeddings are to be added to positional embeddings elementwise, so the input embedding dimensions is size embed_dim.

In [None]:
class ViT(nn.Module):
  def __init__(self, patch_dim, image_dim, num_layers, num_heads, embed_dim, mlp_hidden_dim, num_classes, dropout):
    '''
    patch_dim: patch length and width to split image by
    image_dim: image length and width
    num_layers: number of layers in network
    num_heads: number of heads for multi-head attention
    embed_dim: dimension to project images patches to and dimension to use for position embeddings
    mlp_hidden_dim: hidden dimension of linear layer
    num_classes: number of classes to classify in data
    dropout: dropout rate
    '''

    super().__init__()
    self.num_layers = num_layers
    self.patch_dim = patch_dim
    self.image_dim = image_dim
    self.input_dim = self.patch_dim * self.patch_dim * 3
    self.num_heads = num_heads

    self.patch_embedding = nn.Linear(self.input_dim, embed_dim)
    self.position_embedding = nn.Parameter(torch.zeros(1, (image_dim // patch_dim) ** 2 + 1, embed_dim))
    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    self.embedding_dropout = nn.Dropout(dropout)

    self.encoder_layers = nn.ModuleList([])
    for i in range(num_layers):
      self.encoder_layers.append(ViTLayer(num_heads, embed_dim, embed_dim, mlp_hidden_dim, dropout))

    self.mlp_head = nn.Linear(embed_dim, num_classes)
    self.layernorm = nn.LayerNorm(embed_dim)

  def forward(self, images):
    '''
    images: raw image data (batch_size, channels, rows, cols)
    '''

    # Don't hardcode dimensions (except for maybe channels = 3), use the variables in __init__.
    # You shouldn't need to add anything else to __init__, all of the embeddings,
    # dropout etc. are already initialized for you.

    # Put the preprocessed patches in variable "out" with shape (batch_size, length, embed_dim).

    # HINT: You can make image patches with .reshape
    # e.g.
    # x = torch.ones((100, 100))
    # x_patches = x.reshape(4, 25, 4, 25)
    # where you have 4 * 4 patches with each patch being 25 by 25
    h = w = self.image_dim // self.patch_dim
    N = images.size(0)
    images = images.reshape(N, 3, h, self.patch_dim, w, self.patch_dim)
    images = torch.einsum("nchpwq -> nhwpqc", images)
    patches = images.reshape(N, h * w, self.input_dim) # (batch, num_patches_per_image, patch_size_unrolled)

    patch_embeddings = self.patch_embedding(patches)# TODO: Pass through our patch embedding layer
    patch_embeddings = torch.cat([torch.tile(self.cls_token, (N, 1, 1)),
                                  patch_embeddings], dim=1)
    out = patch_embeddings + torch.tile(self.position_embedding, (N, 1, 1)) # We add positional embeddings to our tokens (not concatenated)
    out = self.embedding_dropout(out) # TODO: Pass through our embedding dropout layer

    # add padding s.t. input length is multiple of num_heads
    add_len = (self.num_heads - out.shape[1]) % self.num_heads
    out = torch.cat([out, torch.zeros(N, add_len, out.shape[2], device=device)], dim=1)

    # TODO: Pass through each one of our encoder layers
    for layer in self.encoder_layers:
      #print(layer)
      out = layer(out)
    # Pop off and read our classification token we added, see what the value is
    cls_head = self.layernorm(torch.squeeze(out[:, 0], dim=1))
    logits = self.mlp_head(cls_head)
    return logits

def get_vit_tiny(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=3,
              embed_dim=192, mlp_hidden_dim=768, num_classes=num_classes, dropout=0.1)

def get_vit_small(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=6,
               embed_dim=384, mlp_hidden_dim=1536, num_classes=num_classes, dropout=0.1)

Now let's train the model! You don't need to write any code for this - just run the cell.

Remember to change the device variable (in the cell at the beginning of the notebook) to 'cuda' and change your runtime to GPU (Runtime > Change runtime type) as well.

Try to get 60%+ accuracy after 30 epochs.

In [None]:
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision
import math
import torch.optim as optim
from tqdm.notebook import tqdm

cifar10_mean = torch.tensor([0.49139968, 0.48215827, 0.44653124])
cifar10_std = torch.tensor([0.24703233, 0.24348505, 0.26158768])

class Cifar10Dataset(Dataset):
    def __init__(self, train):
        self.transform = transforms.Compose([
                                                transforms.Resize(40),
                                                transforms.RandomCrop(32),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.ToTensor(),
                                                transforms.Normalize(cifar10_mean, cifar10_std)
                                            ])
        self.dataset = torchvision.datasets.CIFAR10(root='./SSL-Vision/data',
                                                    train=train,
                                                    download=True)
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        img = self.transform(img)
        return img, label

batch_size = 512

trainset = Cifar10Dataset(True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)#create dataloader
# TODO: Pass our dataset trainset into a torch Dataloader object, with shuffle = True and the batch_size=batch_size, num_workers=2

testset = Cifar10Dataset(False)
testloader = DataLoader(trainset, batch_size=1, shuffle=False, num_workers=2)# TODO: create a test dataset the same as the train loader but with shuffle=False and the test dataset

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')



Files already downloaded and verified
Files already downloaded and verified


In [None]:
vit = get_vit_small().to(device)
vit = torch.nn.DataParallel(vit)

learning_rate = 5e-4 * batch_size / 256
num_epochs = 30
warmup_fraction = 0.1
weight_decay = 0.1

total_steps = math.ceil(len(trainset) / batch_size) * num_epochs
# total_steps = num_epochs
warmup_steps = total_steps * warmup_fraction
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(vit.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=weight_decay)

train_losses = []
test_losses = []
for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc = 0.0
    train_total = 0
    vit.train()
    for inputs, labels in tqdm(trainloader):
        """TODO:
        1. Set inputs and labels to be on device
        2. zero out our gradients
        3. pass our inputs through the ViT
        4. pass our outputs / labels into our loss / criterion
        5. backpropagate
        6. step our optimizeer
        """
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        output = vit(inputs)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.shape[0]
        train_acc += torch.sum((torch.argmax(output, dim=1) == labels)).item()
        train_total += inputs.shape[0]
    train_loss = train_loss / train_total
    train_acc = train_acc / train_total
    train_losses.append(train_loss)

    test_loss = 0.0
    test_acc = 0.0
    test_total = 0
    vit.eval()
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = vit(inputs)
            loss = criterion(outputs, labels.long())

            test_loss += loss.item() * inputs.shape[0]
            test_acc += torch.sum((torch.argmax(outputs, dim=1) == labels)).item()
            test_total += inputs.shape[0]
    test_loss = test_loss / test_total
    test_acc = test_acc / test_total
    test_losses.append(test_loss)

    print(f'[{epoch + 1:2d}] train loss: {train_loss:.3f} | train accuracy: {train_acc:.3f} | test_loss: {test_loss:.3f} | test_accuracy: {test_acc:.3f}')

print('Finished Training')

  0%|          | 0/98 [00:00<?, ?it/s]

[ 1] train loss: 2.034 | train accuracy: 0.252 | test_loss: 1.788 | test_accuracy: 0.328


  0%|          | 0/98 [00:00<?, ?it/s]

[ 2] train loss: 1.734 | train accuracy: 0.348 | test_loss: 1.677 | test_accuracy: 0.375


  0%|          | 0/98 [00:00<?, ?it/s]

[ 3] train loss: 1.602 | train accuracy: 0.409 | test_loss: 1.540 | test_accuracy: 0.428


  0%|          | 0/98 [00:00<?, ?it/s]

[ 4] train loss: 1.525 | train accuracy: 0.439 | test_loss: 1.458 | test_accuracy: 0.466


  0%|          | 0/98 [00:00<?, ?it/s]

[ 5] train loss: 1.465 | train accuracy: 0.461 | test_loss: 1.439 | test_accuracy: 0.468


  0%|          | 0/98 [00:00<?, ?it/s]

[ 6] train loss: 1.427 | train accuracy: 0.477 | test_loss: 1.414 | test_accuracy: 0.483


  0%|          | 0/98 [00:00<?, ?it/s]

[ 7] train loss: 1.391 | train accuracy: 0.492 | test_loss: 1.357 | test_accuracy: 0.504


  0%|          | 0/98 [00:00<?, ?it/s]

[ 8] train loss: 1.362 | train accuracy: 0.504 | test_loss: 1.332 | test_accuracy: 0.511


  0%|          | 0/98 [00:00<?, ?it/s]

[ 9] train loss: 1.329 | train accuracy: 0.517 | test_loss: 1.294 | test_accuracy: 0.530


  0%|          | 0/98 [00:00<?, ?it/s]

[10] train loss: 1.314 | train accuracy: 0.523 | test_loss: 1.279 | test_accuracy: 0.536


  0%|          | 0/98 [00:00<?, ?it/s]

[11] train loss: 1.280 | train accuracy: 0.536 | test_loss: 1.279 | test_accuracy: 0.536


  0%|          | 0/98 [00:00<?, ?it/s]

[12] train loss: 1.270 | train accuracy: 0.539 | test_loss: 1.282 | test_accuracy: 0.538


  0%|          | 0/98 [00:00<?, ?it/s]

[13] train loss: 1.246 | train accuracy: 0.547 | test_loss: 1.203 | test_accuracy: 0.563


  0%|          | 0/98 [00:00<?, ?it/s]

[14] train loss: 1.235 | train accuracy: 0.550 | test_loss: 1.203 | test_accuracy: 0.564


  0%|          | 0/98 [00:00<?, ?it/s]

[15] train loss: 1.218 | train accuracy: 0.555 | test_loss: 1.238 | test_accuracy: 0.558


  0%|          | 0/98 [00:00<?, ?it/s]

[16] train loss: 1.195 | train accuracy: 0.569 | test_loss: 1.179 | test_accuracy: 0.573


  0%|          | 0/98 [00:00<?, ?it/s]

[17] train loss: 1.177 | train accuracy: 0.575 | test_loss: 1.154 | test_accuracy: 0.582


  0%|          | 0/98 [00:00<?, ?it/s]

[18] train loss: 1.175 | train accuracy: 0.574 | test_loss: 1.140 | test_accuracy: 0.592


  0%|          | 0/98 [00:00<?, ?it/s]

[19] train loss: 1.151 | train accuracy: 0.584 | test_loss: 1.131 | test_accuracy: 0.591


  0%|          | 0/98 [00:00<?, ?it/s]

[20] train loss: 1.145 | train accuracy: 0.588 | test_loss: 1.113 | test_accuracy: 0.597


  0%|          | 0/98 [00:00<?, ?it/s]

[21] train loss: 1.123 | train accuracy: 0.596 | test_loss: 1.107 | test_accuracy: 0.601


  0%|          | 0/98 [00:00<?, ?it/s]

[22] train loss: 1.125 | train accuracy: 0.596 | test_loss: 1.083 | test_accuracy: 0.613


  0%|          | 0/98 [00:00<?, ?it/s]

[23] train loss: 1.104 | train accuracy: 0.603 | test_loss: 1.103 | test_accuracy: 0.599


  0%|          | 0/98 [00:00<?, ?it/s]

[24] train loss: 1.098 | train accuracy: 0.605 | test_loss: 1.047 | test_accuracy: 0.623


  0%|          | 0/98 [00:00<?, ?it/s]

[25] train loss: 1.084 | train accuracy: 0.610 | test_loss: 1.060 | test_accuracy: 0.620


  0%|          | 0/98 [00:00<?, ?it/s]

[26] train loss: 1.068 | train accuracy: 0.617 | test_loss: 1.023 | test_accuracy: 0.633


  0%|          | 0/98 [00:00<?, ?it/s]

[27] train loss: 1.052 | train accuracy: 0.620 | test_loss: 1.013 | test_accuracy: 0.639


  0%|          | 0/98 [00:00<?, ?it/s]

[28] train loss: 1.046 | train accuracy: 0.622 | test_loss: 1.008 | test_accuracy: 0.639


  0%|          | 0/98 [00:00<?, ?it/s]

[29] train loss: 1.038 | train accuracy: 0.626 | test_loss: 1.007 | test_accuracy: 0.638


  0%|          | 0/98 [00:00<?, ?it/s]

[30] train loss: 1.018 | train accuracy: 0.635 | test_loss: 0.981 | test_accuracy: 0.649
Finished Training


In [None]:
torch.save(vit.module.state_dict(), "ViT_MSA_BerkleyHW4.pth")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from random import randint
from time import sleep
import matplotlib.pyplot as plt

vit.eval()
with torch.no_grad():
  for inputs, labels in testloader:
    inputs = inputs.to('cuda:0')
    output = vit(inputs)
    _, pred = torch.max(output, dim=1)
    print(f"ViT CLS token Classification: {pred.detach().item()}")
    print(f"Ground Truth label: {labels.item()}")
    sleep(0.5)

ViT CLS token Classification: 6
Ground Truth label: 6
ViT CLS token Classification: 9
Ground Truth label: 9
ViT CLS token Classification: 9
Ground Truth label: 9
ViT CLS token Classification: 6
Ground Truth label: 4
ViT CLS token Classification: 1
Ground Truth label: 1
ViT CLS token Classification: 1
Ground Truth label: 1
ViT CLS token Classification: 2
Ground Truth label: 2
ViT CLS token Classification: 7
Ground Truth label: 7
ViT CLS token Classification: 8
Ground Truth label: 8
ViT CLS token Classification: 3
Ground Truth label: 3
ViT CLS token Classification: 4
Ground Truth label: 4
ViT CLS token Classification: 4
Ground Truth label: 7
ViT CLS token Classification: 7
Ground Truth label: 7
ViT CLS token Classification: 1
Ground Truth label: 2
ViT CLS token Classification: 9
Ground Truth label: 9
ViT CLS token Classification: 8
Ground Truth label: 9
ViT CLS token Classification: 9
Ground Truth label: 9
ViT CLS token Classification: 5
Ground Truth label: 3
ViT CLS token Classification

KeyboardInterrupt: 