Import packages needed.

In [1]:
import einops
from tqdm.notebook import tqdm

from torchsummary import summary

import torch
from torch import nn
import torchvision
import torch.optim as optim
from torchvision.transforms import Compose, Resize, Normalize, RandomHorizontalFlip, RandomCrop

In [36]:
# !jupyter nbextension enable --py widgetsnbextension

Set hyperparameters and specify device.

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
#print(device)

patch_size = 16
latent_size = 768
n_channels = 3
num_heads = 12
num_encoders = 12
dropout = 0.1
num_classes = 10
size = 224

epochs = 10
base_lr = 10e-3
weight_decay = 0.03
batch_size = 8

Implementation of input linear projection.

In [38]:
class InputEmbedding(nn.Module):
    def __init__(self, patch_size = patch_size, n_channels = n_channels, device = device, latent_size = latent_size, batch_size = batch_size):
        super(InputEmbedding, self).__init__()
        self.latent_size = latent_size
        self.patch_size = patch_size
        self.n_channels = n_channels
        self.device = device
        self.batch_size = batch_size
        self.input_size = self.patch_size*self.patch_size*self.n_channels

        #Linear Projection
        self.linearProjection = nn.Linear(self.input_size, self.latent_size)

        #Class Token
                                              # how many in 1 batch? 1Token  dimensions
        self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)

        #Positional Embedding
        self.pos_embedding =nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)

    def forward(self, input_data):
        input_data = input_data.to(self.device)

        # Patchify Input Image
        patches = einops.rearrange(
            input_data, "b c (h h1) (w w1) -> b (h w) (h1 w1 c)", h1 = self.patch_size, w1 = self.patch_size)
        
        #print(input_data.size())
        #print(patches.size())

        linear_projection = self.linearProjection(patches).to(self.device)
        b, n, _ = linear_projection.shape

        linear_projection = torch.cat((self.class_token, linear_projection), dim = 1) #entlang der 1. Dim hinzugefügt
        pos_embed = einops.repeat(self.pos_embedding, "b 1 d -> b m d", m = n + 1)
        
        #print(linear_projection.size())
        #print(pos_embed.size())

        linear_projection += pos_embed

        return linear_projection

In [39]:
test_input = torch.randn((8,3,224,224))
test_class = InputEmbedding().to(device)
embed_test = test_class(test_input)

Implement the encoder block.

In [40]:
class EncoderBlock(nn.Module):
    def __init__(self, latent_size = latent_size, num_heads = num_heads, device = device, dropout = dropout):
        super(EncoderBlock, self).__init__()

        self.latent_size = latent_size
        self.num_heads = num_heads
        self.device = device
        self.dropout = dropout

        #Normalization layer
        self.norm = nn.LayerNorm(self.latent_size)
      
        self.multihead = nn.MultiheadAttention(
            self.latent_size, self.num_heads, dropout = self.dropout)
        
        self.enc_MLP = nn.Sequential(
            nn.Linear(self.latent_size, self.latent_size*4),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.latent_size*4,self.latent_size),
            nn.Dropout(self.dropout)
        )

    def forward(self, embedded_patches):
        firstnorm_out = self.norm(embedded_patches)
        attention_out = self.multihead(firstnorm_out, firstnorm_out, firstnorm_out)[0]

        #first residual connection
        first_added = attention_out + embedded_patches

        secondnorm_out = self.norm(first_added)
        ff_output = self.enc_MLP(secondnorm_out)

        output = ff_output + first_added

        #print(embedded_patches.size())
        #print(output.size())

        return output

In [41]:
test_encoder = EncoderBlock().to(device)
test_encoder(embed_test)

tensor([[[ 5.2930e-01,  1.8278e-01,  1.6022e+00,  ..., -8.8799e-01,
           1.5994e+00, -2.6420e+00],
         [ 3.8978e-01,  1.0842e+00,  5.1064e-01,  ..., -3.4221e-01,
          -2.1966e-01, -1.1527e+00],
         [ 9.1445e-01,  5.4362e-01, -4.8330e-01,  ...,  5.4173e-01,
          -6.3929e-01, -1.8867e+00],
         ...,
         [ 8.4386e-02,  8.7726e-01,  5.5275e-01,  ..., -1.4566e-01,
          -8.4374e-01, -2.2914e+00],
         [-3.1912e-01, -5.6400e-01,  3.6490e-01,  ...,  2.0771e-01,
          -7.0197e-02, -2.5003e+00],
         [ 3.0916e-01,  5.8354e-01, -4.3770e-01,  ..., -4.7897e-01,
          -4.1521e-01, -1.6929e+00]],

        [[-9.4845e-01, -5.4261e-01,  1.1081e+00,  ...,  7.3119e-02,
           3.7971e-01,  6.8748e-01],
         [-2.1748e+00, -1.2811e+00, -6.4552e-01,  ...,  1.1636e+00,
          -2.0746e+00,  4.2849e-01],
         [-1.1954e+00,  5.2933e-01, -1.0266e+00,  ...,  1.2726e+00,
          -1.8763e+00,  1.7916e+00],
         ...,
         [-1.1037e+00, -1

Put everything together

In [42]:
class ViT(nn.Module):
    def __init__(self, num_encoders=num_encoders, latent_size=latent_size, device=device, num_classes=num_classes, dropout=dropout):
        super(ViT, self).__init__()

        self.num_encoders = num_encoders
        self.latent_size = latent_size
        self.device = device
        self.num_classes = num_classes
        self.dropout = dropout

        self.embedding = InputEmbedding()

        #Create Stack of Encoders
        self.encStack = nn.ModuleList(EncoderBlock() for i in range(self.num_encoders))

        self.MLP_head = nn.Sequential(
            nn.LayerNorm(self.latent_size),
            nn.Linear(self.latent_size, self.latent_size),
            nn.Linear(self.latent_size, self.num_classes)
        )

    def forward(self, test_input):
        enc_output = self.embedding(test_input)

        for enc_layer in self.encStack:
            enc_output = enc_layer(enc_output)

        cls_token_embed = enc_output[:, 0]

        return self.MLP_head(cls_token_embed)

Test ViT

In [43]:
model = ViT().to(device)
vit_output = model(test_input)
print(vit_output)
print(vit_output.size())

tensor([[ 0.0624, -0.4705,  0.1533,  0.2916, -0.2076, -0.5198, -0.0619,  0.6661,
          0.3178, -0.2557],
        [ 0.2574, -0.1355, -0.0260,  0.3431, -0.1586, -0.3548,  0.4580,  0.6288,
          0.1264, -0.0667],
        [ 0.1556,  0.0872, -0.0180, -0.0119, -0.6361, -0.5154,  0.3149, -0.1762,
         -0.0030,  0.4549],
        [-0.1011, -0.2652,  0.1837, -0.4612,  0.0316, -0.1454,  0.1524,  0.3946,
         -0.2170,  0.7480],
        [-0.0068, -0.4800, -0.2421, -0.4517, -0.0962,  0.1658,  0.4228, -0.2135,
          0.0531,  0.2083],
        [ 0.2849, -0.0120,  0.1465, -0.6187, -0.3015, -0.1000, -0.1478, -0.2426,
          0.4067, -0.2238],
        [ 0.0549, -0.0337,  0.1273, -0.0380,  0.3378, -0.4076,  0.1464, -0.1642,
          0.6452,  0.3868],
        [-0.0370,  0.1946, -0.3381, -0.2465, -0.4302, -0.3944, -0.0939, -0.5192,
          0.0309,  0.0898]], grad_fn=<AddmmBackward0>)
torch.Size([8, 10])
