In [None]:
import torch
import torch.nn.functional as f
import matplotlib.pyplot as plt
from collections import OrderedDict

from torch import nn
from torch import tensor
import torch.optim as optim
import torchvision
from torchvision.transforms import Compose, Resize, ToTensor,transforms


### Multi-Layer Perceptron class  

In [None]:
class MultilayerPerceptron(nn.Module):
    """Multi Layer perceptron with one hidden layer"""
    def __init__(self, in_dim:  int, hidden_dim:int,out_dim:int,droupout_p : float) -> None:
        super(MultilayerPerceptron,self).__init__()
        
        
        self.lin_1 = nn.Linear(in_features=in_dim,out_features=hidden_dim)
        self.act_1 = nn.GELU()
        self.droupout_1=nn.Dropout(p=droupout_p)
        self.lin_2 = nn.Linear(in_features=hidden_dim,out_features=out_dim)
        self.droupout_2 = nn.Dropout(p=droupout_p)
        
        
    def forward(self,input: tensor) -> tensor:
        
        
        output = self.lin_1(input)
        output = self.act_1(output)
        output = self.droupout_1(output)
        output = self.lin_2(output)
        output = self.droupout_2(output)
        
        return output    

### Tokenizer Class

In [None]:
class Tokenizer(nn.Module):
    """Tokenizes the image"""
    
    def __init__(self, token_dim: int, patch_size: int) -> None:
        
        super().__init__()
        
        self.input_to_tokens = nn.Conv2d(in_channels=3,out_channels=token_dim,kernel_size=patch_size,stride=patch_size)
        
        
    def forward(self, input:tensor)->tensor:
        
        """Returns token in shape of (batch_size, n_token, token_dim)"""
        
        output = self.input_to_tokens(input)
        output = torch.flatten(output,start_dim=-2,end_dim=-1)
        output = output.transpose(-2,-1)
        
        return output
    
    
class ClasstokenConcatenator(nn.Module):
    """Concatenate the Class with set of tokens"""
    
    def __init__(self, token_dim: int) -> None:
        
        super().__init__()
        self.class_token = nn.Parameter(torch.zeros(token_dim))
        
    def forward(self, input:tensor) -> tensor:

        class_token = self.class_token.expand(len(input),1,-1)
        output = torch.cat((class_token,input),dim=1)
                    
        return output
    
    
class PositionEmbeddingAdder(nn.Module):
    """adds learnable parameters to token for position embedding"""
    
    def __init__(self, n_token: int, token_dim: int) -> None:
        super().__init__()
        
        position_embedding = torch.zeros(n_token,token_dim)
        self.position_embedding =  nn.Parameter(position_embedding)
        
    def forward(self, input:tensor)->tensor:
        
        output = input+self.position_embedding
        return output
                    

### Attention Module

In [None]:
class QueriesKeyValuesExtractor(nn.Module):
    """get queries key value from multi head self attention"""
    
    def __init__(self,token_dim:int,head_dim:int,n_heads:int) -> None:
        super().__init__()
        
        self.head_dim = head_dim
        self.n_heads = n_heads
        queries_key_values_dim = 3*self.n_heads*self.head_dim
        
        self.input_to_queries_key_values = nn.Linear(in_features=token_dim,out_features=queries_key_values_dim,bias = False)
        
        
        
    def forward(self,input: tensor) -> tuple[tensor,tensor,tensor]:
        
        
        batch_size,n_token,token_dim = input.shape
        queries_key_values = self.input_to_queries_key_values(input)            #input -> [batch_size, n_tokens, token_dim]
        queries_key_values = queries_key_values.reshape(batch_size,3,self.n_heads,n_token,self.head_dim)
        
        queries, keys, values = queries_key_values.unbind(dim=1)
        
        return queries, keys, values
    
    
def get_attention(queries: tensor, keys: tensor, values: tensor) -> tensor:
        
        
    scale = queries.shape[-1]**(-0.5)
    attention_scores = (queries @  keys.transpose(-1,-2)) * scale
        
    attention_prob = f.softmax(attention_scores,dim=-1)
        
    attention = attention_prob @ values
        
    return attention
        

In [None]:
class Multiheadselfattention(nn.Module):
    """Multi head self attention"""
    
    def __init__(self,token_dim: int , head_dim: int , n_heads : int, droupout_p : float) -> None:
        super(Multiheadselfattention,self).__init__()
        
        
        self.query_key_value_extractor = QueriesKeyValuesExtractor(token_dim=token_dim,head_dim=head_dim,n_heads=n_heads)
        self.concatenated_head_dim = n_heads*head_dim
        
        self.attention_to_output = nn.Linear(in_features=self.concatenated_head_dim,out_features=token_dim)
        
        self.output_dropout = nn.Dropout(p=droupout_p)
        
        
    def forward(self, input: tensor) -> tensor:
        
        batch_size, n_tokens, token_dim = input.shape
        querys, keys, values = self.query_key_value_extractor(input)
        
        attention = get_attention(queries=querys,keys=keys,values=values)
        
        attention = attention.transpose(1,2).reshape(batch_size,n_tokens,self.concatenated_head_dim)
        
        output = self.attention_to_output(attention)
        output = self.output_dropout(output)
        
        return output
        
    

### Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    """Transformer Block"""
    
    def __init__(self, token_dim: int, multihead_attention_head_dim: int, multihead_attention_n_heads: int, 
                 multilayer_perceptron_hidden_dim: int, dropout_p: float) -> None:
        super().__init__()
        
        self.layer_norm_1 = nn.LayerNorm(normalized_shape=token_dim)
        self.multi_head_attention = Multiheadselfattention(token_dim=token_dim,head_dim=multihead_attention_head_dim,
                                                           n_heads=multihead_attention_n_heads,droupout_p=dropout_p)
        
        self.layer_norm_2 = nn.LayerNorm(normalized_shape= token_dim)
        
        self.multilayer_perceptron = MultilayerPerceptron(in_dim=token_dim,hidden_dim=multilayer_perceptron_hidden_dim,
                                                          out_dim=token_dim,droupout_p=dropout_p)
        
    def forward(self, input: tensor) -> tensor:
        """Runs the input through transformer block"""
        
        residual = input
        output = self.layer_norm_1(input)
        output = self.multi_head_attention(output)
        output += residual
        
        residual = output
        output = self.layer_norm_2(output)
        output = self.multilayer_perceptron(output)
        output += residual
        
        return output

In [None]:
class Transformer(nn.Module):
    """Transformer Encoder"""
    
    def __init__(self, n_layers: int, token_dim: int, multihead_attention_head_dim: int,
                 multihead_attention_n_heads: int, mulitlayer_perceptron_hidden_dim : int, dropout_p : float) -> None:
        super().__init__()
        
        transformer_blocks =[]
        for i in range(1, n_layers+1):
            transformer_block = TransformerBlock(token_dim=token_dim,multihead_attention_head_dim=multihead_attention_head_dim,
                                                 multihead_attention_n_heads=multihead_attention_n_heads,multilayer_perceptron_hidden_dim=
                                                 mulitlayer_perceptron_hidden_dim,dropout_p=dropout_p)
            transformer_block = (f'transformer_block_{i}',transformer_block)
            
            transformer_blocks.append(transformer_block)
            
        transformer_blocks = OrderedDict(transformer_blocks)
        self.transformer_blocks = nn.Sequential(transformer_blocks)
        
        
    def forward(self, input: tensor) -> tensor:
        
        output = self.transformer_blocks(input)
        
        return output
            

In [None]:
class visionTransformer(nn.Module):
    """Vision Transformer"""
    
    def __init__(self, token_dim: int, patch_size: int, image_size: int, n_layers: int, multihead_attention_head_dim : int
                 ,multihead_attention_n_heads: int, multilayer_perceptron_hidden_dim: int, dropout_p: float, n_classes: int) -> None:
        super().__init__()
        
        self.tokenizer = Tokenizer(token_dim=token_dim,patch_size=patch_size)
        self.class_token_concatenator = ClasstokenConcatenator(token_dim=token_dim)
        n_tokens = (image_size//patch_size)**2 + 1
        
        self.position_embedding_adder = PositionEmbeddingAdder(n_token=n_tokens,token_dim=token_dim)
        
        self.transformer = Transformer(n_layers=n_layers,token_dim=token_dim,multihead_attention_head_dim=multihead_attention_head_dim
                                       ,multihead_attention_n_heads=multihead_attention_n_heads,mulitlayer_perceptron_hidden_dim=multilayer_perceptron_hidden_dim,
                                       dropout_p=dropout_p)
        
        self.head = nn.Linear(in_features=token_dim,out_features=n_classes)
        
        
    def forward(self, input: tensor) -> tensor:
        
        output = self.tokenizer(input)
        output = self.class_token_concatenator(output)
        output = self.position_embedding_adder(output)
        output = self.transformer(output)
        output = output[:,0]
        
        output = self.head(output)
        return output
        
        
        

In [None]:
model = visionTransformer(token_dim=48,patch_size=4,image_size=32,n_layers=4,multihead_attention_head_dim=48,multihead_attention_n_heads=8
                          ,multilayer_perceptron_hidden_dim=512,dropout_p=0.1,n_classes=10)


In [None]:
transform = transforms.Compose([transforms.Resize((32, 32)),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)


#Training of the Vision Transformer 

# model = visionTransformer(token_dim=192,patch_size=8,image_size=32,n_layers=4,multihead_attention_head_dim=192,multihead_attention_n_heads=8
#                           ,multilayer_perceptron_hidden_dim=512,dropout_p=0.2,n_classes=10)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer =optim.Adam(model.parameters(),lr=0.001)

criterion = nn.CrossEntropyLoss()


num_epochs = 10

for epoch in range(num_epochs):
    running_loss = 0
    correct = 0
    total = 0
    for i, data in enumerate(trainloader):
        inputs, labels = data[0].to(device), data[1].to(device)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        
        loss = criterion(outputs,labels)
        loss.backward()
        
        optimizer.step()

        _,predicted = torch.max(outputs.data,1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        running_loss += loss.item()
        if i % 200 == 199:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0
    print('Accuracy of the network on the train images: %d %%' % (100 * correct / total))
    correct = 0
    total = 0
    model.eval()
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        # model.eval()
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

            



Files already downloaded and verified
Files already downloaded and verified
[1,   200] loss: 1.768
Accuracy of the network on the train images: 40 %
Accuracy of the network on the 10000 test images: 47 %
[2,   200] loss: 1.326
Accuracy of the network on the train images: 53 %
Accuracy of the network on the 10000 test images: 56 %
[3,   200] loss: 1.117
Accuracy of the network on the train images: 60 %
Accuracy of the network on the 10000 test images: 59 %
[4,   200] loss: 0.974
Accuracy of the network on the train images: 64 %
Accuracy of the network on the 10000 test images: 60 %
[5,   200] loss: 0.857
Accuracy of the network on the train images: 68 %
Accuracy of the network on the 10000 test images: 61 %
[6,   200] loss: 0.753
Accuracy of the network on the train images: 72 %
Accuracy of the network on the 10000 test images: 62 %
[7,   200] loss: 0.654
Accuracy of the network on the train images: 75 %
Accuracy of the network on the 10000 test images: 61 %
[8,   200] loss: 0.541
Accur