In [3]:
import torch
from torch import nn
from torchvision import transforms
from torchinfo import summary
import matplotlib.pyplot as plt
from os import cpu_count

import engine
import data_setup
from helper_function import set_seeds, plot_loss_curves
from data_download import download_dataset

device = "cuda" if torch.cuda.is_available() else "cpu"

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
# training and testing directories
TRAIN_DIR = "./data/pizza_steak_sushi/train"
TEST_DIR = "./data/pizza_steak_sushi/test"

IMG_SIZE = 224
BATCH_SIZE = 32

# ViT image transformation
simple_transform = transforms.Compose([
    transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor()
])

# Creates training and testing dataloaders
train_dl, test_dl, class_names = data_setup.create_dataloaders(train_dir=TRAIN_DIR,
                                                    test_dir=TEST_DIR,
                                                    transform=simple_transform,
                                                    batch_size=BATCH_SIZE,
                                                    num_workers=cpu_count())

In [5]:
# # batch of img, labels
# img_batch, label_batch = next(iter(train_dl))
# # single image
# img, label = img_batch[0], label_batch[0]
# img_permute = img.permute(1,2,0)

# plt.imshow(img_permute)
# plt.title(f"Class is {class_names[label]}")
# plt.show()

In [6]:
# ## experimenting to see patch inputs, outputs and patches

# # Get image dimensions and patches
# H, W, C = 224, 224, 3 #img_permute.shape
# patch_size = 16 #patch dimension
# N = H*W // patch_size**2 # total number of patches
# assert H*W % patch_size**2 == 0, "Image must be divided into equal patches"
# # print(f"The sequence of patches has length {N}")
# D = 768 # P**2 * C

# # G
# patch_per_row = H // patch_size
# patch_per_col = W // patch_size

# # Create a series of subplots
# fig, axs = plt.subplots(nrows=patch_per_row,
#                         ncols=patch_per_col,
#                         figsize=(H/patch_size, W/patch_size), # no. of patches per col & row
#                         sharex=True,
#                         sharey=True)

# for i, patch_i in enumerate(range(0, H, patch_size)):
#     for j, patch_j in enumerate(range(0, W, patch_size)):
#         axs[i][j].imshow(img_permute[patch_i:patch_i+patch_size, patch_j:patch_j+patch_size, :])
#         axs[i][j].set_xticks([])
#         axs[i][j].set_yticks([])

# plt.tight_layout()
# plt.show()

In [7]:
# create image patch embeddnigs (2D image linear projection layer) / Implementation of Equation 1
class Patch_Embed_Layer(nn.Module):
    def __init__(self, in_channels, num_patch, patch_size, embed_size) -> None:
        super().__init__()

        self.patch_embedding_layer = nn.Sequential(
            # extract the 2D feature maps (learnable patches)
            nn.Conv2d(in_channels=in_channels,
                      out_channels=num_patch,
                      kernel_size=patch_size,
                      stride=patch_size,
                      padding=0),
            # flatten the feature maps
            nn.Flatten(start_dim=2,
                       end_dim=3),
            #  linear projection to create patch embedings
            nn.Linear(in_features=14*14, out_features=embed_size)
        )
    
    def forward(self, x):
        # concat class embedding with patch embedding
        # have to generate batchwise class tokens!!!
        # x = torch.concat([self.class_embedding, self.patch_embedding_layer(x)], dim=1)
        
        # generate positional embeddng for given sequence
        return self.patch_embedding_layer(x)

# BATCH_SIZE = 2
# embed_size = 768
# num_patch = 196
# patch_size = 16
# # print(f"Image shape {img.unsqueeze(0).shape}")
# dm_img = torch.rand(size=(BATCH_SIZE, 3, 224, 224))
# print(dm_img.shape)
# patch_embd = Patch_Embed_Layer(num_patch, patch_size, embed_size)
# patch_embd_output = patch_embd(dm_img)
# print(f"Image patches embeding shape {patch_embd_output.shape}")

In [8]:
# create MSA Block / Implementation of Equation 2 without skip connection
class MultiheadAttention(nn.Module):
    def __init__(self, num_heads, embed_size, attn_dropout=0) -> None:
        super().__init__()

        # qkv for input in the attention block
        self.query = nn.Linear(embed_size, embed_size, bias=False)
        self.key = nn.Linear(embed_size, embed_size, bias=False)
        self.value = nn.Linear(embed_size, embed_size, bias=False)

        # layer norm
        self.layer_norm = nn.LayerNorm(normalized_shape=embed_size)
        
        # attention block
        self.self_attention_layer = nn.MultiheadAttention(embed_dim=embed_size,
                                                          num_heads=num_heads,
                                                          dropout=attn_dropout,
                                                          batch_first=True)
        
    def forward(self, x):

        x = self.layer_norm(x)

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # get the attn output and ignore attn output weights
        return self.self_attention_layer(query=q,
                                         key=k,
                                         value=v,
                                         need_weights=False)[0]

# msa_layer = MultiheadAttention(4, 768)
# msa_output = msa_layer(patch_embd_output)
# msa_output.shape

In [9]:
# MLP Block / Implementation of Equation 3 without skip connection

class MultiLayerPerceptron(nn.Module):
    def __init__(self,
                 embed_size, # hidden size D from Table 1
                 mlp_size:int=3072, # from Table 1 of ViT-Base
                 dropout:float=0.1): # from Table 3 of ViT-Base
        super().__init__()
        
        # layer normalization
        self.layer_norm = nn.LayerNorm(normalized_shape=embed_size)

        # mlp layer as specified in ViT Paper
        self.mlp_layer = nn.Sequential(
            nn.Linear(embed_size, mlp_size),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(mlp_size, embed_size),
            nn.Dropout(p=dropout)
        )

    def forward(self, x):
        return self.mlp_layer(self.layer_norm(x))
    
# mlp_layer = MultiLayerPerceptron(embed_size)
# output_mlp_layer = mlp_layer(msa_output)
# output_mlp_layer.shape

In [10]:
# Create Transformer encoder block
class EncoderBlock(nn.Module):
    def __init__(self, num_heads, embed_size, mlp_size, dropout_attn, dropout) -> None:
        super().__init__()

        # Multi-headed attention block with layer norm
        self.msa_block = MultiheadAttention(num_heads,
                                            embed_size,
                                            dropout_attn)
        
        # MLP Block with layer norm
        self.mlp_block = MultiLayerPerceptron(embed_size,
                                              mlp_size,
                                              dropout)
        
    def forward(self, x):
        # msa with layer norm + skip connection
        x = self.msa_block(x) + x
        # mlp with layer norm + skip connection
        return self.mlp_block(x) + x
    
# encoder = EncoderBlock(4, embed_size, 3072, 0.1)
# import timeit
# start = timeit.default_timer()
# encoder(patch_embd_output).shape
# print(timeit.default_timer() - start)

# summary(encoder)

In [11]:
# Complete ViT Architecture (ViT Base Model)
BATCH_SIZE = 2
IMAGE_SIZE = 224  # image resolution
IMG_CHANNELS = 3  # image channels
NUM_CLASSES = 3 # output labels
PATCH_SIZE = 16  # dimension of the image patches
NUM_PATCH = IMAGE_SIZE**2 // PATCH_SIZE**2  # number of image patches
D_MODEL = 768  # patch embedding dimension throughout ViT
NUM_HEADS = 12 # number of heads for multiheaded attention block
NUM_LAYERS = 12 # number of encoder blocks in ViT
MLP_SIZE = 3072 # size of the MLP block in Encoder
DROPOUT_EMBEDS = 0.1 # dropout of the patch embeddings
DROPOUT_MLP = 0.1 # dropout of the MLP block in Encoder
DROPOUT_ATTN = 0 # dropout of the MSA block in Encoder

class ViTBase(nn.Module):
    def __init__(self, in_channels, num_patch, patch_size, embed_size, num_heads, num_layers, mlp_size, dropout_embeds, dropout_mlp, dropout_attn, out_features) -> None:
        super().__init__()
        
        # Image patch embeddings
        self.patch_embedding = Patch_Embed_Layer(in_channels, num_patch, patch_size, embed_size)
        
        # Class token embedding
        self.class_embedding = nn.Parameter(torch.randn(size=(1, 1, embed_size), requires_grad=True))
        
        # Postition embeddings of flatten patches
        self.position_embedding = nn.Parameter(torch.randn(size=(1, num_patch+1, embed_size), requires_grad=True))

        # Embeddings dropout
        self.embedding_dropout = nn.Dropout(p=dropout_embeds)
        
        # Encoder block layers
        self.encoder_blocks = nn.Sequential(
            *[EncoderBlock(num_heads, embed_size, mlp_size, dropout_attn, dropout_mlp) for _ in range(num_layers)]
        )

        # Classification head for image classification
        self.classifier_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_size),
            nn.Linear(embed_size, out_features)
        )

    def forward(self, x):
        batch_size = x.shape[0]  # Get batch dimension

        # Create the image patches' embeddings
        patch_embeddings = self.patch_embedding(x)
        
        # Create the output tokens similar to BERT's class token for the batch
        class_embedding = self.class_embedding.expand(batch_size, -1, -1)

        # Create the complete sequence of embeddings with class token, patch and positional embeddings
        x_embedded = torch.concat([class_embedding, patch_embeddings], dim=1) + self.position_embedding

        # Run embedding dropout (Appendix B.1 in ViT Paper)
        x_embedded_drpout = self.embedding_dropout(x_embedded)
        
        # Pass the input through the Encoder layers
        encoder_output = self.encoder_blocks(x_embedded_drpout)
        
        # Get the class token embeddings
        class_token = encoder_output[:, 0,:]
        
        # Get the output labels from class token
        return self.classifier_head(class_token)

dm_img = torch.rand(size=(BATCH_SIZE, 3, 224, 224))
# print(dm_img.shape)
vit_model = ViTBase(IMG_CHANNELS, NUM_PATCH, PATCH_SIZE, D_MODEL, NUM_HEADS, NUM_LAYERS, MLP_SIZE, DROPOUT_EMBEDS, DROPOUT_MLP, DROPOUT_ATTN, NUM_CLASSES)
# vit_model(dm_img)

In [12]:
summary(model=vit_model, 
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                           Input Shape          Output Shape         Param #              Trainable
ViTBase (ViTBase)                                                 [32, 3, 224, 224]    [32, 3]              152,064              True
├─Patch_Embed_Layer (patch_embedding)                             [32, 3, 224, 224]    [32, 196, 768]       --                   True
│    └─Sequential (patch_embedding_layer)                         [32, 3, 224, 224]    [32, 196, 768]       --                   True
│    │    └─Conv2d (0)                                            [32, 3, 224, 224]    [32, 196, 14, 14]    150,724              True
│    │    └─Flatten (1)                                           [32, 196, 14, 14]    [32, 196, 196]       --                   --
│    │    └─Linear (2)                                            [32, 196, 196]       [32, 196, 768]       151,296              True
├─Dropout (embedding_dropout)                              

In [13]:
# from torchvision import models

# vit_weights = models.ViT_B_16_Weights.DEFAULT
# vit_transfermodel = models.vit_b_16(weights=vit_weights)

# vit_transfermodel.classifier