In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torchinfo import summary
import numpy as np
from tqdm.notebook import tqdm
import os, random

In [2]:
# def set_seeds(seed):
#     os.environ["PL_GLOBAL_SEED"] = str(seed)
#     random.seed(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed(seed)
#     torch.cuda.manual_seed_all(seed)
#
# set_seeds(2025)
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
BATCH_SIZE = 64

In [4]:
transform_train = transforms.Compose([
    # transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
transform_test = transforms.Compose([
    # transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [5]:
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform_train)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform_test)

train_size = int(0.8 * len(train_dataset))
valid_size = len(train_dataset) - train_size
train_set, valid_set = random_split(train_dataset, [train_size, valid_size])

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [6]:
def patchify(images, patch_no):
  n, c, h, w = images.shape

  patches = torch.zeros(n, patch_no ** 2, (h*c*w) // patch_no ** 2 , device=images.device)
  patch_size = h // patch_no

  for idx, image in enumerate(images):
    for i in range(patch_no):
      for j in range(patch_no):
        patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
        patches[idx, i * patch_no + j] = patch.flatten()
  return patches

In [7]:
def get_positional_embeddings(sequence_length, dimension):
    result = torch.ones(sequence_length, dimension)
    for i in range(sequence_length):
        for j in range(dimension):
            result[i][j] = np.sin(i / (10000 ** (j / dimension))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / dimension)))
    return result

In [8]:
class Multi_Self_Attention(nn.Module):
    def __init__(self, dimension, head_no, dropout=0.1):
        super(Multi_Self_Attention, self).__init__()
        self.dimension = dimension
        self.head_no = head_no
        self.head_dimension = dimension // head_no

        self.query = nn.Linear(dimension, dimension)
        self.key = nn.Linear(dimension, dimension)
        self.value = nn.Linear(dimension, dimension)
        self.out = nn.Linear(dimension, dimension)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        batch_size, seq_len, dimension = x.size()

        q = self.query(x).view(batch_size, seq_len, self.head_no, self.head_dimension).transpose(1, 2)
        k = self.key(x).view(batch_size, seq_len, self.head_no, self.head_dimension).transpose(1, 2)
        v = self.value(x).view(batch_size, seq_len, self.head_no, self.head_dimension).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dimension ** 0.5)
        attn = self.softmax(scores)
        context = torch.matmul(attn, v)

        context = context.transpose(1, 2).reshape(batch_size, seq_len, dimension)
        out = self.dropout(self.out(context))
        return out

In [9]:
class ViT_Block(nn.Module):
  def __init__(self, hidden_dimension, head_no, mlp_ratio=4, dropout=0.1):
    super(ViT_Block, self).__init__()
    self.hidden_dimension = hidden_dimension
    self.head_no = head_no

    self.dropout = nn.Dropout(dropout)
    self.norm1 = nn.LayerNorm(hidden_dimension)
    self.multi_head_self_attention = Multi_Self_Attention(hidden_dimension, head_no, dropout)
    self.norm2 = nn.LayerNorm(hidden_dimension)

    self.mlp = nn.Sequential(
        nn.Linear(hidden_dimension, mlp_ratio * hidden_dimension),
        nn.GELU(),
        nn.Linear(mlp_ratio * hidden_dimension, hidden_dimension)
    )

  def forward(self, input):
    attention = input + self.dropout(self.multi_head_self_attention(self.norm1(input)))
    out = attention + self.dropout(self.mlp(self.norm2(attention)))
    return out

In [10]:
test = ViT_Block(hidden_dimension=8, head_no=2)
x = torch.randn(7, 50, 8)
test(x).shape

torch.Size([7, 50, 8])

In [11]:
class ViT_Model(nn.Module):
  def __init__(self, data_shape, patch_no, hidden_dimension, output_dimension, block_no, head_no):
    super(ViT_Model, self).__init__()

    self.data_shape = data_shape
    self.patch_no = patch_no
    self.hidden_dimension = hidden_dimension
    self.block_no = block_no
    self.head_no = head_no

    self.patch_size = (data_shape[1] // patch_no, data_shape[2] // patch_no)
    self.data_dimension = int(self.data_shape[0] * self.patch_size[0] * self.patch_size[1])

    self.mapper = nn.Linear(self.data_dimension, self.hidden_dimension)

    self.class_token = nn.Parameter(torch.rand(1, self.hidden_dimension))

    self.register_buffer(
            'positional_embedding',
            get_positional_embeddings(patch_no ** 2 + 1, hidden_dimension),
            persistent=False
        )

    self.blocks = nn.ModuleList([ViT_Block(self.hidden_dimension, self.head_no, dropout=0.1) for _ in range(self.block_no)])

    self.mlp_head = nn.Sequential(
        nn.LayerNorm(self.hidden_dimension),
        nn.GELU(),
        nn.Linear(self.hidden_dimension, self.hidden_dimension),
        nn.Dropout(0.1),
        nn.GELU(),
        nn.Linear(self.hidden_dimension, output_dimension)
    )


  def forward(self, input):
    patches = patchify(input, self.patch_no)
    tokens = self.mapper(patches)

    tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

    positional_embedding = self.positional_embedding.repeat(input.shape[0], 1, 1)
    out = tokens + positional_embedding

    for block in self.blocks:
      out = block(out)

    out = out[:, 0]

    out = self.mlp_head(out)

    return out

model = ViT_Model(data_shape=(1, 28, 28), patch_no=4, hidden_dimension=256, output_dimension=10, block_no=6, head_no=8).to(device)

input_data = torch.randn(BATCH_SIZE, 1, 28, 28).to(device)
summary(model, input_data=input_data, device=str(device))

Layer (type:depth-idx)                        Output Shape              Param #
ViT_Model                                     [64, 10]                  256
├─Linear: 1-1                                 [64, 16, 256]             12,800
├─ModuleList: 1-2                             --                        --
│    └─ViT_Block: 2-1                         [64, 17, 256]             --
│    │    └─LayerNorm: 3-1                    [64, 17, 256]             512
│    │    └─Multi_Self_Attention: 3-2         [64, 17, 256]             263,168
│    │    └─Dropout: 3-3                      [64, 17, 256]             --
│    │    └─LayerNorm: 3-4                    [64, 17, 256]             512
│    │    └─Sequential: 3-5                   [64, 17, 256]             525,568
│    │    └─Dropout: 3-6                      [64, 17, 256]             --
│    └─ViT_Block: 2-2                         [64, 17, 256]             --
│    │    └─LayerNorm: 3-7                    [64, 17, 256]             512
│ 

In [12]:
epoch_no = 10
optimizer = optim.AdamW(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()

In [13]:
progress_bar = tqdm(range(1, epoch_no + 1), desc="Epochs")

for epoch in progress_bar:
    model.train()
    train_loss = 0.0
    num_batches = 0

    print(f"********************************************************************")
    epoch_progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch} training")
    for batch_no, (data, labels) in epoch_progress_bar:
        data = data.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        num_batches += 1

    avg_train_loss = train_loss / num_batches
    print(f"Epoch {epoch}/{epoch_no} average training loss: {avg_train_loss:.4f}")
    print(f"----------------------------------------------------------------")

    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    num_test_batches = 0

    with torch.no_grad():
        test_progress_bar = tqdm(enumerate(test_loader), total=len(test_loader), desc=f"Epoch {epoch} testing")
        for batch_no, (data, labels) in test_progress_bar:
            data = data.to(device)
            labels = labels.to(device)

            output = model(data)
            loss = criterion(output, labels)
            test_loss += loss.item()
            num_test_batches += 1

            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_test_loss = test_loss / num_test_batches
    accuracy = correct / total * 100
    print(f"Epoch {epoch}/{epoch_no} average test loss: {avg_test_loss:.4f}")
    print(f"Epoch {epoch}/{epoch_no} test accuracy: {accuracy:.2f}%")

Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

********************************************************************


Epoch 1 training:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch 1/10 average training loss: 0.6756
----------------------------------------------------------------


Epoch 1 testing:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 1/10 average test loss: 0.3983
Epoch 1/10 test accuracy: 87.62%
********************************************************************


Epoch 2 training:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch 2/10 average training loss: 0.3367
----------------------------------------------------------------


Epoch 2 testing:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 2/10 average test loss: 0.2986
Epoch 2/10 test accuracy: 90.26%
********************************************************************


Epoch 3 training:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch 3/10 average training loss: 0.3180
----------------------------------------------------------------


Epoch 3 testing:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 3/10 average test loss: 0.2633
Epoch 3/10 test accuracy: 91.76%
********************************************************************


Epoch 4 training:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch 4/10 average training loss: 0.3463
----------------------------------------------------------------


Epoch 4 testing:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 4/10 average test loss: 0.2417
Epoch 4/10 test accuracy: 92.46%
********************************************************************


Epoch 5 training:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch 5/10 average training loss: 0.2678
----------------------------------------------------------------


Epoch 5 testing:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 5/10 average test loss: 0.2081
Epoch 5/10 test accuracy: 93.46%
********************************************************************


Epoch 6 training:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch 6/10 average training loss: 0.3272
----------------------------------------------------------------


Epoch 6 testing:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 6/10 average test loss: 0.2003
Epoch 6/10 test accuracy: 93.87%
********************************************************************


Epoch 7 training:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch 7/10 average training loss: 0.2091
----------------------------------------------------------------


Epoch 7 testing:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 7/10 average test loss: 0.1959
Epoch 7/10 test accuracy: 94.04%
********************************************************************


Epoch 8 training:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch 8/10 average training loss: 0.2298
----------------------------------------------------------------


Epoch 8 testing:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 8/10 average test loss: 0.2295
Epoch 8/10 test accuracy: 92.96%
********************************************************************


Epoch 9 training:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch 9/10 average training loss: 0.3042
----------------------------------------------------------------


Epoch 9 testing:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 9/10 average test loss: 0.2694
Epoch 9/10 test accuracy: 91.40%
********************************************************************


Epoch 10 training:   0%|          | 0/750 [00:00<?, ?it/s]

Epoch 10/10 average training loss: 0.2494
----------------------------------------------------------------


Epoch 10 testing:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 10/10 average test loss: 0.1989
Epoch 10/10 test accuracy: 93.45%
