In [12]:
import torch
import torchvision
import matplotlib.pyplot as plt
import torchvision.transforms as transform
import torch.utils.data as dataloader
import torch.nn as nn

In [13]:
batch_size = 64
num_classes = 10
img_size = 28
patch_size = 7
patch_num = (img_size // patch_size) ** 2
attn_heads = 4
embed_dim = 20
num_transformer_block = 4
mlp_nodes = 64
input_channel = 1
learning_rate = 0.001
epochs = 5
dropout = 0.1

In [14]:
transform = transform.Compose([
    transform.ToTensor(),
])

In [15]:
train_dataset = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=transform)
val_dataset = torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=transform)

In [16]:
train_data = dataloader.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
val_data = dataloader.DataLoader(dataset=val_dataset,batch_size=batch_size,shuffle=True)

In [17]:
class MultiheadAttention(nn.Module):
  def __init__(self,d_in,d_out,num_heads,dropout,context_len,qvk_bias = False):
    super().__init__()
    assert (d_out % num_heads == 0), "Output Dimesion is not divisble by number of heads"

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads

    self.W_query = nn.Linear(d_in,d_out,bias = qvk_bias)
    self.W_key = nn.Linear(d_in,d_out,bias = qvk_bias)
    self.W_value = nn.Linear(d_in,d_out,bias = qvk_bias)

    self.out_proj = nn.Linear(d_out,d_out)

    self.dropout = nn.Dropout(dropout)

  def forward(self,x):
    b,num_tokens,d_in = x.shape #here d_in is embeding_dimesion

    keys = self.W_key(x)
    query = self.W_query(x)
    value = self.W_value(x)

    keys = keys.view(b,num_tokens,self.num_heads,self.head_dim)
    query = query.view(b,num_tokens,self.num_heads,self.head_dim)
    value = value.view(b,num_tokens,self.num_heads,self.head_dim)


    keys = keys.transpose(1,2)
    query = query.transpose(1,2)
    value = value.transpose(1,2)

    attn_scores = query @ keys.transpose(-2,-1)

    attn_weights = torch.softmax(attn_scores / (self.head_dim**0.5),dim = -1)
    attn_weights = self.dropout(attn_weights)

    context_vec = (attn_weights @ value).transpose(1,2)

    #context_vec = context_vec.flatten(2) it can also be used and works but it can create a copy so the space will increase
    context_vec = context_vec.contiguous().view(b,num_tokens,self.d_out)

    context_vec = self.out_proj(context_vec)

    return context_vec


In [18]:
class PatchEmbedding(nn.Module):
  def __init__(self,input_channel,embed_dim,patch_size):
    super().__init__()
    self.patch_embed = nn.Conv2d(in_channels=input_channel,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size)

  def forward(self,x):
    x = self.patch_embed(x)
    x = x.flatten(2).transpose(1,2)
    return x


class LayerNorm(nn.Module):
  def __init__(self,embed_dim):
    super().__init__()
    self.eps = 1e-5
    self.scale = nn.Parameter(torch.ones(embed_dim))
    self.shift = nn.Parameter(torch.zeros(embed_dim))
  def forward(self,x):
    mean = x.mean(dim = -1,keepdim = True)
    var = x.var(dim = -1,keepdim = True,unbiased=False)

    x = (x-mean) / torch.sqrt(var+self.eps)

    return self.scale*x + self.shift

class GELU(nn.Module):
  def __init__(self):
    super().__init__()
    self.register_buffer('c', torch.sqrt(torch.tensor(2.0 / torch.pi)))

  def forward(self, x):
    return 0.5 * x * (1 + torch.tanh(self.c * (x + 0.044715 * x.pow(3))))


class FeedForward(nn.Module):
  def __init__(self,embed_dim):
    super().__init__()
    self.layers = nn.Sequential(
        nn.Linear(in_features=embed_dim,out_features=4*embed_dim),
        GELU(),
        nn.Linear(in_features=4*embed_dim,out_features=embed_dim)
    )
  def forward(self,x):
    return self.layers(x)

class MLP_head(nn.Module):
  def __init__(self,embed_dim):
    super().__init__()
    self.layer_norm1 = LayerNorm(embed_dim)
    self.mlp_head = nn.Linear(in_features=embed_dim,out_features=num_classes)

  def forward(self,x):
    x = x[:, 0]   # CLS token
    x = self.layer_norm1(x)
    x = self.mlp_head(x)

    return x

In [19]:
class TransformerEncoder(nn.Module):
  def __init__(self,embed_dim,num_heads,dropout,max_seq_len) -> None:
    super().__init__()
    self.attn = MultiheadAttention(
        d_in=embed_dim,
        d_out=embed_dim,
        num_heads=num_heads,
        dropout=dropout,
        context_len=max_seq_len)
    self.ff = FeedForward(embed_dim)
    self.norm1 = LayerNorm(embed_dim)
    self.norm2 = LayerNorm(embed_dim)
    self.dropout_shortcut = nn.Dropout(dropout)
  def forward(self,x):
    shortcut = x
    x = self.norm1(x)
    x = self.attn(x)
    x = self.dropout_shortcut(x)
    x = shortcut + x

    shortcut = x
    x = self.norm2(x)
    x = self.ff(x)
    x = self.dropout_shortcut(x)
    x = x + shortcut

    return x

In [20]:
class VisionTransformer(nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.patch_embedding = PatchEmbedding(input_channel,embed_dim,patch_size)
    self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
    self.position_embedding = nn.Parameter(torch.randn((1,patch_num+1,embed_dim)))
    self.transformer_block = nn.Sequential(
       *[TransformerEncoder(embed_dim,attn_heads,dropout,patch_num+1) for _ in range(num_transformer_block)]
    )
    self.mlp_head = MLP_head(embed_dim)

  def forward(self,x):
    x = self.patch_embedding(x)
    B = x.shape[0]
    cls_tokens = self.cls_token.expand(B,-1,-1)
    x = torch.concat((cls_tokens,x),dim = 1)
    x = x + self.position_embedding
    x = self.transformer_block(x)
    x = self.mlp_head(x)


    return x


In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [22]:
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    running_correct = 0
    running_total = 0

    print(f"\nEpoch [{epoch+1}/{epochs}]")

    for batch_idx, (images, labels) in enumerate(train_data):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()


        batch_size = labels.size(0)
        preds = outputs.argmax(dim=1)
        batch_correct = (preds == labels).sum().item()
        batch_acc = batch_correct / batch_size


        running_loss += loss.item() * batch_size
        running_correct += batch_correct
        running_total += batch_size

        avg_loss = running_loss / running_total
        avg_acc = running_correct / running_total


        if (batch_idx + 1) % 50 == 0 or (batch_idx + 1) == len(train_data):
            print(
                f"  Batch [{batch_idx+1}/{len(train_data)}] | "
                f"Batch Loss: {loss.item():.4f}, Batch Acc: {batch_acc:.4f} | "
                f"Avg Loss: {avg_loss:.4f}, Avg Acc: {avg_acc:.4f}"
            )

    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in val_data:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            batch_size = labels.size(0)
            preds = outputs.argmax(dim=1)

            val_loss += loss.item() * batch_size
            val_correct += (preds == labels).sum().item()
            val_total += batch_size

    val_loss /= val_total
    val_acc = val_correct / val_total

    print(
        f"Epoch [{epoch+1}/{epochs}] Summary | "
        f"Train Loss: {avg_loss:.4f}, Train Acc: {avg_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
    )



Epoch [1/5]
  Batch [50/938] | Batch Loss: 2.3012, Batch Acc: 0.1250 | Avg Loss: 2.3152, Avg Acc: 0.1178
  Batch [100/938] | Batch Loss: 1.8586, Batch Acc: 0.4062 | Avg Loss: 2.2232, Avg Acc: 0.1761
  Batch [150/938] | Batch Loss: 1.4177, Batch Acc: 0.4375 | Avg Loss: 2.0413, Avg Acc: 0.2551
  Batch [200/938] | Batch Loss: 1.1937, Batch Acc: 0.6250 | Avg Loss: 1.8863, Avg Acc: 0.3229
  Batch [250/938] | Batch Loss: 1.1543, Batch Acc: 0.5625 | Avg Loss: 1.7498, Avg Acc: 0.3773
  Batch [300/938] | Batch Loss: 1.0711, Batch Acc: 0.6094 | Avg Loss: 1.6370, Avg Acc: 0.4191
  Batch [350/938] | Batch Loss: 0.8970, Batch Acc: 0.7500 | Avg Loss: 1.5359, Avg Acc: 0.4593
  Batch [400/938] | Batch Loss: 0.7757, Batch Acc: 0.7500 | Avg Loss: 1.4393, Avg Acc: 0.4954
  Batch [450/938] | Batch Loss: 0.7723, Batch Acc: 0.7812 | Avg Loss: 1.3571, Avg Acc: 0.5259
  Batch [500/938] | Batch Loss: 0.6390, Batch Acc: 0.7656 | Avg Loss: 1.2888, Avg Acc: 0.5513
  Batch [550/938] | Batch Loss: 0.6323, Batch Ac

In [26]:
from torchsummary import summary

In [31]:
print(model)

VisionTransformer(
  (patch_embedding): PatchEmbedding(
    (patch_embed): Conv2d(1, 20, kernel_size=(7, 7), stride=(7, 7))
  )
  (transformer_block): Sequential(
    (0): TransformerEncoder(
      (attn): MultiheadAttention(
        (W_query): Linear(in_features=20, out_features=20, bias=False)
        (W_key): Linear(in_features=20, out_features=20, bias=False)
        (W_value): Linear(in_features=20, out_features=20, bias=False)
        (out_proj): Linear(in_features=20, out_features=20, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=20, out_features=80, bias=True)
          (1): GELU()
          (2): Linear(in_features=80, out_features=20, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (dropout_shortcut): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerEncoder(
      (attn): MultiheadAttention(
        (W_query): Linear(in_

In [32]:
print("Params:", sum(p.numel() for p in model.parameters()))

Params: 21610
