In [190]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

## Hyperparameters

In [191]:
batch_size = 32
learning_rate = 0.001

## Datasets

In [192]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize(384)])
train_data = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('../data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)

![ViT IMAGE](assets/vit.png)

### Patching and Embedding

In [193]:
class PatchEmbed(nn.Module):
	"""Split image into patches and then embed them.

	Params:
		img_size: int

		patch_size: int

		in_chans: int
		
		embed_dim: int

	Attributes:
		n_patches : int
			Number of patches inside of our image
		
		proj : nn.Conv2d
			Convolutional layer that does both the splitting into patches and their embedding
	"""
	def __init__(self, img_size, patch_size, in_chans=1, embed_dim=384*2):
		super().__init__()
		self.img_size = img_size
		self.patch_size = patch_size
		self.n_patches = (img_size // patch_size) ** 2
		self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

	def forward(self, x):
		x = self.proj(x)
		x = x.flatten(2) # flatten patches into a single dimension
		x = x.transpose(1, 2)

		return x

### Attention Mechanism

In [194]:
class Attention(nn.Module):
	"""Attention mechanism
	
	Params:
		dim: int
			input and output dimension of token features
			
		n_heads: int
			number of attention heads
		
		qkv_bias: bool
			if True, we include bias to query, key and projections

		attn_p: float
			dropout probability applied to the query, key and value tensors

		proj_p: float
			dropout probability applied to the output tensor

	Attributes:
		scale: float
			normalizing for the doat product

		qkv: nn.Linear
			get a token and make a linear projection for the query, key and value

		proj: nn.Linear
			linear mapping that takes in the concatenated output of all attention heads and maps it into a new space
		
		attn_drop, proj_drop: nn.Dropout
			dropout layers
		"""
	def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
		super().__init__()
		self.n_heads = n_heads
		self.dim = dim
		self.head_dim = dim // n_heads
		self.scale = self.head_dim ** -0.5

		self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
		self.attn_drop = nn.Dropout(attn_p)
		self.proj = nn.Linear(dim, dim)
		self.proj_drop = nn.Dropout(proj_p)

	def forward(self, x):
		n_samples, n_tokens, dim = x.shape

		if dim != self.dim:
			raise ValueError
		
		qkv = self.qkv(x)
		qkv = qkv.reshape(n_samples, n_tokens, 1, self.n_heads, self.head_dim) # (n_samples, n_patches + 1, 3, n_heads, head_dim)
		qkv = qkv.permute(2, 0, 3, 1, 4) # (3, n_samples, n_heads, n_patches + 1, head_dim)
		q, k, v = qkv[0], qkv[1], qkv[2]
		k_t = k.transpose(-2, -1) # (n_samples, n_heads, head_dim, n_patches + 1)
		dot_product = (q @ k_t) * self.scale # (n_samples, n_heads, n_patches + 1, n_patches + 1)
		attn = dot_product.softmax(dim=-1)

		weighted_average = attn @ v
		weighted_average = weighted_average = weighted_average.transpose(1, 2) #(n_samples, n_patches + 1, n_heads, head_dim)
		weighted_average = weighted_average.flatten(2) # (n_samples, n_patches + 1, dim)

		x = self.proj(weighted_average) # (n_samples, n_patches + 1, dim)
		x = self.proj_drop(x)

		return x

### Multilayer Perceptron

In [195]:
class MLP(nn.Module):
	"""Multilayer perceptron
	
	Params:
		in_features: int
			number of input features

		hidden_features: int
			number of nodes in the hidden layer

		out_features: int
			number of output features
		
		p: float
			dropout probability
	
	Attributes:
		fc: nn.Linear
			first linear layer
		
		act: nn.GELU
			GELU action function
		
		fc2: nn.Linear
			second linear layer

		drop: nn.Dropout
			Dropout layer
	"""

	def __init__(self, in_features, hidden_features, out_features, p=0.):
		super().__init__()
		self.fc1 = nn.Linear(in_features, hidden_features)
		self.act = nn.GELU()
		self.fc2 = nn.Linear(hidden_features, out_features)
		self.drop = nn.Dropout(p)

	def forward(self, x):
		x = self.fc1
		x = self.act(x)
		x = self.drop(x)
		x = self.fc2(x)
		x = self.drop(x)

		return x

### Block Module

In [196]:
class Block(nn.Module):
	"""Transformer Block
	
	Params:
		dim: int
			embedding dimension

		n_heads: int
			number of attention heads
			
		mlp_ratio: float
			hidden_dimension size of the MLP module with respect to dim
		
		qkv_bias: bool
			if True, include bias to query, key and value projections

		p, attn_p: float
			dropout probability

	Attributes:
		norm1, norm2: LayerNorm
			normalization layers

		attn: Attention
			Attention module

		mlp: MLP
			MLP Module
	"""

	def __init__(self, dim, n_heads, mlp_ratio=0.4, qkv_bias=True, p=0., attn_p=0.):
		super().__init__()
		self.norm1 = nn.LayerNorm(dim, eps=1e-6)
		self.attn = Attention(dim, n_heads=n_heads, qkv_bias=qkv_bias, attn_p=attn_p, proj_p=p)
		self.norm2 = nn.LayerNorm(dim, eps=1e-6)

		hidden_features = int(dim * mlp_ratio)
		self.mlp = MLP(in_features=dim, hidden_features=hidden_features,out_features=dim)

	def forward(self, x):
		x = x + self.attn(self.norm1(x))
		x = x + self.mlp(self.norm2(x))

		return x

### Vision Transformer Module

In [197]:
class VisionTransformer(nn.Module):
	"""Vision Transformer module
	
	Params:
		img_size: int
			height and width of a squared image
			
		patch_size: int
			height and width of the patch
			
		in_chans: int
			number of input channel
			
		n_classes: int
			number of classes
			
		embed_dim: int
			dimension of the token/patch embedding
			
		depth: int
			number of blocks
			
		n_heads: int
			number of attention heads

		mlp_ratio: float
			hidden dimension of the MLP module

		qkv_bias: bool
			if True, include bias to query, key and value projections

		p, attn_p: float
			dropout probability

	Attributes:
		patch_embed: PatchEmbed
			instance of PatchEmbed layer

		cls_token: nn.Parameter
			learnable parameter that is the first token in the sequence

		pos_emb: nn.Parameter
			Positionnal embedding of the cls token + all the patches

		pos_drop: nn.Dropout
			dropout layer

		blocks: nn.ModuleList
			List of Block modules.

		norm: nn.LayerNorm
			Layer normalization
		"""
	def __init__(self, img_size=28, patch_size=5, in_chans=1, n_classes=10, embed_dim=56, depth=12, n_heads=3, mlp_ratio=4, qkv_bias=True, p=0., attn_p=0.):
		super().__init__()

		self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
		self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
		self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim))
		self.pos_drop = nn.Dropout(p=p)
		self.blocks = nn.ModuleList([
			Block(dim=embed_dim, n_heads=n_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, p=p, attn_p=attn_p) 
			for _ in range(depth)
		])

		self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
		self.head = nn.Linear(embed_dim, n_classes)

	def forward(self, x):
		n_samples = x.shape[0]
		x = self.patch_embed(x)

		cls_token = self.cls_token.expand(n_samples, -1, -1) # (n_samples, 1, embed_dim)
		x = torch.cat((cls_token, x), dim = 1) # (n_samples, 1 + n_patches, embed_dim)
		x = self.pos_drop(x)

		for block in self.blocks:
			x = block(x)

		x = self.norm(x)

		cls_token_final = x[:, 0]
		x = self.head(cls_token_final)

		return x

In [198]:
model = VisionTransformer()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [199]:
def train(train_loader, model, loss_fn, optimizer):
	size = len(train_loader.dataset)
	batches_l = len(train_loader)
	loss = 0
	correct = 0

	for batch_idx, (data, target) in enumerate(train_loader):
		pred = model(data)
		loss = loss_fn(pred, target)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		if batch_idx % 100 == 0:
			print(f'{loss= }')

In [200]:
train(train_loader, model, loss_fn, optimizer)

RuntimeError: shape '[32, 5777, 1, 3, 18]' is invalid for input of size 31057152