In [1]:
import sys
import os

sys.path.append(os.path.join(os.getcwd(), ".."))


import torch
import torch.nn as nn
import torchinfo
from tqdm.notebook import tqdm 

from torchvision.datasets import CIFAR10

from vision_transformer.VisionTransformer import ViT
from src.read_config import read_config

In [2]:
config = read_config(config_path='../configs', config_name='transformer_params.yaml')

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize(config_path=config_path)


In [3]:
class TransformerClassifier(nn.Module): 
    def __init__(self, n_classes, img_size: int, in_channels: int, transformer_config): 
        super().__init__()
        
        self.vit = ViT(img_size=img_size, in_channels=in_channels, **transformer_config)

        self.classifier = nn.Sequential(
            nn.Linear(self.vit.d_model, self.vit.d_model//2), 
            nn.ReLU(), 
            
            nn.Linear(self.vit.d_model // 2, n_classes), 
            nn.ReLU(),
        )
        
    def forward(self, x): 
        embeddins = self.vit(x)
        logits = self.classifier(embeddins)
        
        return logits

In [4]:
classifier = TransformerClassifier(n_classes=10, img_size=32, in_channels=3, transformer_config=config['params'])

torchinfo.summary(classifier.vit, (15, 3, 32, 32), device='cpu')

Layer (type:depth-idx)                             Output Shape              Param #
ViT                                                [15, 768]                 768
├─PatchEmbedder: 1-1                               [15, 16, 768]             --
│    └─Conv2d: 2-1                                 [15, 768, 4, 4]           148,224
├─PositionalEncoding: 1-2                          [15, 16, 768]             --
├─Encoder: 1-3                                     [15, 17, 768]             --
│    └─ModuleList: 2-2                             --                        --
│    │    └─EncoderBlock: 3-1                      [15, 17, 768]             5,510,916
│    │    └─EncoderBlock: 3-2                      [15, 17, 768]             5,510,916
│    │    └─EncoderBlock: 3-3                      [15, 17, 768]             5,510,916
│    │    └─EncoderBlock: 3-4                      [15, 17, 768]             5,510,916
│    │    └─EncoderBlock: 3-5                      [15, 17, 768]             5,51

In [5]:
classifier(torch.rand(15, 3, 32, 32)).shape

torch.Size([15, 10])