<a href="https://colab.research.google.com/github/TapasKumarDutta1/Transformer/blob/main/vision_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [37]:
from torch import nn
import torch
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size, in_channels):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, in_channels, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, E, H, W)
        B,C,H,W = x.shape
        x = x.view(B,C,H*W)  # (B, E, H*W)
        x = x.transpose(1, 2)  # (B, H*W, E)
        return x

class Concatenate_CLS_Token(nn.Module):
    def __init__(self, N, embed_dim):
        super().__init__()
        self.CLS = nn.Parameter(torch.zeros(1, 1, embed_dim))
    def forward(self, x):
        BATCH, _, _ =x.shape
        CLS = self.CLS.repeat(BATCH,1,1)
        x   = torch.concatenate([CLS, x],1)
        return x


class Add_Positional_Embedding(nn.Module):
    def __init__(self, N, embed_dim):
        super().__init__()
        self.N = N
        self.positional = nn.Parameter(torch.zeros(1, 1, int(embed_dim)))
    def forward(self, x):
        BATCH, _, _ =x.shape
        positional = self.positional.repeat(BATCH,1,1)
        x = x+positional
        return x

class MultiHeadAttention(nn.Module):
  def __init__(self,head,in_dim,dim):
    super().__init__()
    self.head = head
    self.dim  = dim
    self.query_projection = nn.Linear(in_dim, dim*self.head)
    self.key_projection   = nn.Linear(in_dim, dim*self.head)
    self.value_projection = nn.Linear(in_dim, dim*self.head)

    self.out = nn.Linear(dim*self.head, in_dim)

  def forward(self, x):
    query = self.query_projection(x)   #B, N, D*H
    key   = self.key_projection(x)
    value = self.value_projection(x)


    B,N,D = x.shape

    query = query.view(B, N, self.head, self.dim)  # B, N, H, D
    key   = key.view(B, N, self.head, self.dim)
    value = value.view(B, N, self.head, self.dim)


    query = query.transpose(1,2)  # B, H, N, D
    key   = key.transpose(1,2)
    value = value.transpose(1,2)  # B, H, N, D


    attention_map = torch.matmul(query, key.transpose(-1,-2))   #B, H, N, N
    scaled_attention_map = attention_map / torch.sqrt(torch.tensor((self.dim)))
    scaled_attention_map = torch.nn.Softmax(-1)(scaled_attention_map)


    output = torch.matmul(scaled_attention_map, value)   # B, H, N, D
    output = output.transpose(1,2)  # B, N, H, D
    output = output.reshape(B, N, self.head*self.dim)
    return self.out(output)

class vision_transformer(nn.Module):

    def __init__(self,
        H = 32,
        W = 32,
        embed_dim  = 4,
        MLP_size   = 2,
        num_class  = 2,
        patch_size = 8,
        num_head   = 2,
        batch_size = 1,
        in_channel = 3,
        ):

        N = int(H*W/(patch_size**2))
        super(vision_transformer, self).__init__()
        self.preprocess = nn.Sequential(
            PatchEmbedding(patch_size, in_channel),
            Concatenate_CLS_Token(N, in_channel),
            Add_Positional_Embedding(N, in_channel)
        )
        self.transformer = []
        self.mlp = []

        self.transformer1 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    MultiHeadAttention(num_head,in_channel,embed_dim)
                    )
        self.mlp1 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    nn.Linear(in_channel,MLP_size),
                    nn.GELU(),
                    nn.Linear(MLP_size, in_channel)
                )
        self.transformer2 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    MultiHeadAttention(num_head,in_channel,embed_dim)
                    )
        self.mlp2 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    nn.Linear(in_channel, MLP_size),
                    nn.GELU(),
                    nn.Linear(MLP_size, in_channel)
                )
        self.transformer3 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    MultiHeadAttention(num_head,in_channel,embed_dim)
                    )
        self.mlp3 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    nn.Linear(in_channel, MLP_size),
                    nn.GELU(),
                    nn.Linear(MLP_size, in_channel)
                )
        self.transformer4 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    MultiHeadAttention(num_head,in_channel,embed_dim)
                    )
        self.mlp4 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    nn.Linear(in_channel, MLP_size),
                    nn.GELU(),
                    nn.Linear(MLP_size,in_channel)
                )
        self.transformer5 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    MultiHeadAttention(num_head,in_channel,embed_dim)
                    )
        self.mlp5 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    nn.Linear(in_channel, MLP_size),
                    nn.GELU(),
                    nn.Linear(MLP_size, in_channel)
                )
        self.transformer6 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    MultiHeadAttention(num_head,in_channel,embed_dim)
                    )
        self.mlp6 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    nn.Linear(in_channel, MLP_size),
                    nn.GELU(),
                    nn.Linear(MLP_size, in_channel)
                )
        self.transformer7 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    MultiHeadAttention(num_head,in_channel,embed_dim)
                    )
        self.mlp7 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    nn.Linear(in_channel, MLP_size),
                    nn.GELU(),
                    nn.Linear(MLP_size, in_channel)
                )
        self.transformer8 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    MultiHeadAttention(num_head,in_channel,embed_dim)
                    )
        self.mlp8 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    nn.Linear(in_channel, MLP_size),
                    nn.GELU(),
                    nn.Linear(MLP_size, in_channel)
                )
        self.transformer9 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    MultiHeadAttention(num_head,in_channel,embed_dim)
                    )
        self.mlp9 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    nn.Linear(in_channel, MLP_size),
                    nn.GELU(),
                    nn.Linear(MLP_size, in_channel)
                )
        self.transformer10 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    MultiHeadAttention(num_head,in_channel,embed_dim)
                    )
        self.mlp10 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    nn.Linear(in_channel, MLP_size),
                    nn.GELU(),
                    nn.Linear(MLP_size, in_channel)
                )
        self.transformer11 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    MultiHeadAttention(num_head,in_channel,embed_dim)
                    )
        self.mlp11 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    nn.Linear(in_channel, MLP_size),
                    nn.GELU(),
                    nn.Linear(MLP_size, in_channel)
                )
        self.transformer12 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    MultiHeadAttention(num_head,in_channel,embed_dim)
                    )
        self.mlp12 = nn.Sequential(
                    nn.LayerNorm([N+1, in_channel]),
                    nn.Linear(in_channel, MLP_size),
                    nn.GELU(),
                    nn.Linear(MLP_size, in_channel)
                )

        self.head = nn.Linear(in_channel, num_class)
    def forward(self, x):
        x  = self.preprocess(x)
        x1 = self.transformer1(x)
        x  = x+x1

        x1 = self.mlp1(x)
        x  = x+x1

        x1 = self.transformer2(x)
        x  = x+x1

        x1 = self.mlp2(x)
        x  = x+x1

        x1 = self.transformer3(x)
        x  = x+x1

        x1 = self.mlp3(x)
        x  = x+x1

        x1 = self.transformer4(x)
        x  = x+x1

        x1 = self.mlp4(x)
        x  = x+x1

        x1 = self.transformer5(x)
        x  = x+x1

        x1 = self.mlp5(x)
        x  = x+x1

        x1 = self.transformer6(x)
        x  = x+x1

        x1 = self.mlp6(x)
        x  = x+x1

        x1 = self.transformer7(x)
        x  = x+x1

        x1 = self.mlp7(x)
        x  = x+x1

        x1 = self.transformer8(x)
        x  = x+x1

        x1 = self.mlp8(x)
        x  = x+x1

        x1 = self.transformer9(x)
        x  = x+x1

        x1 = self.mlp9(x)
        x  = x+x1

        x1 = self.transformer10(x)
        x  = x+x1

        x1 = self.mlp10(x)
        x  = x+x1


        x1 = self.transformer11(x)
        x  = x+x1

        x1 = self.mlp11(x)
        x  = x+x1

        x1 = self.transformer12(x)
        x  = x+x1

        x1 = self.mlp12(x)
        x  = x+x1
        out    = x[:,0,:]
        out    = self.head(out)
        return out
