# ViT

This notebook trains a ViT. Note, to utilize GPU, this notebook has been made in Kaggle.

Referenced this paper: https://cs231n.stanford.edu/reports/2022/pdfs/151.pdf


In [1]:
!pip install vit_pytorch

Collecting vit_pytorch
  Downloading vit_pytorch-1.7.12-py3-none-any.whl.metadata (67 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.8/67.8 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.7.0 (from vit_pytorch)
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading vit_pytorch-1.7.12-py3-none-any.whl (131 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m131.5/131.5 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops, vit_pytorch
Successfully installed einops-0.8.0 vit_pytorch-1.7.12


In [2]:
import pandas as pd
import numpy as np
import torch
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader
from vit_pytorch import ViT
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import log_loss, accuracy_score

# Functions & Classes

In [3]:
class StateFarmDD(Dataset):
    # Constructor
    def __init__(self,annotations_path,path_prefix='../data/imgs/train',transform_pipeline=None,test=False):
        self.annotations = pd.read_csv(annotations_path)
        self.transformation_pipeline = transform_pipeline
        self.path_prefix = path_prefix
        self.label_to_int_dict = {'c0':0,'c1':1,'c2':2,'c3':3,'c4':4,'c5':5,'c6':6,'c7':7,'c8':8,'c9':9,
                                  'c10':-1}
        self.test = test

    # Method to get the length of the dataset
    def __len__(self):
        return len(self.annotations)
    
    # Method to get the item at a particular index
    def __getitem__(self,index):
        label = self.annotations.iloc[index,1]
        image_name = self.annotations.iloc[index,2]
        if self.test:
            image = plt.imread(f'{self.path_prefix}/{image_name}')
        else:
            image = plt.imread(f'{self.path_prefix}/{label}/{image_name}')

        # Throwing image through pipeline if it exists
        if self.transformation_pipeline:
            transformed_image = self.transformation_pipeline(image.copy()).squeeze(0)
            return transformed_image, self.label_to_int_dict[label]
        else:
            return image, self.label_to_int_dict[label]

In [4]:
# Function to calculate the metrics
def calculate_metrics(model_id,model_name,training_targets,training_predictions,val_targets,val_predictions):

    # Getting the metrics
    train_CE_loss = log_loss(training_targets,training_predictions)
    validation_CE_loss = log_loss(val_targets,val_predictions)

    # Getting the accuracy
    train_class_preds = np.array(training_predictions).argmax(axis=1)
    valid_class_preds = np.array(val_predictions).argmax(axis=1)
    train_acc = accuracy_score(training_targets,train_class_preds)
    valid_acc = accuracy_score(val_targets,valid_class_preds)

    return {'model_id':model_id,'model_name':model_name,'train_CE_loss':train_CE_loss,
            'train_acc':train_acc,'validation_CE_loss':validation_CE_loss,'validation_acc':valid_acc}

## Model Building & Training

In [5]:
# Getting the data
training_path = '/kaggle/input/statefarmdd/training.csv'
validation_path = '/kaggle/input/statefarmdd/validation.csv'

# Creating the Datasets
transformation_pipeline = v2.Compose([
        v2.ToImage(),
        v2.Resize([224,224]),
        v2.ToDtype(torch.float32, scale=True),
])
train_dataset = StateFarmDD(training_path, path_prefix='/kaggle/input/state-farm-distracted-driver-detection/imgs/train',transform_pipeline=transformation_pipeline)
valid_dataset = StateFarmDD(validation_path, path_prefix='/kaggle/input/state-farm-distracted-driver-detection/imgs/train',transform_pipeline=transformation_pipeline)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True)

In [6]:
model = ViT(
    image_size=224,
    patch_size=16,
    num_classes=10,
    dim=768,
    depth=3,
    heads=8,
    mlp_dim=3072,
    channels=3,
    dropout=0.2,
    emb_dropout=0
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model_name = 'vit'
model_number = 23
loss_fn = torch.nn.CrossEntropyLoss(reduction='sum').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
training_history_loss = []
validation_history_loss = []
train_history_accuracy = []
valid_history_accuracy = []
epochs = 25
current_count = 0
early_stopping_threshold = 1e-4
early_stopping_count = 5
best_val_loss = float('inf')
best_epoch = -1

# Training the model
for epoch in range(epochs):
    model.train()
    train_loss = 0
    valid_loss = 0
    train_accuracy = 0
    valid_accuracy = 0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Computing the accuracy
        y_pred_labels = torch.argmax(y_pred, dim=1)
        train_accuracy += torch.sum(y_pred_labels == y_batch).item()

    model.eval()
    for X_batch, y_batch in valid_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        valid_loss += loss.item()

        # Computing the accuracy
        y_pred_labels = torch.argmax(y_pred, dim=1)
        valid_accuracy += torch.sum(y_pred_labels == y_batch).item()
    
    train_loss /= len(train_loader.dataset)
    valid_loss /= len(valid_loader.dataset)
    train_accuracy /= len(train_loader.dataset)
    valid_accuracy /= len(valid_loader.dataset)
    training_history_loss.append(train_loss)
    validation_history_loss.append(valid_loss)
    train_history_accuracy.append(train_accuracy)
    valid_history_accuracy.append(valid_accuracy)

    # Early stopping
    if epoch > 0:
        if validation_history_loss[-1] - best_val_loss > early_stopping_threshold:
            current_count += 1
        else:
            best_val_loss = validation_history_loss[-1]
            best_epoch = epoch
            current_count = 0
    else:
        best_val_loss = validation_history_loss[-1]
        best_epoch = epoch
        
    if current_count == early_stopping_count:
        print('Stopping training due to early stopping!!!')
        break
    elif current_count == 0:
        # Saving the best model
        # Saving the best model
        model.to('cpu')
        torch.save(model.state_dict(), f'{model_name}.pth')
        model.to(device)
    
    print('-----------------------------------')
    print(f'Epoch {epoch}')
    print(f'Training Loss: {round(train_loss,4)}')
    print(f'Validation Loss: {round(valid_loss,4)}')
    print(f'Training Accuracy: {round(train_accuracy*100,4)}%')
    print(f'Validation Accuracy: {round(valid_accuracy*100,4)}%')
    print()
    print(f'Best Validation Loss: {round(best_val_loss,4)}')
    print(f'Best Epoch: {best_epoch}')
    print('-----------------------------------')
    print()

-----------------------------------
Epoch 0
Training Loss: 2.2929
Validation Loss: 2.2899
Training Accuracy: 12.9133%
Validation Accuracy: 12.2562%

Best Validation Loss: 2.2899
Best Epoch: 0
-----------------------------------

-----------------------------------
Epoch 1
Training Loss: 1.7983
Validation Loss: 2.8484
Training Accuracy: 39.4011%
Validation Accuracy: 14.602%

Best Validation Loss: 2.2899
Best Epoch: 0
-----------------------------------

-----------------------------------
Epoch 2
Training Loss: 0.542
Validation Loss: 3.9561
Training Accuracy: 85.4523%
Validation Accuracy: 19.2033%

Best Validation Loss: 2.2899
Best Epoch: 0
-----------------------------------

-----------------------------------
Epoch 3
Training Loss: 0.1901
Validation Loss: 4.4705
Training Accuracy: 95.8079%
Validation Accuracy: 18.3843%

Best Validation Loss: 2.2899
Best Epoch: 0
-----------------------------------

-----------------------------------
Epoch 4
Training Loss: 0.1115
Validation Loss: 4.7

In [7]:
# Saving Plots for the training and validation loss
plt.figure(figsize=(10, 6))
plt.plot(training_history_loss, label='Training Loss')
plt.plot(validation_history_loss, label='Validation Loss')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.savefig(f'{model_name}_loss.png')
plt.close()

# Saving plot for the training and validation accuracy 
plt.figure(figsize=(10, 6))
plt.plot(train_history_accuracy, label='Training Accuracy')
plt.plot(valid_history_accuracy, label='Validation Accuracy')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.savefig(f'{model_name}_accuracy.png')
plt.close()

In [8]:
# Making the predictions for the training & validation for metric logging
model.load_state_dict(torch.load(f'{model_name}.pth')) # loading the best model
model.eval()
model.to(device)
train_pred = []
valid_pred = []
train_truth = []
valid_truth = []

  model.load_state_dict(torch.load(f'{model_name}.pth')) # loading the best model


In [9]:
# Running through data loaders to store the predictions
with torch.no_grad():
    for X_batch, y_batch in train_loader:
        X_batch = X_batch.to(device)
        y_pred = torch.nn.functional.softmax(model(X_batch),dim=1)
        train_pred.extend(y_pred.detach().cpu().numpy())
        train_truth.extend(y_batch.numpy())

    for X_batch, y_batch in valid_loader:
        X_batch = X_batch.to(device)
        y_pred = torch.nn.functional.softmax(model(X_batch),dim=1)
        valid_pred.extend(y_pred.detach().cpu().numpy())
        valid_truth.extend(y_batch.numpy())

# Printing out the metrics 
metrics = calculate_metrics(model_number,f'{model_name}',train_truth,np.array(train_pred),valid_truth,np.array(valid_pred))
print(metrics)

{'model_id': 23, 'model_name': 'vit', 'train_CE_loss': 2.214328674014989, 'train_acc': 0.2832189644416719, 'validation_CE_loss': 2.289904279112905, 'validation_acc': 0.12256228745922687}
