# ViT Classifier

In [1]:
import torch
from torch import nn

In [25]:
28*28 - 16

768

In [15]:
LEARNING_RATE = 1e-4
NUM_CLASSES = 10
PATCH_SIZE = 4
IMG_SIZE = 28
IN_CHANNELS = 1
NUM_HEADS = 8
DROPOUT = 0.001
HIDDEN_DIM = 768 # hidden dimension of MLP head
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.999)
ACTIVATION = "gelu"
NUM_ENCODERS = 4
EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS # flattened patch dimension same as number ofelements in patch matrix
NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2

device = "cuda" if torch.cuda.is_available() else "cpu"

In [98]:
class PatchEmbedding(nn.Module):
    def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
        super().__init__()
        
        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=embed_dim,
                kernel_size=patch_size,
                # can get overlapping patches here by different stride values
                stride=patch_size,
            ),
            # for batches dim 2
            nn.Flatten(2)
        )
        
        # to make learnable embedding
        self.cls_token = nn.Parameter(torch.randn((1, in_channels, embed_dim)), requires_grad=True)
        
        # learnable position embeddings
        self.position_embeddings = nn.Parameter(torch.randn((1, num_patches+1, embed_dim)), requires_grad=True)
        
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x):
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        
        x = self.patcher(x).permute(0, 2, 1)
        print(x.shape)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.position_embeddings
        x = self.dropout(x)
        
        return x
        

In [28]:
# test 
model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS)
model(torch.randn((10, 1, 28, 28))).shape

torch.Size([10, 49, 16])


torch.Size([10, 50, 16])

In [17]:
class ViT(nn.Module):
    def __init__(self, num_patches, img_size, num_classes, patch_size, embed_dim, num_encoders, num_heads, hidden_dim, dropout, activation, in_channels):
        super().__init__()
        
        self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation, batch_first=True, norm_first=True)
        
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)
        
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes)
        )
        
    def forward(self, x):
        x = self.embeddings_block(x)
        print("Patch embeddings: ", x.shape)
        
        x = self.encoder_blocks(x)
        print("Encode output: ", x.shape)
        
        # only first token used to make prediction of _class_
        x = self.mlp_head(x[:, 0, :])
        print("MLP Head output: ", x.shape)
        
        return x

In [18]:

model = ViT(NUM_PATCHES, IMG_SIZE, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, HIDDEN_DIM, DROPOUT, ACTIVATION, IN_CHANNELS).to(device)
x = torch.randn(512, 1, 28, 28).to(device)
print(model(x).shape) # BATCH_SIZE X NUM_CLASSES



Patch embeddings:  torch.Size([512, 50, 16])
Encode output:  torch.Size([512, 50, 16])
MLP Head output:  torch.Size([512, 10])
torch.Size([512, 10])


In [21]:
print(f"{sum(p.numel() for p in model.parameters()):_}")

276_298


# Vision Transformer for segmentation

In [92]:
class PatchEmbedding_raw(nn.Module):
    """
    Breaks down raw image to linear layers
    """
    def __init__(self, in_channels, patch_size, num_patches):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.in_channels = in_channels
        
        # learnable position embeddings
        self.position_embeddings = nn.Parameter(torch.randn((1, num_patches, in_channels*patch_size**2)), requires_grad=True)
        
    def forward(self, x):
        b, c, _, _ = x.shape
        x = (x
             .permute(0, 2, 3, 1)
             .flatten(1)
             .reshape(b, self.num_patches, self.in_channels*self.patch_size**2))
        
        x = x + self.position_embeddings
        
        return x

In [95]:
model = PatchEmbedding_raw(3, 2, 4)
model(torch.tensor(range(4*4*3*2))
 .reshape((2,3,4,4))).shape

torch.Size([2, 4, 12])

In [99]:
# pipeline to linearize pixels of image
(torch.tensor(range(27*2))
 .reshape((2,3,3,3))
 .permute(0, 2, 3, 1)
 .flatten(1)
 .reshape(2, 3*3, 3) # batch, num tokens, items in each
)

tensor([[[ 0,  9, 18],
         [ 1, 10, 19],
         [ 2, 11, 20],
         [ 3, 12, 21],
         [ 4, 13, 22],
         [ 5, 14, 23],
         [ 6, 15, 24],
         [ 7, 16, 25],
         [ 8, 17, 26]],

        [[27, 36, 45],
         [28, 37, 46],
         [29, 38, 47],
         [30, 39, 48],
         [31, 40, 49],
         [32, 41, 50],
         [33, 42, 51],
         [34, 43, 52],
         [35, 44, 53]]])

In [106]:
model = PatchEmbedding_raw(in_channels=3, 
                           patch_size=10,
                           num_patches=(50//10)**2)

model(torch.randn(1, 3, 50, 50)).shape

torch.Size([1, 25, 300])

In [135]:
class TrasnformerSegmenter(nn.Module):
    def __init__(self, patcher, segmentation_head, embed_dim, num_encoders, num_heads, dropout, activation):
        super().__init__()
        
        self.patcher = patcher
                
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation, batch_first=True)
        
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)
        
        self.segmentation_head = segmentation_head
        
    def forward(self, x):
        shape = x.shape
        
        x = self.patcher(x)
        print(x.shape)
        
        x = self.encoder_blocks(x)
        print(x.shape)
        
        # imageify
        x = x.reshape(shape)
        
        x = self.segmentation_head(x)
        
        return x

In [136]:
model = TrasnformerSegmenter(
    PatchEmbedding_raw(3, 2, (50//2)**2),
    nn.Sequential(
        nn.Conv2d(3, 64, 3, 1, 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.Conv2d(64, 32, 3, 1, 1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.Conv2d(32, 1, 3, 1, 1),
    ),
    3* (2*2),
    2,
    4,
    0.1,
    "gelu")


In [137]:
# pipeline to linearize pixels of image
(torch.tensor(range(27*2))
 .reshape((2,3,3,3))
#  .permute(0, 2, 3, 1)
#  .flatten(1)
#  .reshape(2, 3*3, 3) # batch, num tokens, items in each
)

tensor([[[[ 0,  1,  2],
          [ 3,  4,  5],
          [ 6,  7,  8]],

         [[ 9, 10, 11],
          [12, 13, 14],
          [15, 16, 17]],

         [[18, 19, 20],
          [21, 22, 23],
          [24, 25, 26]]],


        [[[27, 28, 29],
          [30, 31, 32],
          [33, 34, 35]],

         [[36, 37, 38],
          [39, 40, 41],
          [42, 43, 44]],

         [[45, 46, 47],
          [48, 49, 50],
          [51, 52, 53]]]])

In [120]:
# pipeline to imagify the encoder output to [2, 3, 3, 3]
(torch.tensor(range(27*2))
 .reshape((2,9,3))
 .reshape((2, 3, 3, 3))
)

tensor([[[[ 0,  1,  2],
          [ 3,  4,  5],
          [ 6,  7,  8]],

         [[ 9, 10, 11],
          [12, 13, 14],
          [15, 16, 17]],

         [[18, 19, 20],
          [21, 22, 23],
          [24, 25, 26]]],


        [[[27, 28, 29],
          [30, 31, 32],
          [33, 34, 35]],

         [[36, 37, 38],
          [39, 40, 41],
          [42, 43, 44]],

         [[45, 46, 47],
          [48, 49, 50],
          [51, 52, 53]]]])

In [138]:
model(torch.randn((2, 3, 50, 50))).shape

torch.Size([2, 625, 12])
torch.Size([2, 625, 12])


torch.Size([2, 1, 50, 50])

In [None]:
class ViT(nn.Module):
    def __init__(self, num_patches, img_size, num_classes, patch_size, embed_dim, num_encoders, num_heads, hidden_dim, dropout, activation, in_channels):
        super().__init__()
        
        self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation, batch_first=True, norm_first=True)
        
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)
        
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes)
        )
        
    def forward(self, x):
        x = self.embeddings_block(x)
        print("Patch embeddings: ", x.shape)
        
        x = self.encoder_blocks(x)
        print("Encode output: ", x.shape)
        
        # only first token used to make prediction of _class_
        x = self.mlp_head(x[:, 0, :])
        print("MLP Head output: ", x.shape)
        
        return x