### Imports

In [None]:
# Visual Transformer based off the original publication An Image is Worth 16x16 Words:
# https://arxiv.org/pdf/2010.11929

# Code is written from tutorial by Ahmad Chalhoub
# https://www.youtube.com/watch?v=nZ22Ecg9XCQ&ab_channel=AhmadChalhoub
# Andrei Cartera -- Sep 2024

#COLAB
#https://saturncloud.io/blog/how-to-save-files-from-google-colab-to-google-drive-a-stepbystep-guide/#:~:text=Step-by-Step%20Guide%20to%20Save%20Files%20to%20Google%20Drive,3%20Step%203%3A%20Save%20Files%20to%20Google%20Drive
#https://stackoverflow.com/questions/59710439/google-colab-and-google-drive-copy-file-from-colab-to-google-drive
#https://stackoverflow.com/questions/63879856/saving-model-state-and-load-in-google-colab
#https://medium.com/@ml_kid/how-to-save-our-model-to-google-drive-and-reuse-it-2c1028058cb2


#Pytorch Lightning

from pathlib import Path
import einops
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, Subset
import torch
from torch import nn
import torchvision
import torch.optim as optim
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, RandomCrop

# einops tutorial https://www.youtube.com/watch?v=xGy75Pjsqzo


print(torch.__version__)

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
!jupyter nbextension enable --py widgetsnbextension

### Set Hyperparameters of the network and specify device

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

patch_size = 16
latent_size = 768
n_channels = 3
num_heads = 12
num_encoders = 12
dropout = 0.1
num_classes = 10
size = 224

epochs = 10
base_lr = 10e-3
weight_decay = 0.01
batch_size = 25

In [None]:
# I resize the input data to 224x224, since that is the training resolution used in the paper.
# The mean and std values used to normalize CIFAR10 data is from here: https://github.com/kentaroy47/vision-transformers-cifar10/blob/main/train_cifar10.py
import torch.utils


transform_training_data = Compose(
    [RandomCrop(32, padding=4), Resize((224)), RandomHorizontalFlip(), ToTensor(), Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
    )

# Load train and test datasets
filepath = Path('.')
train_data = torchvision.datasets.CIFAR10(
    root=filepath, train=True, download=True, transform=transform_training_data)

test_data = torchvision.datasets.CIFAR10(
  root=filepath, train=False, download=True, transform=transform_training_data)

#subset_indices = list(range(200))  # limit size to 200
#train_subset = Subset(train_data, subset_indices)
#test_subset = Subset(test_data, subset_indices)

#trainloader_part = torch.utils.data.DataLoader(train_subset, batch_size=batch_size,shuffle=True)
#testloader_part = torch.utils.data.DataLoader(test_subset, batch_size=batch_size, shuffle=False)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Code to visualize samples from CIFAR10 dataset. This code is copied from: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

### Input Layer

In [None]:
class InputEmbedding(nn.Module):
  def __init__(self, patch_size=patch_size, n_channels=n_channels, device=device, latent_size=latent_size, batch_size=batch_size):
    super(InputEmbedding, self).__init__()
    self.latent_size = latent_size
    self.patch_size = patch_size
    self.n_channels = n_channels
    self.device = device
    self.batch_size = batch_size
    self.input_size = self.patch_size*self.patch_size*self.n_channels

    # Linear projection
    self.linearProjection = nn.Linear(self.input_size, self.latent_size)

    # Class token
    self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)

    # Positional embedding
    self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)

  def forward(self, input_data):
    input_data = input_data.to(self.device)


    # Patchify input image
    patches = einops.rearrange(
      ## einops "(a b)" normally multiplies, since h is already 224, (h, h1) divides h by h1
      input_data, 'b c (h h1) (w w1) -> b (h w) (h1 w1 c)', h1=self.patch_size, w1=self.patch_size
    )

    #print(input_data.size())
    #print(patches.size())
    #print(input_data.shape)
    #print(patches.shape)

    #TODO Rewatch this part https://youtu.be/nZ22Ecg9XCQ?t=1282

    linear_projection = self.linearProjection(patches).to(self.device)
    b, n, _ = linear_projection.shape
    linear_projection = torch.cat((self.class_token, linear_projection), dim=1)
    pos_embed = einops.repeat(self.pos_embedding, 'b 1 d -> b m d', m=n+1)

    linear_projection += pos_embed

    return linear_projection


In [None]:
test_input = torch.randn(batch_size, 3, 224, 224)
test_class = InputEmbedding().to(device)
embed_test = test_class(test_input)

### Implementation of Encoder Block

In [None]:
class EncoderBlock(nn.Module):
  def __init__(self, latent_size=latent_size, num_heads=num_heads, device=device, dropout=dropout):
    super(EncoderBlock, self).__init__()

    self.latent_size = latent_size
    self.num_heads = num_heads
    self.device = device
    self.dropout = dropout

    #Normmalization layer
    self.norm = nn.LayerNorm(self.latent_size)

    self.multihead = nn.MultiheadAttention(
        self.latent_size, self.num_heads, dropout=self.dropout
    )

    self.enc_MLP = nn.Sequential(
      nn.Linear(self.latent_size, self.latent_size*4),
      nn.GELU(),
      nn.Dropout(self.dropout),
      nn.Linear(self.latent_size*4, latent_size),
      nn.Dropout(self.dropout)
    )

  def forward(self, embedded_patches) :
    firstnorm_out = self.norm(embedded_patches)
    attention_out = self.multihead(firstnorm_out, firstnorm_out, firstnorm_out)[0]

    #first residual connection
    first_added = attention_out + embedded_patches

    secondnorm_out = self.norm(first_added)
    ff_out = self.enc_MLP(secondnorm_out)

    #print('embed: ', embedded_patches.size())
    #print('output: ', output.size())

    return ff_out + first_added



In [None]:
#test_encoder = EncoderBlock().to(device)
#test_encoder(embed_test)

### Put Everything Together

In [None]:
class ViT(nn.Module):
  def __init__(self, num_encoders=num_encoders, latent_size=latent_size, device=device, num_classes=num_classes, dropout=dropout):
    super(ViT, self).__init__()

    self.num_encoders=num_encoders
    self.latent_size=latent_size
    self.device=device
    self.num_classes=num_classes
    self.dropout=dropout

    self.embedding = InputEmbedding()

    #Create Stack of Encoders
    self.encStack = nn.ModuleList([EncoderBlock() for i in  range(self.num_encoders)])

    self.MLP_head = nn.Sequential(
      nn.LayerNorm(self.latent_size),
      nn.Linear(self.latent_size, self.latent_size),
      nn.Linear(self.latent_size, self.num_classes)
    )

  def forward(self, input):
    enc_output = self.embedding(input)

    for enc_layer in self.encStack:
      enc_output = enc_layer(enc_output)

    cls_token_embed = enc_output[:, 0]

    return self.MLP_head(cls_token_embed)


In [None]:
#model = ViT().to(device)
#vit_output = model(test_input)
#print(vit_output)
#print(vit_output.size())

In [None]:
#def evaluation(data, model):
#    total, correct = 0, 0
#    for images, labels in data:
#        images = images.to(device)
#        ypred = model.forward(images)
#        _, predicted = torch.max(ypred.data, 1)
#        total += labels.size(0)
#        correct += (predicted == labels.to(device)).sum().item()
#    return 100 * correct / total



In [None]:
def validate_model(model, dataloader, criterion, device):
  model.eval()  # Set model to evaluation mode
  running_loss = 0.0
  correct_predictions = 0
  total_samples = 0

  with torch.no_grad():  # Disable gradient calculation
    for inputs, labels in dataloader:
      inputs, labels = inputs.to(device), labels.to(device)

      outputs = model(inputs)
      loss = criterion(outputs, labels)
      running_loss += loss.item() * inputs.size(0)

      _, predicted = torch.max(outputs, 1)
      correct_predictions += (predicted == labels).sum().item()
      total_samples += labels.size(0)

      # Clear cache to free up memory
      torch.cuda.empty_cache()

  avg_loss = running_loss / total_samples
  accuracy = correct_predictions / total_samples
  return avg_loss, accuracy

In [None]:
#from google.colab import drive
#drive.mount('/content/gdrive')

In [None]:
my_model = ViT(num_encoders, latent_size, device, num_classes).to(device)

# Betas used for Adam in paper are 0.9 and 0.999, which are the default in PyTorch
optimizer = optim.Adam(my_model.parameters(), lr=base_lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.LinearLR(optimizer)

In [None]:
def main():
    my_model.train().to(device)

    for epoch in tqdm(range(epochs), total=epochs):
        running_loss = 0.0
        for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):

            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()

            outputs = my_model(inputs)

            loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if batch_idx % 200 == 0:
                print('Batch {} epoch {} has loss = {}'.format(batch_idx, epoch, running_loss/200))
                running_loss = 0

        scheduler.step()
        torch.cuda.empty_cache()
    torch.cuda.empty_cache()

    torch.save(my_model.state_dict(), 'my_model2.pth')


    test_accuracy = validate_model(my_model, testloader, criterion, device)
    print(f'Test Accuracy: {test_accuracy}%')


In [None]:
if __name__ == "__main__":  
    main()