# Training ViT with PyTorch

Code to train the Visual Transformer on CIFAR100

Work in progress

In [1]:
import os

In [2]:
#python
import matplotlib.pyplot as plt
import numpy as np
import io
from PIL import Image
from tqdm import tqdm

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader, random_split

import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter

## GPU Check

In [3]:
# gpu selection
use_cuda = torch.cuda.is_available()
cuda_index = 2  # torch.cuda.device_count() -1
device = torch.device(f"cuda:{cuda_index}" if use_cuda else "cpu")
print(device)

cuda:2


## Load Dataset Cifar 100

In [None]:
def get_dataset_config(dataset):
    dataset_config = {
        'CIFAR10': {'num_classes': 10, 
                    'input_ch': 3, 
                    'means': (0.424, 0.415, 0.384), 
                    'stds': (0.283, 0.278, 0.284)},
        
        'CIFAR100': {'num_classes': 100, 
                     'input_ch': 3, 
                     'means': (0.438, 0.418, 0.377), 
                     'stds': (0.300, 0.287, 0.294)},
        
        'ImageNet': {'num_classes': 1000, 
                     'input_ch': 3,
                     'means': [0.485, 0.456, 0.406],
                     'stds': [0.229, 0.224, 0.225]}
    }
    return dataset_config

In [None]:
# def load_data(**kwargs):
def load_data(dataset, split, augment=False, shuffle_train=False, batch_size=64, num_workers=8, seed=42, data_dir='/srv/newpenny/dataset'):
    '''
    dataset (str): choices=['CIFAR10', CIFAR100', 'ImageNet']
    '''
    dc = get_dataset_config(dataset)

    means_ = dc[dataset]['means']
    stds_ = dc[dataset]['stds']
   
    if augment and split=='train': # acts only on the training set
        transform = transforms.Compose([
            #transforms.Grayscale(num_output_channels=3),
            transforms.Resize((256, 256)),
            transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
            transforms.CenterCrop((224, 224)),
            #transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(means_, stds_)
        ])
    else:
        transform = transforms.Compose([
            #transforms.Grayscale(num_output_channels=3),
            transforms.Resize((256, 256)),
            transforms.CenterCrop((224, 224)),
            #transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(means_, stds_)
        ])
          
    if dataset=='ImageNet':
        # NB. the loader doesn't work for the 'test' split
        data_path = os.path.join(data_dir, 'imagenet-1k/data')
        data = torchvision.datasets.__dict__[dataset](root=data_path, 
                                                      split=split, 
                                                      transform=transform)
    elif dataset.startswith('CIFAR'):
        data_path = os.path.join(data_dir, dataset)
        if split!='test': # data for train and val
            data = torchvision.datasets.__dict__[dataset](root=data_path, 
                                                        train=True, 
                                                        transform=transform, 
                                                        download=True)
        else:
            data = torchvision.datasets.__dict__[dataset](root=data_path, 
                                                        train=False, 
                                                        transform=transform, 
                                                        download=True)

    if split=='test':
        loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    else:
        # split the dataset into train and validation sets
        train_size = int(0.8 * len(data))
        val_size = len(data) - train_size
        generator = torch.Generator().manual_seed(42)
        train_data, val_data = random_split(data, [train_size, val_size], generator=generator)
        if split=='train':
            loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=shuffle_train, num_workers=num_workers)
        else:
            loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return loader

In [None]:
data_dir = '/srv/newpenny/dataset'

dataset = 'CIFAR100'

train_loader = load_data(dataset, "train", augment=True)
val_loader = load_data(dataset, "val")

In [None]:
num_classes = len(train_loader.dataset.dataset.classes)
num_classes

## Load Weights of the model

In [None]:
model_name = 'vit_b_16'

In [None]:
if model_name == 'resnet50':
    weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1
elif model_name == 'vgg16':
    weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1
elif model_name == 'vit_b_16':
    weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1

model = torchvision.models.__dict__[model_name](weights=weights)

In [None]:
# show architecture of the model
model

In [None]:
model.state_dict().keys()

Change the shape of the last layer from 1000 to 100 -> `out_features=100`

In [None]:
n_classes = len(train_loader.dataset.dataset.classes)

if model_name == 'resnet50':
    in_features = model.fc.in_features
    model.fc = torch.nn.Linear(in_features, n_classes)
elif model_name == 'vgg16':
    in_features = model.classifier[-1].in_features
    model.classifier[-1] = torch.nn.Linear(in_features, n_classes)
elif model_name == 'vit_b_16':
    in_features = model.heads.head.in_features
    model.heads.head = torch.nn.Linear(in_features, n_classes)

In [None]:
model

## Training

### Parameter Settings

In [None]:
num_epochs = 15

criterion = nn.CrossEntropyLoss()

initial_lr = 0.001
optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9)
#optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=1e-4) # more computationally intensive

early_stopping_patience = 10

#scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
#scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
lr_patience = 5
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=lr_patience)


In [None]:
def accuracy(outputs, targets):
    _, predicted = torch.max(outputs, 1)  # get the class index with the highest probability
    correct = (predicted == targets).sum().item()
    total = targets.size(0)
    return correct / total

In [None]:
# GPU selection
use_cuda = torch.cuda.is_available()
cuda_index = torch.cuda.device_count() - 1
device = torch.device(f"cuda:{cuda_index}" if use_cuda else "cpu")
print(f"Using {device} device")

In [None]:
# TensorBoard
save_path = os.path.join(os.path.expanduser("~"), "Documents", "runs")
writer = SummaryWriter(save_path)
save_path

### Training Loop

In [None]:
patience_counter = 0
best_val_loss = float('inf')

tl = []
vl = []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    
    for data in tqdm(train_loader):
        inputs, targets = data
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()        
        
        correct_predictions += accuracy(outputs, targets) * targets.size(0)
        total_predictions += targets.size(0)

    # compute validation loss
    model.eval()
    val_loss = 0.0
    val_correct_predictions = 0
    val_total_predictions = 0
    
    with torch.no_grad():
        for data in val_loader:
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            
            val_correct_predictions += accuracy(outputs, targets) * targets.size(0)
            val_total_predictions += targets.size(0)
            
    train_loss /= len(train_loader)
    val_loss /= len(val_loader)

    tl.append(train_loss)
    vl.append(val_loss)
    
    train_accuracy = (correct_predictions / total_predictions) * 100
    val_accuracy = (val_correct_predictions / val_total_predictions) * 100
    
	#tensorboard
    writer.add_scalar('train loss',train_loss,epoch)
    writer.add_scalar('train accuracy',train_accuracy,epoch)
    writer.add_scalar('val loss',val_loss,epoch)
    writer.add_scalar('val accuracy',val_accuracy,epoch)
    writer.add_scalar('lr',optimizer.param_groups[0]['lr'],epoch)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, '
          f'Train Accuracy: {train_accuracy:.2f}%, Val Accuracy: {val_accuracy:.2f}%')
    
    # early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("Early stopping: Validation loss hasn't improved for", early_stopping_patience, "epochs.")
            break

    # step the scheduler
    if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
        scheduler.step(val_loss)
    else:
        scheduler.step()

    current_lr = scheduler.get_last_lr()[0] if hasattr(scheduler, 'get_last_lr') else optimizer.param_groups[0]['lr']
    print(f'Current lr: {current_lr:.6f}')

writer.close()