- imports

In [1]:
import torch
from torch import nn

- hyperparameters

In [2]:
# IMAGES
NUM_CLASSES = 10
PATCH_SIZE = 4
IMAGE_SIZE = 28
IN_CHANNELS = 1

# MODEL 
NUM_HEADS = 8
DROPOUT = 0.001
HIDDEN_DIM = 768
ACTIVATION = "gelu"
NUM_ENCODER_LAYERS = 4
EMBEDDING_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS   # 16
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2     # 49
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


- patches & position embeddings

In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 1, patch_size: int = 4, num_patches: int = 49, embed_dim: int = 16, dropout: float = 0.001):
        super().__init__()
        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=embed_dim,
                kernel_size=patch_size,
                stride=patch_size
            ),
            nn.Flatten(2)
        )

        self.cls_token = nn.Parameter(torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True)
        self.positional_embedding = nn.Parameter(torch.randn(size=(1, num_patches + 1, embed_dim)), requires_grad=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)  # original dimensions

        x = self.patcher(x).permute(0, 2, 1) 
        x = torch.cat((cls_tokens, x), dim=1)   # merge cls_token and patches
        x += self.positional_embedding
        x = self.dropout(x)

        return x

model = PatchEmbedding(IN_CHANNELS, PATCH_SIZE, NUM_PATCHES, EMBEDDING_DIM, DROPOUT)
model.to(device)

x = torch.randn(512, 1, IMAGE_SIZE, IMAGE_SIZE).to(device)
print(model(x).shape)  # torch.Size([512, 50, 16]) - batch size, num_patches + one cls_token for image, embedding_dim

torch.Size([512, 50, 16])


- ViT

In [4]:
class ViT(nn.Module):
    def __init__(self, num_classes: int = 10, img_size: int = 28, hidden_dim: int = 768, num_heads: int = 8, num_encoder_layers: int = 4, dropout: float = 0.001,
                 patch_size: int = 4, num_patches: int = 49, embed_dim: int = 16, activation: nn.GELU = nn.GELU, in_channels: int = 1):
        super().__init__()
        self.embeddings_block = PatchEmbedding(in_channels, patch_size, num_patches, embed_dim, dropout)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            activation=activation,
            dropout=dropout,
            batch_first=True,
            norm_first=False
        )
        self.encoder_block = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embeddings_block(x)
        x = self.encoder_block(x)
        x = self.mlp_head(x[:, 0, :])  # extract cls_token

        return x
    
model = ViT(NUM_CLASSES, IMAGE_SIZE, HIDDEN_DIM, NUM_HEADS, NUM_ENCODER_LAYERS, DROPOUT, PATCH_SIZE, NUM_PATCHES, EMBEDDING_DIM, ACTIVATION, IN_CHANNELS)
model = model.to(device)
x = torch.randn(512, 1, IMAGE_SIZE, IMAGE_SIZE).to(device)
print(model(x).shape)  # torch.Size([512, 10]) - batch size, num_classes

torch.Size([512, 10])
