- imports

In [1]:
import torch
from torch import nn

- hyperparameters

In [5]:
# IMAGES
NUM_CLASSES = 10
PATCH_SIZE = 4
IMAGE_SIZE = 28
IN_CHANNELS = 1

# MODEL
LEARNING_RATE = 0.0001
NUM_HEADS = 8
DROPOUT = 0.001
HIDDEN_DIM = 768
ADAM_WEIGHT_DECAY = 0.0001
ADAM_BETAS = (0.9, 0.999)
ACTIVATION = nn.GELU
NUM_ENCODER_LAYERS = 4
EMBEDDING_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS   # 16
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2     # 49

# TRAINING
device = "cuda" if torch.cuda.is_available() else "cpu"

- patches & position embeddings

In [8]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 1, patch_size: int = 4, num_patches: int = 49, embed_dim: int = 16, dropout: float = 0.001):
        super().__init__()
        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=embed_dim,
                kernel_size=patch_size,
                stride=patch_size
            ),
            nn.Flatten(2)
        )

        self.cls_token = nn.Parameter(torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True)
        self.positional_embedding = nn.Parameter(torch.randn(size=(1, num_patches + 1, embed_dim)), requires_grad=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)  # original dimensions

        x = self.patcher(x).permute(0, 2, 1) 
        x = torch.cat((cls_tokens, x), dim=1)   # merge cls_token and patches
        x = self.dropout(x)

        return x

model = PatchEmbedding(IN_CHANNELS, PATCH_SIZE, NUM_PATCHES, EMBEDDING_DIM, DROPOUT)
model.to(device)

x = torch.randn(512, 1, IMAGE_SIZE, IMAGE_SIZE).to(device)
print(model(x).shape)  # torch.Size([512, 50, 16]) - batch size, num_patches + one cls_token for image, embedding_dim

torch.Size([512, 50, 16])
