In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST
from torchsummary import summary

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, num_hiddens, img_size, 
                 patch_size):
        super().__init__()
        # Patchifying and linear map can be done with a convolution
        def _make_tuple(x):
            if not isinstance(x, (list, tuple)):
                return (x, x)
            return x
        img_size, patch_size = _make_tuple(img_size), _make_tuple(patch_size)
        self.num_patches = (img_size[0] // patch_size[0]) * (
                            img_size[1] // patch_size[1])
        self.conv = nn.Conv2d(in_channels, num_hiddens, kernel_size=patch_size,
                              stride=patch_size)
    
    def forward(self, x):
        # Reshapes image of shape (N, C, H, W) to (N, Patches, C)
        return self.conv(x).flatten(start_dim=2, end_dim=-1).transpose(1, 2)

In [None]:
class ViTMLP(nn.Module):
    def __init__(self, num_ins, mlp_num_hiddens, mlp_num_outputs, dropout):
        super().__init__()
        self.lin1 = nn.Linear(num_ins, mlp_num_hiddens)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.lin2 = nn.Linear(mlp_num_hiddens, mlp_num_outputs)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout2(self.lin2(self.dropout1(self.gelu(self.lin1(x)))))

In [None]:
class ViTBlock(nn.Module):
    def __init__(self, num_ins, num_hiddens, norm_shape, 
                 mlp_num_hiddens, num_heads, dropout):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(norm_shape)
        self.attention = nn.MultiheadAttention(num_hiddens, num_heads, dropout)
        self.layer_norm2 = nn.LayerNorm(norm_shape)
        self.mlp = ViTMLP(num_ins, mlp_num_hiddens, num_hiddens, dropout)

    def forward(self, x):
        ln = self.layer_norm1(x)
        att, _ = self.attention(ln, ln, ln)
        _x = att + x
        ln2 = self.layer_norm2(_x)
        out = self.mlp(ln2)
        return out + _x

In [None]:
class ViT(nn.Module):
    def __init__(self, img_size, patch_size, num_hiddens, mlp_num_hiddens, 
                 num_heads, num_blks, emb_dropout, blk_dropout, num_ins,
                 num_classes=10):
        super().__init__()
        self.patch_embedding = PatchEmbedding(1, num_hiddens, img_size, patch_size)
        self.class_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
        num_steps = self.patch_embedding.num_patches + 1
        self.pos_embedding = nn.Parameter(torch.randn(1, num_steps, num_hiddens))
        self.dropout = nn.Dropout(emb_dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module("blk"+str(i), ViTBlock(num_ins, num_hiddens,
                                                        num_hiddens, mlp_num_hiddens, 
                                                        num_heads, blk_dropout))
        self.head = nn.Sequential(nn.LayerNorm(num_hiddens), 
                                  nn.Linear(num_hiddens, num_classes))
        
    def forward(self, x):
        # 1. Patch embedding + linear map
        X = self.patch_embedding(x)
        X = torch.cat((self.class_token.expand(X.shape[0], -1, -1), X), 1)
        X = self.dropout(X + self.pos_embedding)
        # 2. Attention
        for blk in self.blks:
            X = blk(X)
        return self.head(X[:, 0])

In [None]:
transform = ToTensor()

train_set = MNIST(root="./data", train=True, download=True, transform=transform)
test_set = MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

In [None]:
num_hiddens, mlp_num_hiddens, num_heads, num_blks, num_ins = 256, 1024, 8, 1, 256
emb_dropout, blk_dropout, lr = 0.2, 0.2, 0.1
img_size, patch_size = 28, 7
model = ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads, 
            num_blks, emb_dropout, blk_dropout, num_ins)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

num_epochs = 5

for epoch in range(num_epochs):
    train_loss = 0
    for X, y in train_loader:
        y_hat = model(X)
        loss = criterion(y_hat, y)
        train_loss += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"loss on epoch: {epoch} was {train_loss:.2f}")