In [3]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
print(f"torch version: {torch.__version__}")


torch version: 2.5.1+cu124


In [4]:
my_device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"using device: {my_device}")

using device: cuda


In [19]:
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size, in_channels, embed_dim):
        super().__init__()
        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, x):
        x = self.proj(x)
        print(x.shape)
        x = x.flatten(2).transpose(1, 2) # B, n_patches, embed_dim
        return x

In [24]:
class TinyViT(nn.Module):
    def __init__(self):
        super().__init__()
        img_size = 32
        patch_size = 4
        in_channels = 3
        embed_dim =48
        num_heads = 4
        dropout = 0.1
        num_layers = 4
        num_classes = 10
        mlp_ratio = 4

        self.patch_embedding = PatchEmbedding(patch_size, in_channels, embed_dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.zeros(1, 1 + (img_size // patch_size) ** 2, embed_dim))

        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            batch_first=True,
            norm_first=True,
            )
        self.transformer = nn.TransformerEncoder(
            self.encoder_layer,
            num_layers=num_layers,
        )
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, num_classes),
        )

            
    def forward(self, x):
        x = self.patch_embedding(x)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:,0] # cls token
        x = self.head(x)
        return x      

In [None]:
img = torch.randn(1, 3, 32, 32)
model = TinyViT()
out = model(img)
print(out.shape)