# Vision Transformer using Pytorch

In [47]:
import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets as datasets
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torchvision import transforms
import math
from torch.utils.data import random_split
import warnings
warnings.filterwarnings("ignore")

In [48]:
class MultiHeadedSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=2):
        super().__init__()
        assert embed_dim % num_heads == 0, f"Can't divide dimension {embed_dim} into {num_heads} heads"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.d_head = embed_dim // num_heads

        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, sequences):
        # sequences: (seq_length, N, embed_dim)
        sequences = sequences.permute(1, 0, 2)  # (N, seq_length, embed_dim)
        attn_output, attn_weights = self.multihead_attn(sequences, sequences, sequences)
        attn_output = attn_output.permute(1, 0, 2)  # (seq_length, N, embed_dim)
        return attn_output,attn_weights

In [49]:
class EncoderBlock(nn.Module):
    def __init__(self,num_heads,hidden_dimension,mlp_ratio=4):
        super().__init__()
        self.num_heads=num_heads
        self.hidden_dimension=hidden_dimension
        self.normalize_layer_1=nn.LayerNorm(hidden_dimension)
        self.multi_headed_self_attention=MultiHeadedSelfAttention(num_heads=num_heads,embed_dim=hidden_dimension)
        self.normalize_layer_2=nn.LayerNorm(hidden_dimension)
        self.neural_network=nn.Sequential(
            nn.Linear(hidden_dimension,mlp_ratio*hidden_dimension),
            nn.GELU(),
            nn.Linear(mlp_ratio*hidden_dimension,hidden_dimension)
        )
    def forward(self,encoder_input):
        norm_layer1_output=self.normalize_layer_1(encoder_input)
        self_attention_output,attention_weights=self.multi_headed_self_attention(norm_layer1_output)
        norm_layer2_output=self.normalize_layer_2(self_attention_output)
        neural_network_output=self.neural_network(norm_layer2_output)
        final_output=neural_network_output
        return final_output,attention_weights

In [50]:
class VisionTransformer(nn.Module):
    def __init__(self,patch_size,image_size,hidden_dimension,num_heads,out_dimension,num_encoder_blocks=2):
        super().__init__()
        self.patch_size=patch_size
        self.hidden_dimension=hidden_dimension
        self.num_heads=num_heads
        self.num_encoder_blocks=num_encoder_blocks
        self.output_dimension=out_dimension
        # dimension of the input patches after they have been flattened
        self.input_dimension = int(3 * self.patch_size * self.patch_size)
        # linear embedding
        self.linear_mapper = nn.Linear(self.input_dimension, self.hidden_dimension)
        #class token 
        self.class_token=nn.Parameter(torch.randn(1,self.hidden_dimension))
        batch_size,channel,height,width=image_size
        num_of_patches=(height*width)/(self.patch_size**2)
        # print(num_of_patches)
        self.positional_encodings = self._generate_positional_encodings(int(num_of_patches), self.hidden_dimension)

        self.blocks = nn.ModuleList([EncoderBlock(hidden_dimension=self.hidden_dimension, num_heads=self.num_heads) for _ in range(self.num_encoder_blocks)])
        
        # 5) Classification MLPk
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_dimension, self.output_dimension),
            nn.Softmax(dim=-1)
        )

    def _generate_positional_encodings(self, num_patches, hidden_dimension):
        position_encodings = torch.zeros(num_patches + 1, hidden_dimension)  # +1 for class token
        position_encodings[1:, 0::2] = torch.sin(self._get_angles(torch.arange(0, num_patches).float(),
                                                                                     2 * torch.arange(0, hidden_dimension // 2).float() / hidden_dimension))
        position_encodings[1:, 1::2] = torch.cos(self._get_angles(torch.arange(0, num_patches).float(),
                                                                                     2 * torch.arange(0, hidden_dimension // 2).float() / hidden_dimension))
        return position_encodings.unsqueeze(0)

    def _get_angles(self, positions, i):
        angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / torch.tensor(self.hidden_dimension).float())
        return positions.unsqueeze(-1) * angle_rates
   
    def forward(self,input_images):
        # image_patches=rearrange(input_images, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=self.patch_size, s2=self.patch_size)
        # nn.functional.unfold(input_images,kernel_size=self.patch_size,stride=self.patch_size)
        batch_size=input_images.shape[0]
        # unfold = nn.Unfold(self.patch_size,stride=1)
        image_patches=input_images.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)

        # Reshape the patches
        image_patches = image_patches.contiguous().view(input_images.size(0), -1, self.patch_size * self.patch_size * input_images.size(1))
        tokens=self.linear_mapper(image_patches)

        # Adding the class_tokens to the tokens
        tokens_with_class=torch.stack([torch.vstack((self.class_token,tokens[i])) for i in range(len(tokens))])
        # class_token = self.class_token.expand(batch_size, -1)
        # tokens_with_class = torch.cat((class_token.unsqueeze(1), tokens), dim=1)
        tokens_with_position = tokens_with_class + self.positional_encodings[:, :tokens_with_class.size(1)]

        # Add positional encodings
        # tokens_with_position = tokens_with_class + self.positional_encodings[:, :tokens_with_class.size(1)]
        # print(tokens_with_position)
        attention_weights=[]
        for block in self.blocks:
            tokens_with_position ,attn_weight= block(tokens_with_position)
            attention_weights.append(attn_weight)
         # Getting the classification token only
        tokens_with_position = tokens_with_position[:, 0]
        
        return self.mlp(tokens_with_position),attention_weights
        # return image_patches

In [51]:
BATCH_SIZE=64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32)),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data = datasets.CIFAR10(
    'train_data',
    transform=transform,
    target_transform=transforms.Compose([
        lambda x: torch.tensor(x),  # Convert label to tensor
    ]),
    train=True,
    download=True
)

test_data = datasets.CIFAR10(
    'test_data',
    transform=transform,
    target_transform=transforms.Compose([
        lambda x: torch.tensor(x),  # Convert label to tensor
    ]),
    train=False,
    download=True
)

# train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [52]:
# Defining model and training options
device = torch.cuda.set_device(1)
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
transformer_model = VisionTransformer(image_size=(BATCH_SIZE,3, 32, 32), num_encoder_blocks=2, hidden_dimension=8, num_heads=2, out_dimension=10,patch_size=3)

# model = MyViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)
N_EPOCHS = 5
LR = 0.005

# # Training loop
# def train():
#     optimizer = Adam(model.parameters(), lr=LR)
#     criterion = CrossEntropyLoss()
#     for epoch in range(N_EPOCHS):
#         train_loss = 0.0
#         for batch in train_loader:
#             x, y = batch
#             x, y = x.to(device), y.to(device)
#             y_hat = model(x)
#             loss = criterion(y_hat, y)
    
#             train_loss += loss.detach().cpu().item() / len(train_loader)
    
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
    
#         print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")
criterion = CrossEntropyLoss()
# Test loop
def test(model):
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in test_loader:
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)
    
            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")


Using device:  None (NVIDIA GeForce RTX 3090)


# Training the model with different data sizes: 5% , 10% , 25% , 50% , 100%

In [53]:
N_EPOCHS = 5
LR = 0.005
def train(train_data_loader,model):
    train_loader=train_data_loader
    optimizer = Adam(model.parameters(), lr=LR)
    for epoch in range(N_EPOCHS):
        train_loss = 0.0
        for batch in train_loader:
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
    
            train_loss += loss.detach().cpu().item() / len(train_loader)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

## Trainig with 5% training data

In [19]:
five_percent = int(0.05 * len(train_data))

subset_dataset, _ = random_split(train_data, [five_percent, len(train_data) - five_percent])

# Create a DataLoader for the subset dataset
subset_loader = DataLoader(subset_dataset, batch_size=BATCH_SIZE, shuffle=True)
train(subset_loader,model=transformer_model)

Epoch 1/5 loss: 2.30
Epoch 2/5 loss: 2.29
Epoch 3/5 loss: 2.27
Epoch 4/5 loss: 2.27
Epoch 5/5 loss: 2.26


In [20]:
test(model=transformer_model)

Test loss: 2.26
Test accuracy: 17.07%


## Training with 10% training data

In [34]:
ten_percent = int(0.1 * len(train_data))

subset_dataset, _ = random_split(train_data, [ten_percent, len(train_data) - ten_percent])

# Create a DataLoader for the subset dataset
subset_loader = DataLoader(subset_dataset, batch_size=BATCH_SIZE, shuffle=True)
train(subset_loader,model=transformer_model)

Epoch 1/5 loss: 2.27
Epoch 2/5 loss: 2.26
Epoch 3/5 loss: 2.26
Epoch 4/5 loss: 2.26
Epoch 5/5 loss: 2.25


In [35]:
test(model=transformer_model)

Test loss: 2.25
Test accuracy: 18.99%


## Training with 25% training data

In [36]:
twenty_five_percent = int(0.25 * len(train_data))

subset_dataset, _ = random_split(train_data, [twenty_five_percent, len(train_data) - twenty_five_percent])

# Create a DataLoader for the subset dataset
subset_loader = DataLoader(subset_dataset, batch_size=BATCH_SIZE, shuffle=True)
train(subset_loader,model=transformer_model)

Epoch 1/5 loss: 2.25
Epoch 2/5 loss: 2.23
Epoch 3/5 loss: 2.22
Epoch 4/5 loss: 2.22
Epoch 5/5 loss: 2.22


In [37]:
test(model=transformer_model)

Test loss: 2.21
Test accuracy: 23.17%


## Training with 50% training data

In [38]:
fifty_percent = int(0.5 * len(train_data))

subset_dataset, _ = random_split(train_data, [fifty_percent, len(train_data) - fifty_percent])

# Create a DataLoader for the subset dataset
subset_loader = DataLoader(subset_dataset, batch_size=BATCH_SIZE, shuffle=True)
train(subset_loader,model=transformer_model)

Epoch 1/5 loss: 2.21
Epoch 2/5 loss: 2.20
Epoch 3/5 loss: 2.19
Epoch 4/5 loss: 2.18
Epoch 5/5 loss: 2.18


In [39]:
test(model=transformer_model)

Test loss: 2.18
Test accuracy: 26.29%


## Training with 100% training data

In [40]:
subset_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
train(subset_loader,model=transformer_model)

Epoch 1/5 loss: 2.18
Epoch 2/5 loss: 2.17
Epoch 3/5 loss: 2.17
Epoch 4/5 loss: 2.16
Epoch 5/5 loss: 2.16


In [41]:
test(model=transformer_model)

Test loss: 2.15
Test accuracy: 30.00%


# Training with different patch sizes: (4x4) , (8x8) , (16x16)

## Training with patch size (4x4)

In [22]:
model_patch_4=VisionTransformer(image_size=(BATCH_SIZE,3, 32, 32), num_encoder_blocks=2, hidden_dimension=8, num_heads=2, out_dimension=10,patch_size=4)
train_data_loader=DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)
train(train_data_loader,model=model_patch_4)

Epoch 1/5 loss: 2.24
Epoch 2/5 loss: 2.21
Epoch 3/5 loss: 2.19
Epoch 4/5 loss: 2.18
Epoch 5/5 loss: 2.19


In [23]:
test(model=model_patch_4)

Test loss: 2.16
Test accuracy: 28.66%


## Training with patch size (8x8)

In [54]:
model_patch_8=VisionTransformer(image_size=(BATCH_SIZE,3, 32, 32), num_encoder_blocks=2, hidden_dimension=8, num_heads=2, out_dimension=10,patch_size=8)
train(train_data_loader,model=model_patch_8)

Epoch 1/5 loss: 2.26
Epoch 2/5 loss: 2.24
Epoch 3/5 loss: 2.23
Epoch 4/5 loss: 2.22
Epoch 5/5 loss: 2.22


In [55]:
test(model=model_patch_8)

Test loss: 2.20
Test accuracy: 23.97%


## Training with patch size (16x16)

In [56]:
model_patch_16=VisionTransformer(image_size=(BATCH_SIZE,3, 32, 32), num_encoder_blocks=2, hidden_dimension=8, num_heads=2, out_dimension=10,patch_size=16)
train(train_data_loader,model=model_patch_16)

Epoch 1/5 loss: 2.25
Epoch 2/5 loss: 2.23
Epoch 3/5 loss: 2.22
Epoch 4/5 loss: 2.22
Epoch 5/5 loss: 2.20


In [57]:
test(model=model_patch_16)

Test loss: 2.19
Test accuracy: 25.48%


# Training with different number of attention heads

## Training with 4 attention heads

In [25]:
model_attention_4=VisionTransformer(image_size=(BATCH_SIZE,3, 32, 32), num_encoder_blocks=2, hidden_dimension=8, num_heads=4, out_dimension=10,patch_size=2)
train(train_data_loader,model=model_attention_4)

Epoch 1/5 loss: 2.26
Epoch 2/5 loss: 2.26
Epoch 3/5 loss: 2.26
Epoch 4/5 loss: 2.24
Epoch 5/5 loss: 2.21


In [26]:
test(model=model_attention_4)

Test loss: 2.19
Test accuracy: 25.10%


## Training with 8 attention heads

In [27]:
model_attention_8=VisionTransformer(image_size=(BATCH_SIZE,3, 32, 32), num_encoder_blocks=2, hidden_dimension=8, num_heads=8, out_dimension=10,patch_size=2)
train(train_data_loader,model=model_attention_8)

Epoch 1/5 loss: 2.25
Epoch 2/5 loss: 2.21
Epoch 3/5 loss: 2.19
Epoch 4/5 loss: 2.17
Epoch 5/5 loss: 2.16


In [28]:
test(model=model_attention_8)

Test loss: 2.15
Test accuracy: 30.02%


## Training with 12 attention heads

In [None]:
model_attention_12=VisionTransformer(image_size=(BATCH_SIZE,3, 32, 32), num_encoder_blocks=5, hidden_dimension=12, num_heads=12, out_dimension=10,patch_size=2)
train(train_data_loader,model=model_attention_12)

In [31]:
test(model=model_attention_12)

Test loss: 2.15
Test accuracy: 30.07%


In [None]:
from matplotlib import pyplot as plt

def visualize_attention_maps(model, test_dataset, num_classes):
    model.eval()
    for cls in range(num_classes):
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=True)
        count = 0
        with torch.no_grad():
            for data in test_loader:
                images, labels = data
                for i in range(len(labels)):
                    if labels[i] == cls:
                        outputs, attn_weights = model(images[i].unsqueeze(0))  # Add batch dimension
                        # print(attn_weights)
                        # Visualize attention maps
                        for layer, attn_weight in enumerate(attn_weights):
                            plt.imshow(attn_weight[layer])
                            break
                        break
                    break
                break
                #             attn_weight = attn_weight.squeeze(0)  # Remove batch dimension
                #             num_heads = attn_weight.size(0)
                #             for head in range(num_heads):
                #                 plt.figure(figsize=(5, 5))
                #                 plt.imshow(attn_weight[layer][head].cpu(), cmap='hot', interpolation='nearest')
                #                 plt.title(f'Layer {layer + 1}, Head {head + 1}')
                #                 plt.colorbar()
                #                 plt.show()
                #         count += 1
                #         if count == 2:
                #             break
                # if count == 2:
                #     break



visualize_attention_maps(transformer_model, test_data, 10)