In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        pass
    
    def forward(self, x):
        pass

In [None]:
class ViTMLP(nn.Module):
    def __init__(self, num_ins, mlp_num_hiddens, mlp_num_outputs, dropout):
        super().__init__()
        self.lin1 = nn.Linear(num_ins, mlp_num_hiddens)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.lin2 = nn.Linear(mlp_num_hiddens, mlp_num_outputs)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout2(self.lin2(self.dropout1(self.gelu(self.lin1(x)))))

In [None]:
class ViTBlock(nn.Module):
    def __init__(self, num_ins, embed_dims, num_hiddens, norm_shape, 
                 mlp_num_hiddens, mlp_num_outputs, num_heads, dropout):
        super().__init__()
        self.layer_norm1 = nn.LayerNorm(norm_shape)
        self.attention = nn.MultiheadAttention(embed_dims, num_heads, dropout)
        self.layer_norm2 = nn.Layernorm(norm_shape)
        self.mlp = ViTMLP(num_ins, mlp_num_hiddens, mlp_num_outputs, dropout)

    def forward(self, x):
        ln = self.layer_norm1(x)
        att = self.attention(ln, ln, ln)
        _x = att + x
        ln2 = self.layer_norm2(_x)
        out = self.mlp(ln2)
        out = out + _x

In [None]:
class ViT(nn.Module):
    def __init__(self, img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads, num_blks,
                 emb_dropout, blk_dropout):
        super().__init__()
        self.patch_embeddings = PatchEmbedding(img_size, patch_size, num_hiddens)
        self.dropout = nn.Dropout(emb_dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module("blk"+str(i), ViTBlock(num_ins, embed_dims, num_hiddens,
                                                        mlp_num_hiddens, num_heads,
                                                        blk_dropout)
        self.head = nn.Sequential(nn.LayerNorm(num_hiddens), 
                                  nn.Linear(num_hiddens, num_classes))
        
    def forward(self, x):
        X = self.patch_embedding(X)
        X = torch.cat((self.cls_token.expand(X.shape[0], -1, -1), X), 1)
        X = self.dropout(X + self.pos_embedding)
        for blk in self.blks:
            X = blk(X)
        return self.head(X[:, 0])

In [None]:
transform = ToTensor()

train_set = MNIST(root="./data", train=True, download=True, transform=transform)
test_set = MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

In [None]:
num_epochs = 5
lr = 0.1
optimizer = torch.optim.Adam(model.parameters, lr=lr)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for X, y in train_loader:
        y_hat = model(X)
        loss = criterion(y_hat, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"loss on epoch: {epoch} was {loss:.2f}")