# Preprocessing

In [None]:
import torch as t
import torchvision as tv
import torchvision.transforms as transforms
import torch.utils.data as dl
import torch.nn as nn

import matplotlib.pyplot as plt

In [None]:
# Get MNIST data
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.dataset.MNIST(root='./data', train=True, download=True, transform=transform)
val_dataset = torchvision.dataset.MNIST(root='./data', train=False, download=True, transform=transform)

In [None]:
# Define Hyperparameters
batch_size = 64
img_size = 28
num_channels = 1
num_classes = 10
patch_size = 7
num_patches = (img_size / patch_size) ** 2
attn_heads = 4
emb_dim = 20
num_blocks = 4
mlp_nodes = 64

In [None]:
# Batch data with DataLoader
train_data = dl.DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
val_data = dl.DataLoader(val_dataset, shuffle=True, batch_size=batch_size)

In [None]:
# Class for PatchEmbedding

class PatchEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = nn.Conv2d(num_channels, emb_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.patch_embed(x) # [batch_size, embed_dim, pos_x, pos_y]
        x = x.flatten(2).transpose(1,2) # [batch_size, num_patches, embed_dim]
        return x

In [None]:
# Class for TransformerEncoder

class TransformerEncoder(nn.Module):
    def __init__(self):
        super.__init__()
        self.layer_norm1 = nn.LayerNorm(emb_dim)
        self.multi_head_attn = nn.MultiheadAttention(emb_dim, attn_heads)
        self.layer_norm2 = nn.LayerNorm(emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim),
            nn.GELU(),
            nn.Linear(mlp_nodes),
            nn.GELU(),
            nn.Linear(emb_dim),
        )

    def forward(self, x):
        residual1 = x
        x = self.layer_norm1(x)
        x = self.multi_head_attn(x, x, x)[0]
        x = x + residual1
        residual2 = x
        x = self.layernorm2(x)
        x = self.mlp(x)
        x = x + residual2
        return x

In [None]:
# Class for MLP Head

class MLP_Head(nn.Module):
    def __init__(self):
        super.__init__()
        self.layernorm = nn.LayerNorm(emb_dim)
        self.mlphead = nn.Sequential(
            nn.Linear(num_classes),
        )

    def forward(self, x):
        x = x[:,0]
        x = self.layernorm(x)
        x = self.mlphead(x)
        return x

In [None]:
# Class for VisionTransformer

class VisionTransformer(nn.Module):
    def __init__(self):
        super.__init__()
        self.patch_embedding = PatchEmbedding()
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
        self.position_embedding = nn.Parameter(torch.randn())
        self.transformer_blocks = TranformerEncoder()