### Imports

In [2]:
# Visual Transformer based off the original publication An Image is Worth 16x16 Words:
# https://arxiv.org/pdf/2010.11929

# Code is written partially from a tutorial by Uyaar Kurt
# https://www.youtube.com/watch?v=Vonyoz6Yt9c
# Andrei Cartera -- Oct 2024



#Pytorch Lightning

from pathlib import Path
import pandas as pd
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, Subset
import torch
from torch import nn
import torchvision
import torch.optim as optim
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, RandomCrop


print(torch.__version__)

2.4.1+cu124


In [3]:
#from google.colab import drive
#drive.mount('/content/drive')

### Set Hyperparameters of the network and specify device

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

RANDOM_SEED = 42
BATCH_SIZE = 25
EPOCHS = 50
LEARNING_RATE = 1e-4
NUM_CLASSES = 10
PATCH_SIZE = 16
IMG_SIZE = 224
IN_CHANNELS = 3
NUM_HEADS = 12
DROPOUT = 0.001
HIDDEN_DIM = 768
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.999)
ACTIVATION="gelu"
NUM_ENCODERS = 4
EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS # 768
NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2 # 196


cuda:0


In [9]:
# I resize the input data to 224x224, since that is the training resolution used in the paper.
# The mean and std values used to normalize CIFAR10 data is from here: https://github.com/kentaroy47/vision-transformers-cifar10/blob/main/train_cifar10.py
import torch.utils


transform_training_data = Compose(
  [RandomCrop(32, padding=4), Resize((224)), RandomHorizontalFlip(), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
  )

# Load train and test datasets
filepath = Path('../datasets/')
train_data = torchvision.datasets.CIFAR10(
    root=filepath, train=True, download=True, transform=transform_training_data)

test_data = torchvision.datasets.CIFAR10(
  root=filepath, train=False, download=False, transform=transform_training_data)

#subset_indices = list(range(200))  # limit size to 200
#train_subset = Subset(train_data, subset_indices)
#test_subset = Subset(test_data, subset_indices)

#trainloader_part = torch.utils.data.DataLoader(train_subset, batch_size=batch_size,shuffle=True)
#testloader_part = torch.utils.data.DataLoader(test_subset, batch_size=batch_size, shuffle=False)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE,shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Files already downloaded and verified


Code to visualize samples from CIFAR10 dataset. This code is copied from: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(BATCH_SIZE)))

### Input Layer

In [7]:
class PatchEmbedding(nn.Module):
  def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
    super().__init__()
    self.patcher = nn.Sequential(
      nn.Conv2d(
        in_channels=in_channels,
        out_channels=embed_dim,
        kernel_size=patch_size,
        stride=patch_size,
      ),                  
      nn.Flatten(2))

    self.cls_token = nn.Parameter(torch.randn(size=(1, 1, embed_dim)), requires_grad=True)
    self.position_embeddings = nn.Parameter(torch.randn(size=(1, num_patches+1, embed_dim)), requires_grad=True)
    self.dropout = nn.Dropout(p=dropout)

  def forward(self, x):
    cls_token = self.cls_token.expand(x.shape[0], -1, -1)

    x = self.patcher(x).permute(0, 2, 1)
    x = torch.cat([cls_token, x], dim=1)
    x = self.position_embeddings + x 
    x = self.dropout(x)
    return x
    
#model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(device)
#x = torch.randn(BATCH_SIZE, 3, 224, 224).to(device)
#print(model(x).shape)

### Put Everything Together

In [None]:
class ViT(nn.Module):
  def __init__(self, num_patches, num_classes, patch_size, embed_dim, num_encoders, num_heads, dropout, activation, in_channels):
    super().__init__()
    self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels)
    
    encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation, batch_first=True, norm_first=True)
    self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)

    self.mlp_head = nn.Sequential(
        nn.LayerNorm(normalized_shape=embed_dim),
        nn.Linear(in_features=embed_dim, out_features=num_classes)
    )

  def forward(self, x):
    x = self.embeddings_block(x)
    x = self.encoder_blocks(x)
    x = self.mlp_head(x[:, 0, :])  # Apply MLP on the CLS token only
    return x

#model = ViT(NUM_PATCHES, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, DROPOUT, ACTIVATION, IN_CHANNELS).to(device)
#x = torch.randn(BATCH_SIZE, IN_CHANNELS, 224, 224).to(device)
#print(model(x).shape)


In [9]:
def validate_model(model, dataloader, criterion, device):
  model.eval()  # Set model to evaluation mode
  running_loss = 0.0
  correct_predictions = 0
  total_samples = 0

  with torch.no_grad():  # Disable gradient calculation
    for inputs, labels in dataloader:
      inputs, labels = inputs.to(device), labels.to(device)

      outputs = model(inputs)
      loss = criterion(outputs, labels)
      running_loss += loss.item() * inputs.size(0)

      _, predicted = torch.max(outputs, 1)
      correct_predictions += (predicted == labels).sum().item()
      total_samples += labels.size(0)

      # Clear cache to free up memory
      torch.cuda.empty_cache()

  avg_loss = running_loss / total_samples
  accuracy = correct_predictions / total_samples
  return avg_loss, accuracy

In [10]:
def main():
  my_model = ViT(NUM_PATCHES, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, DROPOUT, ACTIVATION, IN_CHANNELS).to(device)

  optimizer = optim.Adam(my_model.parameters(), lr=LEARNING_RATE, weight_decay=ADAM_WEIGHT_DECAY)
  criterion = nn.CrossEntropyLoss()
  scheduler = optim.lr_scheduler.LinearLR(optimizer)
    
  my_model.train().to(device)

  for epoch in tqdm(range(EPOCHS), total=EPOCHS):
    running_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):

      inputs, targets = inputs.to(device), targets.to(device)

      optimizer.zero_grad()

      outputs = my_model(inputs)

      loss = criterion(outputs, targets)

      loss.backward()
      optimizer.step()

      running_loss += loss.item()

      if batch_idx % 200 == 0:
        print('Batch {} epoch {} has loss = {}'.format(batch_idx, epoch, running_loss/200))
        running_loss = 0

    scheduler.step()
    torch.cuda.empty_cache()

  torch.save(my_model.state_dict(), 'my_model4.pth')

  test_accuracy = validate_model(my_model, testloader, criterion, device)
  print(f'Test Accuracy: {test_accuracy}%')


In [None]:
if __name__ == "__main__":  
  main()