In [1]:
# device = 'cpu'
device = 'cuda'

In [2]:
import torch.nn.functional as F
from torch import nn
import torch

class MSA(nn.Module):
  def __init__(self, input_dim, embed_dim, num_heads):
    '''
    input_dim: Dimension of input token embeddings
    embed_dim: Dimension of internal key, query, and value embeddings
    num_heads: Number of self-attention heads
    '''

    super().__init__()

    self.input_dim = input_dim
    self.embed_dim = embed_dim
    self.num_heads = num_heads

    self.K_embed = nn.Linear(input_dim, embed_dim, bias=False)
    self.Q_embed = nn.Linear(input_dim, embed_dim, bias=False)
    self.V_embed = nn.Linear(input_dim, embed_dim, bias=False)
    self.out_embed = nn.Linear(embed_dim, embed_dim, bias=False)

  def forward(self, x):
    '''
    x: input of shape (batch_size, max_length, input_dim)
    return: output of shape (batch_size, max_length, embed_dim)
    '''

    batch_size, max_length, given_input_dim = x.shape
    assert given_input_dim == self.input_dim
    assert max_length % self.num_heads == 0

    x = x.reshape(batch_size * max_length, -1)
    K = self.K_embed(x).reshape(batch_size, max_length, self.embed_dim) # (batch_size, max_length, embed_dim)
    Q = self.Q_embed(x).reshape(batch_size, max_length, self.embed_dim) # (batch_size, max_length, embed_dim)
    V = self.V_embed(x).reshape(batch_size, max_length, self.embed_dim) # (batch_size, max_length, embed_dim)

    # TODO: split each KQV into heads, by reshaping each into (batch_size, max_length, self.num_heads, indiv_dim)
    indiv_dim = self.embed_dim // self.num_heads
    K = K.reshape(batch_size, max_length, self.num_heads, indiv_dim)
    Q = Q.reshape(batch_size, max_length, self.num_heads, indiv_dim)
    V = V.reshape(batch_size, max_length, self.num_heads, indiv_dim)

    K = K.permute(0, 2, 1, 3) # (batch_size, num_heads, max_length, embed_dim / num_heads)
    Q = Q.permute(0, 2, 1, 3) # (batch_size, num_heads, max_length, embed_dim / num_heads)
    V = V.permute(0, 2, 1, 3) # (batch_size, num_heads, max_length, embed_dim / num_heads)

    K = K.reshape(batch_size * self.num_heads, max_length, indiv_dim)
    Q = Q.reshape(batch_size * self.num_heads, max_length, indiv_dim)
    V = V.reshape(batch_size * self.num_heads, max_length, indiv_dim)

    # transpose and batch matrix multiply
    # This is our K transposed so we can do a simple batched matrix multiplication (see torch.bmm for more details and the quick solution)
    K_T = K.permute(0, 2, 1)

    QK = torch.bmm(Q, K_T)

    # calculate weights by dividing everything by the square root of d (self.embed_dim)
    weights = QK / (self.embed_dim ** 0.5)
    weights =  F.softmax(weights, dim=2)  # (batch_size * num_heads, max_length, max_length)

    #  get weighted average... see torch.bmm for a one line solution
    # (batch_size * num_heads, max_length, indiv_dim)
    # weights is (batch_size * num_heads, max_length, max_length) and V is (batch_size * self.num_heads, max_length, indiv_dim)
    # so we want the matrix multiplication of weights and V
    w_V = torch.bmm(weights, V)

    # rejoin heads
    w_V = w_V.reshape(batch_size, self.num_heads, max_length, indiv_dim)
    w_V = w_V.permute(0, 2, 1, 3) # (batch_size, max_length, num_heads, embed_dim / num_heads)
    w_V = w_V.reshape(batch_size, max_length, self.embed_dim)

    out = self.out_embed(w_V)

    return out

### Implement the ViT architecture
You will be implementing the ViT architecture based on the "An image is worth 16x16 words" paper.

Although the ViT and Transformer architecture are very similar, note a few differences:

1. Image patches instead of discrete tokens as input.
2. [GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used for the linear layers in the transformer layer (instead of ReLU)
3. LayerNorm before the sublayer instead of after.
4. Dropout after every linear layer except for KQV projections and also directly after adding positional embeddings to the patch embeddings.
5. Learnable [CLS] token at the beginning of the input.

A useful reference is Figure 1 in the [paper](https://arxiv.org/pdf/2010.11929.pdf).

First, implement a single layer:

In [None]:
class ViTLayer(nn.Module):
  def __init__(self, num_heads, input_dim, embed_dim, mlp_hidden_dim, dropout=0.1):
    '''
    num_heads: Number of heads for multi-head self-attention
    embed_dim: Dimension of internal key, query, and value embeddings
    mlp_hidden_dim: Hidden dimension of the linear layer
    dropout: Dropout rate
    '''

    super().__init__()

    self.input_dim = input_dim
    self.msa = MSA(input_dim, embed_dim, num_heads)

    self.layernorm1 = nn.LayerNorm(embed_dim)
    self.w_o_dropout = nn.Dropout(dropout)
    self.layernorm2 = nn.LayerNorm(embed_dim)
    self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_hidden_dim),
                              nn.GELU(),
                              nn.Dropout(dropout),
                              nn.Linear(mlp_hidden_dim, embed_dim),
                              nn.Dropout(dropout))

  def forward(self, x):
    '''
    x: input embeddings (batch_size, max_length, input_dim)
    return: output embeddings (batch_size, max_length, embed_dim)
    '''

    attention_output = self.msa(x)
    x += self.w_o_dropout(attention_output)
    x = self.layernorm1(x)
    linear_output = self.mlp(x)
    x += self.w_o_dropout(linear_output)
    x = self.layernorm2(x)
    return x


A portion of the full network is already implemented for you. Your task is to implement the preprocessing code, converting raw images into patch embeddings + positional embeddings + dropout, with a learnable CLS token at the beginning of the input.

Note that patch embeddings are to be added to positional embeddings elementwise, so the input embedding dimensions is size embed_dim.

In [None]:
class ViT(nn.Module):
    def __init__(self, patch_dim, image_dim, num_layers, num_heads, embed_dim, mlp_hidden_dim, num_classes, dropout):
        '''
        patch_dim: patch length and width to split image by
        image_dim: image length and width
        num_layers: number of layers in network
        num_heads: number of heads for multi-head attention
        embed_dim: dimension to project images patches to and dimension to use for position embeddings
        mlp_hidden_dim: hidden dimension of linear layer
        num_classes: number of classes to classify in data
        dropout: dropout rate
        '''

        super().__init__()
        self.num_layers = num_layers
        self.patch_dim = patch_dim
        self.image_dim = image_dim
        self.input_dim = self.patch_dim * self.patch_dim * 3
        self.num_heads = num_heads

        self.patch_embedding = nn.Linear(self.input_dim, embed_dim)
        self.position_embedding = nn.Parameter(torch.zeros(1, (image_dim // patch_dim) ** 2 + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.embedding_dropout = nn.Dropout(dropout)

        self.encoder_layers = nn.ModuleList([])
        for i in range(num_layers):
            self.encoder_layers.append(ViTLayer(num_heads, embed_dim, embed_dim, mlp_hidden_dim, dropout))

        self.mlp_head = nn.Linear(embed_dim, num_classes)
        self.layernorm = nn.LayerNorm(embed_dim)

    def forward(self, images):
        '''
        images: raw image data (batch_size, channels, rows, cols)
        '''

        # Don't hardcode dimensions (except for maybe channels = 3), use the variables in __init__.
        # You shouldn't need to add anything else to __init__, all of the embeddings,
        # dropout etc. are already initialized for you.

        # Put the preprocessed patches in variable "out" with shape (batch_size, length, embed_dim).

        h = w = self.image_dim // self.patch_dim
        N = images.size(0)
        images = images.reshape(N, 3, h, self.patch_dim, w, self.patch_dim)
        images = torch.einsum("nchpwq -> nhwpqc", images)
        patches = images.reshape(N, h * w, self.input_dim) # (batch, num_patches_per_image, patch_size_unrolled)

        patch_embeddings = self.patch_embedding(patches)
        patch_embeddings = torch.cat([torch.tile(self.cls_token, (N, 1, 1)), patch_embeddings], dim=1)
        out = patch_embeddings + torch.tile(self.position_embedding, (N, 1, 1)) # We add positional embeddings to our tokens (not concatenated)
        out = self.embedding_dropout(out)

        # add padding s.t. input length is multiple of num_heads
        add_len = (self.num_heads - out.shape[1]) % self.num_heads
        out = torch.cat([out, torch.zeros(N, add_len, out.shape[2], device=device)], dim=1)

        # Pass through each one of our encoder layers
        for layer in self.encoder_layers:
            out = layer(out)

        # Pop off and read our classification token we added, see what the value is
        cls_head = self.layernorm(torch.squeeze(out[:, 0], dim=1))
        logits = self.mlp_head(cls_head)
        return logits

def get_vit_tiny(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=3,
               embed_dim=192, mlp_hidden_dim=768, num_classes=num_classes, dropout=0.1)

def get_vit_small(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=6,
               embed_dim=384, mlp_hidden_dim=1536, num_classes=num_classes, dropout=0.1)

def get_vit_base(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=12,
               embed_dim=768, mlp_hidden_dim=3072, num_classes=num_classes, dropout=0.1)

In [None]:
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torchvision.datasets as datasets
import torchvision
import math
import torch.optim as optim
from tqdm.notebook import tqdm

data_root = './data/cifar10'
train_size = 40000
val_size = 10000

batch_size = 32

transform_train = T.Compose([
    T.Resize(40),
    T.RandomCrop(32),
    T.RandomHorizontalFlip(),
    T.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.95, 1.05)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

transform_val = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

train_dataset = datasets.CIFAR10(
    root=data_root,
    train=True,
    download=True,
    transform=transform_train,
)

val_dataset = datasets.CIFAR10(
    root=data_root,
    train=True,
    download=True,
    transform=transform_val,
)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar10/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 37289802.04it/s]


Extracting ./data/cifar10/cifar-10-python.tar.gz to ./data/cifar10
Files already downloaded and verified


In [None]:
from torch.utils.data import sampler

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          sampler=sampler.SubsetRandomSampler(range(train_size)))

val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        sampler=sampler.SubsetRandomSampler(range(train_size, 50000)))

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

vit = get_vit_small().to(device)

learning_rate = 5e-4 * batch_size / 256
num_epochs = 30
weight_decay = 0.1

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(vit.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=weight_decay)

train_losses = []
val_losses = []
val_accuracies = []
for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc = 0.0
    train_total = 0
    vit.train()
    for inputs, labels in tqdm(train_loader):
        """TODO:
        1. Set inputs and labels to be on device
        2. zero out our gradients
        3. pass our inputs through the ViT
        4. pass our outputs / labels into our loss / criterion
        5. backpropagate
        6. step our optimizeer
        """
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = vit(inputs)

        loss = criterion(outputs, labels.long())
        optimizer.step()

        train_loss += loss.item() * inputs.shape[0]
        train_acc += torch.sum((torch.argmax(outputs, dim=1) == labels)).item()
        train_total += inputs.shape[0]
    train_loss = train_loss / train_total
    train_acc = train_acc / train_total
    train_losses.append(train_loss)

    val_loss = 0.0
    val_acc = 0.0
    val_total = 0
    vit.eval()
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = vit(inputs)
            loss = criterion(outputs, labels.long())

            val_loss += loss.item() * inputs.shape[0]
            val_acc += torch.sum((torch.argmax(outputs, dim=1) == labels)).item()
            val_total += inputs.shape[0]
    val_loss = val_loss / val_total
    val_acc = val_acc / val_total
    val_losses.append(val_loss)

    val_accuracies.append(val_acc)
    if val_acc >= max(val_accuracies):
        torch.save(vit.state_dict(), 'best_model.pth')

    print(f'[{epoch + 1:2d}] train loss: {train_loss:.3f} | train accuracy: {train_acc:.3f} | val loss: {val_loss:.3f} | val accuracy: {val_acc:.3f}')

print('Finished Training')