In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# VIT Implementation

The vision transformer can be seperated into three parts, we will implement each part and combine them in the end.

For the implementation, feel free to experiment different kinds of setup, as long as you use attention as the main computation unit and the ViT can be train to perform the image classification task present later.
You can read about the ViT implement from other libary: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py and https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py

## PatchEmbedding
PatchEmbedding is responsible for dividing the input image into non-overlapping patches and projecting them into a specified embedding dimension. It uses a 2D convolution layer with a kernel size and stride equal to the patch size. The output is a sequence of linear embeddings for each patch.

In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
      # TODO

        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.proj = nn.Conv2d(self.in_channels, 
                            self.embed_dim, 
                            kernel_size=self.patch_size, 
                            stride=self.patch_size
                           )
        self.norm = nn.LayerNorm(self.embed_dim)

    def forward(self, x):
        # TODO
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x

## MultiHeadSelfAttention

This class implements the multi-head self-attention mechanism, which is a key component of the transformer architecture. It consists of multiple attention heads that independently compute scaled dot-product attention on the input embeddings. This allows the model to capture different aspects of the input at different positions. The attention outputs are concatenated and linearly transformed back to the original embedding size.

In [4]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        # TODO
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(self.embed_dim, self.embed_dim * 3)
        self.qk_norm = False
        self.use_activation = False
        self.activation = nn.ReLU() if self.use_activation else nn.Identity()
        self.q_norm = nn.LayerNorm(self.head_dim) if self.qk_norm else nn.Identity()
        self.k_norm = nn.LayerNorm(self.head_dim) if self.qk_norm else nn.Identity()
        self.attn_dropout = nn.Dropout(0.1)
        self.proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.proj_dropout = nn.Dropout(0)

    def forward(self, x):
        # TODO
        batch_si, seq_len, emb_dim = x.shape
        qkv = self.qkv(x).reshape(batch_si, seq_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        q = q * self.scale
        attention = q @ k.transpose(-2, -1)
        attention = attention.softmax(dim=-1)
        attention = self.attn_dropout(attention)

        z = attention @ v
        z = z.transpose(1, 2).reshape(batch_si, seq_len, emb_dim)
        z = self.proj(z)
#         z = self.proj_dropout(z)
        return z

## TransformerBlock
This class represents a single transformer layer. It includes a multi-head self-attention sublayer followed by a position-wise feed-forward network (MLP). Each sublayer is surrounded by residual connections.
You may also want to use layer normalization or other type of normalization.

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout):
        # TODO
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadSelfAttention(embed_dim, num_heads)
        self.attention_norm = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_dim),
                                  nn.GELU(),
                                  nn.Dropout(dropout),
                                  nn.Linear(mlp_dim, embed_dim)
                                  # nn.Dropout(dropout)
        )
        self.mlp_norm = nn.LayerNorm(embed_dim)


    def forward(self, x):
        # TODO
        res = x
        x = self.attention_norm(x)
        x = self.attention(x)
        x = x + res # residual connection
        res = x
        x = self.mlp_norm(x)
        x = self.mlp(x)
        x = x + res
        return x

## VisionTransformer:
This is the main class that assembles the entire Vision Transformer architecture. It starts with the PatchEmbedding layer to create patch embeddings from the input image. A special class token is added to the sequence, and positional embeddings are added to both the patch and class tokens. The sequence of patch embeddings is then passed through multiple TransformerBlock layers. The final output is the logits for all classes

In [6]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=0.1):
        # TODO
        super(VisionTransformer, self).__init__()
        self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.embed_len = self.patch_embed.num_patches + 1
        self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_len, embed_dim))
        self.dropout = nn.Dropout(dropout)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_dim, dropout) for i in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.cls_head = nn.Sequential(nn.Linear(embed_dim, embed_dim//2),
                                nn.GELU(),
                                nn.Dropout(dropout),
                                nn.Linear(embed_dim // 2, num_classes),
                                # nn.Dropout(dropout)     
                                )                           

    def forward(self, x):
        # TODO
        x = self.patch_embed(x)
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        x = x + self.pos_embed
        # x = self.dropout(x)
        for block in self.transformer_blocks:
            x = block(x)
        # x = self.norm(x)
        logits = self.cls_head(x[:, 0])

        return logits



## Let's train the ViT!

We will train the vit to do the image classification with cifar100. Free free to change the optimizer and or add other tricks to improve the training

In [7]:
# Example usage:
image_size = 32
patch_size = 4
in_channels = 3
embed_dim = 512
num_heads = 8
mlp_dim = 1024
num_layers = 4
num_classes = 10
dropout = 0.1
batch_size = 256

In [8]:
model = VisionTransformer(image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout).to(device)
input_tensor = torch.randn(1, in_channels, image_size, image_size).to(device)
output = model(input_tensor)
print(output.shape)

torch.Size([1, 10])


In [9]:
# Load the CIFAR-10 dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 80121955.41it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [10]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
# TODO
lr = 0.003
weight_decay = 0.0001
num_epochs = 150
optimizer = torch.optim.Adam(model.parameters(),
                                lr=lr,
                                weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(trainloader), epochs=num_epochs)

In [11]:
# Train the model
best_val_acc = 0
train_accs = []
test_accs = []
epochs_no_improve = 0
max_patience = 20
early_stop = False
pbar=tqdm(range(num_epochs))
for epoch in pbar:
    # if not load_pretrained:
    running_accuracy = 0.0
    running_loss = 0.0
    model.train()
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        running_accuracy += acc / len(trainloader)
        running_loss += loss.item() / len(trainloader)
    
    train_accs.append(running_accuracy)


    # TODO Feel free to modify the training loop youself.

    # Validate the model
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_acc = 100 * correct / total

    pbar.set_postfix({"Epoch": epoch+1, "Train Accuracy": running_accuracy*100, "Training Loss": running_loss, "Validation Accuracy": val_acc})

    # Save the best model

    if val_acc > best_val_acc:
        epochs_no_improve = 0
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer,
            'scheduler' : scheduler,
            'train_acc': train_accs,
            'test_acc': val_acc
        },  'best_model.pth')

    else:
        epochs_no_improve += 1

    if epoch > 100 and epochs_no_improve >= max_patience:
        print('Early stopping!')
        early_stop = True
        break
    else:
        continue

100%|██████████| 150/150 [2:27:17<00:00, 58.91s/it, Epoch=150, Train Accuracy=tensor(95.5106, device='cuda:0'), Training Loss=0.124, Validation Accuracy=86.9]


In [12]:
print(f"Best Validation Accuracy: {best_val_acc:.2f}%")

Best Validation Accuracy: 87.00%


In [13]:
import os
os.chdir(r'/kaggle/working')

In [14]:
!pwd

/kaggle/working


In [15]:
# from IPython.display import FileLink
# FileLink(r'best_model.pth')

model_dict = torch.load('best_model.pth', map_location=device)
model.load_state_dict(model_dict['model'])

<All keys matched successfully>

In [16]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in tqdm(testloader):
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
val_acc = 100 * correct / total
print(f"Validation Accuracy: {val_acc:.2f}%")

100%|██████████| 40/40 [00:04<00:00,  9.88it/s]

Validation Accuracy: 87.00%





In [17]:
model_dict

{'epoch': 141,
 'model': OrderedDict([('cls_token',
               tensor([[[-1.3369e-03,  3.8733e-02,  1.1144e-03, -2.7063e-02,  2.4515e-03,
                         -9.5406e-03, -2.0734e-03, -2.2666e-02,  1.0553e-01, -1.8779e-03,
                         -2.1803e-03, -7.2460e-03, -5.1628e-04,  4.0150e-04,  1.9246e-03,
                         -3.0420e-06, -3.9180e-04,  3.1528e-02, -9.6840e-04, -2.6527e-04,
                         -2.2775e-03,  1.1251e-03, -2.1313e-03, -1.0821e-04,  5.2767e-04,
                         -2.1984e-03,  7.6103e-06, -6.6514e-04, -2.9190e-03, -2.0474e-03,
                          1.7090e-03,  9.6956e-03, -1.3528e-03, -1.0410e-03, -1.9553e-04,
                         -2.5496e-04, -1.6900e-03, -5.4030e-04, -1.0925e-03, -9.0844e-04,
                         -2.3434e-03,  2.5149e-03,  2.9364e-02,  5.6712e-04, -1.2888e-03,
                          6.9708e-04,  2.6167e-02, -3.1013e-03,  7.7722e-04,  4.0732e-03,
                         -2.2072e-05,  9.5709e-0