# Import libraries

In [13]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as dataloader
import matplotlib.pyplot as plt
import torch.nn as nn

# Define data transformations for training and validation

In [14]:
train_transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load FashionMNIST datasets for training and validation

In [15]:
train_dataset = torchvision.datasets.FashionMNIST(
    root="./data",
    train=True,
    download=True,
    transform=train_transform
)
val_dataset = torchvision.datasets.FashionMNIST(
    root="./data",
    train=False,
    download=True,
    transform=val_transform
)

# Display example image and label from the training dataset

In [16]:
img, label = train_dataset[0]
print(img.shape, label)

torch.Size([1, 28, 28]) 9


# Setup constants and hyperparameters

In [17]:
batch_size = 64
img_size = 28
patch_size = 4
num_channels = 1
num_patches = (img_size // patch_size) ** 2
embed_dim = 128
num_heads = 4
mlp_dim = 256
transformer_units = 6
dropout_rate = 0.1

# Create DataLoader objects for training and validation datasets

In [18]:
train_data = dataloader.DataLoader(train_dataset ,shuffle =True ,batch_size = batch_size)
val_data = dataloader.DataLoader(val_dataset ,shuffle =True ,batch_size = batch_size)

# Visualize one batch of images and their patch embedding

In [19]:
images, labels = next(iter(val_data))
print("Shape of images in a batch:", images.shape)
patch_embed = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
embedded_data = patch_embed(images)
embedded_data = embedded_data.flatten(2)
embedded_data = embedded_data.transpose(1,2)
print("Shape of embedded data:", embedded_data.shape)
print(torch.randn(1,1,embed_dim).shape)

Shape of images in a batch: torch.Size([64, 1, 28, 28])
Shape of embedded data: torch.Size([64, 49, 128])
torch.Size([1, 1, 128])


# Define Patch Embedding module

In [20]:
class PatchEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = nn.Conv2d(
            in_channels=num_channels,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
    def forward(self, x):
        x = self.patch_embed(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

# Define Transformer block module

In [21]:
class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(
            normalized_shape=embed_dim,
            eps=1e-5,
            elementwise_affine=True
        )
        self.self_attention = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=0.0,
            batch_first=True,
            bias=True
        )
        self.layer_norm_2 = nn.LayerNorm(
            normalized_shape=embed_dim,
            eps=1e-5,
            elementwise_affine=True
        )
        self.mlp = nn.Sequential(
            nn.Linear(
                in_features=embed_dim,
                out_features=mlp_dim,
                bias=True
            ),
            nn.GELU(),
            nn.Linear(
                in_features=mlp_dim,
                out_features=embed_dim,
                bias=True
            )
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x : [batch_size, num_patches, embed_dim]
        """
        residual_1 = x
        x_norm = self.layer_norm_1(input=x)
        attention_output, _ = self.self_attention(
            query=x_norm,
            key=x_norm,
            value=x_norm
        )
        x = attention_output + residual_1
        residual_2 = x
        x_norm = self.layer_norm_2(input=x)
        mlp_output = self.mlp(x_norm)
        x = mlp_output + residual_2
        return x

# Define Vision Transformer model class

In [22]:
class VisionTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embedding = PatchEmbedding()
        self.cls_token = nn.Parameter(
            torch.randn(1, 1, embed_dim)
        )
        self.pos_embed = nn.Parameter(
            torch.randn(1, (img_size // patch_size) ** 2 + 1, embed_dim)
        )
        self.transformer_layers = nn.Sequential(
            *[TransformerBlock() for _ in range(transformer_units)]
        )
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(
                normalized_shape=embed_dim,
                eps=1e-5,
                elementwise_affine=True
            ),
            nn.Linear(
                in_features=embed_dim,
                out_features=10,
                bias=True
            )
        )
    def forward(self, x):
        """
        x: Input images, shape [B, C, H, W]
        Returns: Class logits, shape [B, num_classes]
        """
        x = self.patch_embedding(x)
        B = x.size(0)
        cls_tokens = self.cls_token.expand(
            B, -1, -1
        )
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.transformer_layers(x)
        x = x[:, 0]
        x = self.mlp_head(x)
        return x

# Setup device, model, optimizer, scheduler, and loss function

#

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionTransformer().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
criterion = nn.CrossEntropyLoss()

# Training and validation loop with early stopping

In [24]:
best_val_loss = float('inf')
patience = 12
counter = 0
for epoch in range(100):
    model.train()
    total_loss = 0
    correct_epoch = 0
    total_epoch = 0
    print(f"\nEpoch {epoch+1}")
    for batch_idx, (images, labels) in enumerate(train_data):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct = (preds == labels).sum().item()
        correct_epoch += correct
        total_epoch += labels.size(0)
        if batch_idx % 100 == 0:
            batch_acc = 100.0 * correct / labels.size(0)
            print(f"  Batch {batch_idx+1:3d}: Loss = {loss.item():.4f}, Accuracy = {batch_acc:.2f}%")
    epoch_acc = 100.0 * correct_epoch / total_epoch
    print(f"==> Epoch {epoch+1} Summary: Total Loss = {total_loss:.4f}, Accuracy = {epoch_acc:.2f}%")
    model.eval()
    val_loss = 0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_data:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            preds = outputs.argmax(dim=1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)
    val_loss /= len(val_data)
    val_acc = 100.0 * val_correct / val_total
    print(f"==> Validation: Loss = {val_loss:.4f}, Accuracy = {val_acc:.2f}%")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        torch.save(model.state_dict(), 'best_vit_model.pth')
        print("  Best model saved.")
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break
    scheduler.step()


Epoch 1
  Batch   1: Loss = 2.4523, Accuracy = 3.12%
  Batch 101: Loss = 0.7898, Accuracy = 78.12%
  Batch 201: Loss = 0.5979, Accuracy = 73.44%
  Batch 301: Loss = 0.4584, Accuracy = 81.25%
  Batch 401: Loss = 0.6835, Accuracy = 76.56%
  Batch 501: Loss = 0.4989, Accuracy = 81.25%
  Batch 601: Loss = 0.5479, Accuracy = 81.25%
  Batch 701: Loss = 0.3869, Accuracy = 85.94%
  Batch 801: Loss = 0.4943, Accuracy = 84.38%
  Batch 901: Loss = 0.6060, Accuracy = 78.12%
==> Epoch 1 Summary: Total Loss = 639.7198, Accuracy = 74.90%
==> Validation: Loss = 0.4864, Accuracy = 82.21%
  Best model saved.

Epoch 2
  Batch   1: Loss = 0.5172, Accuracy = 87.50%
  Batch 101: Loss = 0.6866, Accuracy = 76.56%
  Batch 201: Loss = 0.4278, Accuracy = 85.94%
  Batch 301: Loss = 0.3614, Accuracy = 84.38%
  Batch 401: Loss = 0.5254, Accuracy = 79.69%
  Batch 501: Loss = 0.2055, Accuracy = 95.31%
  Batch 601: Loss = 0.4335, Accuracy = 79.69%
  Batch 701: Loss = 0.5141, Accuracy = 76.56%
  Batch 801: Loss = 0.27

# Perform model evaluation

In [27]:
model.load_state_dict(torch.load('best_vit_model.pth'))
model.to(device)
model.eval()
val_correct = 0
val_total = 0
with torch.no_grad():
    for images, labels in val_data:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        val_correct += (preds == labels).sum().item()
        val_total += labels.size(0)
final_val_acc = 100.0 * val_correct / val_total
print(f"\n Final Validation Accuracy = {final_val_acc:.2f}%")


 Final Validation Accuracy = 90.77%


# metrics

In [28]:
from sklearn.metrics import classification_report
all_labels = []
all_preds = []
with torch.no_grad():
    for images, labels in val_data:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
print(classification_report(all_labels, all_preds))

              precision    recall  f1-score   support

           0       0.86      0.84      0.85      1000
           1       0.98      0.98      0.98      1000
           2       0.86      0.83      0.85      1000
           3       0.89      0.93      0.91      1000
           4       0.86      0.85      0.85      1000
           5       0.98      0.96      0.97      1000
           6       0.75      0.77      0.76      1000
           7       0.94      0.97      0.96      1000
           8       0.97      0.98      0.98      1000
           9       0.97      0.96      0.96      1000

    accuracy                           0.91     10000
   macro avg       0.91      0.91      0.91     10000
weighted avg       0.91      0.91      0.91     10000

