In [1]:
import torch

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('GPU device:',torch.cuda.get_device_name(0))
else:
    device = torch.device('cpu')
    print('No GPU avaialable, Using CPU')

torch.cuda.set_device(0)

GPU device: Tesla V100-SXM2-32GB


In [2]:
# Standard Library Imports
import os
import sys
import time
import logging
import getpass
from glob import glob
from pathlib import Path
import random
from typing import Dict, List, Tuple, Optional

# Third-Party Library Imports
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Local Imports
sys.path.append('/home/sur06423/hiwi/vit_exp/vision_tranformer_baseline')
from src.components import data_setup
from src.components.dataset import ImageFolderCustom
from src.components import utils
from src.components.config_manager_baseline import get_config

In [4]:
def format_time(seconds):
    """Converts time in seconds to hours, minutes, and seconds format."""
    hours = seconds // 3600
    minutes = (seconds % 3600) // 60
    seconds = seconds % 60
    return f"{int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds"

In [5]:
def calculate_balanced_accuracy(y_pred, y_true, num_classes, epsilon=1e-9):
    """
    Calculates the balanced accuracy score.
    
    Args:
        y_pred (torch.Tensor): Predicted labels.
        y_true (torch.Tensor): True labels.
        num_classes (int): Number of classes in the dataset.
        epsilon (float): A small value to add to denominators to prevent division by zero.
        
    Returns:
        float: Balanced accuracy score.
    """
    # Create confusion matrix
    confusion_matrix = torch.zeros(num_classes, num_classes, device=y_pred.device)
    for t, p in zip(y_true.view(-1), y_pred.view(-1)):
        confusion_matrix[t.long(), p.long()] += 1

    # Calculate recall for each class, adding epsilon to avoid division by zero
    # Recall =  dividing the true positives by the sum of the true positive and false negative for each class
    # Recall = (diagonal elements of the confusion matrix) /  (the sum of elements in each row of the confusion matrix + epsilon)
    recall = torch.diag(confusion_matrix) / (confusion_matrix.sum(1) + epsilon)

    # balanced_accuracy_per_class = recall  # This line is technically not needed but added for clarity

    # Calculate balanced accuracy
    balanced_accuracy = recall.mean().item()

    return balanced_accuracy

In [6]:
def load_train_objs(model, num_epochs, optimizer_choice, scheduler_choice, initial_lr, momentum, weight_decay_adam, wd_sgd):
    # Setup the optimizer
    if optimizer_choice == 'ADAM':
        optimizer = optim.Adam(
            params=model.parameters(),
            lr=initial_lr,
            betas=(0.9, 0.999),
            weight_decay=weight_decay_adam
        )
    elif optimizer_choice == 'SGD':
        optimizer = optim.SGD(
            params=model.parameters(),
            lr=initial_lr,
            momentum=momentum,
            weight_decay=wd_sgd
        )
    else:
        raise ValueError("Invalid optimizer choice. Choose 'adam' or 'sgd'.")

    # Define the lambda function for learning rate scheduling
    def lr_lambda(epoch):
        # Decrease the learning rate by a factor of 10 every 30 epochs
        return 0.1 ** (epoch // 30)

    # Setup the learning rate scheduler
    if scheduler_choice == 'CosineAnnealingLR':
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=num_epochs
        )
    elif scheduler_choice == 'LambdaLR':
        lr_scheduler = optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lr_lambda  # Used the custom lambda function
        )
    else:
        raise ValueError("Invalid scheduler choice. Choose 'LambdaLR' or 'CosineAnnealingLR'")
    
    return optimizer, lr_scheduler

In [7]:
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets

def prepare_dataset():
    pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
    pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights)

    # Freeze the base parameters
    for parameter in pretrained_vit.parameters():
        parameter.requires_grad = False

    # Change the classifier head
    pretrained_vit.heads = nn.Linear(in_features=768, out_features=3)
    pretrained_vit_transforms = pretrained_vit_weights.transforms()
    # Convert the string path to a Path object
    image_path = Path("/home/sur06423/hiwi/vit_exp/vision_tranformer_baseline/src/ddp_code/data_pizza/pizza_steak_sushi")
    train_dir = image_path / "train"
    test_dir = image_path / "test"
    
    # Use ImageFolder to create dataset(s)
    train_dataset = datasets.ImageFolder(str(train_dir), transform=pretrained_vit_transforms)
    val_dataset = datasets.ImageFolder(str(test_dir), transform=pretrained_vit_transforms)

    # Get class names
    class_names = train_dataset.classes

    return train_dataset, val_dataset, class_names, pretrained_vit

In [8]:
def prepare_dataloader(dataset: Dataset, batch_size: int, num_workers, prefetch_factor):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=False,
        drop_last = False,
        prefetch_factor = prefetch_factor
    )

In [9]:
# Here you can change things for experimentation
batch_size = 64
prefetch_factor = 2
num_epochs = 10

optimizer = 'SGD'
scheduler = 'CosineAnnealingLR'
lr = 0.001
momentum = 0.9
weight_decay = 0.0001
w_decay_adam = 0.03

if torch.cuda.is_available():
    device = torch.device('cuda')
    torch.cuda.set_device(0)
else:
    device = torch.device('cpu')
    print('No GPU avaialable, Using CPU')

# device = "cuda" if torch.cuda.is_available() else "cpu"
utils.set_seeds(1)
num_workers = os.cpu_count()

train_dataset, val_dataset, class_names, pretrained_vit = prepare_dataset()
train_dataloader = prepare_dataloader(dataset= train_dataset, batch_size = batch_size, num_workers = num_workers, prefetch_factor = prefetch_factor)
val_dataloader = prepare_dataloader(dataset= val_dataset, batch_size = batch_size, num_workers = num_workers, prefetch_factor = prefetch_factor)

optimizer, lr_scheduler = load_train_objs(pretrained_vit, 
                                            num_epochs, 
                                            optimizer, 
                                            scheduler, 
                                            lr, 
                                            momentum, 
                                            w_decay_adam, 
                                            weight_decay
                                        )

In [10]:
print(len(train_dataloader.dataset))
print(len(train_dataloader))
print(len(val_dataloader.dataset))
print(len(val_dataloader))
print(train_dataloader.batch_size)

225
4
75
2
64


In [13]:
import time
import torch
from tqdm import tqdm

def train_epoch(model, dataloader, optimizer, gpu_id, num_classes):
    model.to(gpu_id)
    model.train()
    running_loss, correct_predictions, num_samples = 0, 0, 0
    y_pred_all, y_all = [], []
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(gpu_id), y.to(gpu_id)
        optimizer.zero_grad()
        y_pred = model(X)
        loss = F.cross_entropy(y_pred, y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * X.size(0)
        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        correct_predictions += (y_pred_class == y).type(torch.float).sum().item()
        num_samples += X.size(0)
        y_pred_all.append(y_pred_class)
        y_all.append(y)
    metrics = calculate_metrics(running_loss, correct_predictions, num_samples, y_pred_all, y_all, num_classes)
    return metrics

def validate_epoch(model, dataloader, gpu_id, num_classes):
    model.to(gpu_id)
    model.eval()
    running_loss, correct_predictions, num_samples = 0, 0, 0
    y_pred_all, y_all = [], []
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(gpu_id), y.to(gpu_id)
            y_pred = model(X)
            loss = F.cross_entropy(y_pred, y)
            running_loss += loss.item() * X.size(0)
            y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
            correct_predictions += (y_pred_class == y).type(torch.float).sum().item()
            num_samples += X.size(0)
            y_pred_all.append(y_pred_class)
            y_all.append(y)
    metrics = calculate_metrics(running_loss, correct_predictions, num_samples, y_pred_all, y_all, num_classes)
    return metrics

def calculate_metrics(running_loss, correct_predictions, num_samples, y_pred_all, y_all, num_classes):
    avg_loss = running_loss / num_samples
    avg_accuracy = correct_predictions / num_samples
    balanced_accuracy = calculate_balanced_accuracy(torch.concatenate(y_pred_all), torch.concatenate(y_all), num_classes)
    # Cleanup
    del y_pred_all, y_all
    return avg_loss, avg_accuracy, balanced_accuracy

def format_time(seconds):
    hours = seconds // 3600
    minutes = (seconds % 3600) // 60
    seconds = seconds % 60
    return f"{int(hours)}h:{int(minutes)}m:{int(seconds)}s"

def training(max_epochs, num_classes, lr_scheduler, gpu_id, model, train_dataloader, optimizer, val_dataloader):
    total_start_time = time.time()
    for epoch in tqdm(range(max_epochs)):
        train_metrics = train_epoch(model, train_dataloader, optimizer, gpu_id, num_classes)
        lr_scheduler.step()
        val_metrics = validate_epoch(model, val_dataloader, gpu_id, num_classes)
        # Logging the metrics
        print(f"Epoch: {epoch} | Training - Loss: {train_metrics[0]}, Accuracy: {train_metrics[1]}, Balanced Acc: {train_metrics[2]}")
        print(f"Epoch: {epoch} | Validation - Loss: {val_metrics[0]}, Accuracy: {val_metrics[1]}, Balanced Acc: {val_metrics[2]}")
    total_end_time = time.time()
    print(f"Total training and validation time: {format_time(total_end_time - total_start_time)}.")


In [15]:
training(max_epochs = 10, 
         num_classes = 3, 
         lr_scheduler= lr_scheduler,  
         gpu_id = device, 
         model= pretrained_vit, 
         train_dataloader= train_dataloader, 
         optimizer = optimizer, 
         val_dataloader = val_dataloader)

 10%|█         | 1/10 [00:05<00:49,  5.50s/it]

Epoch: 0 | Training - Loss: 0.6472325812445746, Accuracy: 0.8711111111111111, Balanced Acc: 0.8707549571990967
Epoch: 0 | Validation - Loss: 0.5093457555770874, Accuracy: 0.88, Balanced Acc: 0.8802716135978699


 20%|██        | 2/10 [00:10<00:43,  5.45s/it]

Epoch: 1 | Training - Loss: 0.4953925808270772, Accuracy: 0.9111111111111111, Balanced Acc: 0.9111966490745544
Epoch: 1 | Validation - Loss: 0.3991367868582408, Accuracy: 0.92, Balanced Acc: 0.912529706954956


 30%|███       | 3/10 [00:16<00:38,  5.47s/it]

Epoch: 2 | Training - Loss: 0.4106656911638048, Accuracy: 0.9155555555555556, Balanced Acc: 0.9167379140853882
Epoch: 2 | Validation - Loss: 0.3420335308710734, Accuracy: 0.92, Balanced Acc: 0.912529706954956


 40%|████      | 4/10 [00:21<00:32,  5.47s/it]

Epoch: 3 | Training - Loss: 0.3613481405046251, Accuracy: 0.9288888888888889, Balanced Acc: 0.9299003481864929
Epoch: 3 | Validation - Loss: 0.31162968158721926, Accuracy: 0.92, Balanced Acc: 0.912529706954956


 50%|█████     | 5/10 [00:27<00:28,  5.66s/it]

Epoch: 4 | Training - Loss: 0.3313290837075975, Accuracy: 0.9333333333333333, Balanced Acc: 0.9343447685241699
Epoch: 4 | Validation - Loss: 0.2955687038103739, Accuracy: 0.92, Balanced Acc: 0.912529706954956


 60%|██████    | 6/10 [00:33<00:22,  5.64s/it]

Epoch: 5 | Training - Loss: 0.31377708461549547, Accuracy: 0.9377777777777778, Balanced Acc: 0.9387892484664917
Epoch: 5 | Validation - Loss: 0.2877530578772227, Accuracy: 0.92, Balanced Acc: 0.912529706954956


 70%|███████   | 7/10 [00:38<00:16,  5.49s/it]

Epoch: 6 | Training - Loss: 0.30467132574982114, Accuracy: 0.9422222222222222, Balanced Acc: 0.9432336688041687
Epoch: 6 | Validation - Loss: 0.2846335061391195, Accuracy: 0.92, Balanced Acc: 0.912529706954956


 80%|████████  | 8/10 [00:44<00:10,  5.47s/it]

Epoch: 7 | Training - Loss: 0.30096037520302665, Accuracy: 0.9422222222222222, Balanced Acc: 0.9432336688041687
Epoch: 7 | Validation - Loss: 0.2839108137289683, Accuracy: 0.92, Balanced Acc: 0.912529706954956


 90%|█████████ | 9/10 [00:49<00:05,  5.45s/it]

Epoch: 8 | Training - Loss: 0.3001131491528617, Accuracy: 0.9422222222222222, Balanced Acc: 0.9432336688041687
Epoch: 8 | Validation - Loss: 0.2839108137289683, Accuracy: 0.92, Balanced Acc: 0.912529706954956


100%|██████████| 10/10 [00:54<00:00,  5.49s/it]

Epoch: 9 | Training - Loss: 0.29988351517253453, Accuracy: 0.9422222222222222, Balanced Acc: 0.9432336688041687
Epoch: 9 | Validation - Loss: 0.28326441287994386, Accuracy: 0.92, Balanced Acc: 0.912529706954956
Total training and validation time: 0h:0m:54s.



