<a href="https://colab.research.google.com/github/SaiRajesh228/DA6401_Assignment2/blob/main/DA6401_Assignment2_PartB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import time
from tqdm.notebook import tqdm

import numpy as np

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
import wandb

import torch.optim as optims
from torch.utils.data import Dataset, DataLoader,ChainDataset, ConcatDataset
from torch.utils.data.distributed import DistributedSampler

import matplotlib.pyplot as plt

In [None]:


class DatasetManager:

    def __init__(self, root_path, computation_device, base_transforms=None):
        self.root_path = root_path
        self.device = computation_device
        self.base_transforms = base_transforms or []

    def build_loader(self, data_subset, batch_size=16, shuffle=True,
                    workers=2, augmentation_pipelines=None, pin_memory=False):
        """
        Constructs and returns a DataLoader for the specified dataset subset.

        Parameters:
            data_subset (str): Subdirectory name (e.g., 'train/', 'val/', 'test/')
            batch_size (int): Number of samples per batch
            shuffle (bool): Whether to shuffle the data
            workers (int): Number of parallel data loading processes
            augmentation_pipelines (list): Optional list of augmentation sequences
            pin_memory (bool): Enable fast data transfer to GPU

        Returns:
            DataLoader configured for the specified dataset
        """
        print(f"Initializing {data_subset} dataset processing...")

        # Base transformations applied to all datasets
        core_transforms = transforms.Compose(self.base_transforms)

        # Handle dataset construction
        if data_subset.startswith('train') and augmentation_pipelines:
            main_set = self._create_dataset(data_subset, core_transforms)
            augmented_sets = [self._create_dataset(data_subset, transforms.Compose(
                self.base_transforms + pipeline)) for pipeline in augmentation_pipelines]
            combined_data = ConcatDataset([main_set] + augmented_sets)
        else:
            combined_data = self._create_dataset(data_subset, core_transforms)

        # Configure loader parameters
        loader_config = {
            'batch_size': batch_size,
            'num_workers': workers,
            'pin_memory': pin_memory,
            'persistent_workers': workers > 0
        }

        # Shuffle configuration
        if data_subset.startswith('train'):
            loader_config['shuffle'] = shuffle
            # For distributed training, replace with DistributedSampler
            sampler = None
        else:
            loader_config['shuffle'] = False
            sampler = None

        return DataLoader(combined_data, sampler=sampler, **loader_config)

    def _create_dataset(self, subset_path, transform_pipeline):
        """Helper to create ImageFolder dataset with specified transforms"""
        full_path = f"{self.root_path}/{subset_path}"
        return torchvision.datasets.ImageFolder(
            root=full_path,
            transform=transform_pipeline
        )

In [None]:


class TrainingPipeline:
    """Handles complete model training workflow including data preparation, model setup, and training"""

    def __init__(self, computation_device, data_root, use_wandb=False, kaggle_env=False):
        self.device = computation_device
        self.data_root = data_root
        self.wandb_integration = use_wandb
        self.kaggle_environment = kaggle_env
        self.model = None
        self.train_loader = None
        self.val_loader = None
        self.test_loader = None

    def prepare_data_loaders(self, batch_size, shuffle_data, augmentation_transforms=None,
                            worker_count=0, memory_pinning=False):
        """Initializes data loaders for all dataset splits"""

        data_manager = DatasetManager(root_path=self.data_root,
                                    computation_device=self.device,
                                    base_transforms=self.preprocessing_pipeline)

        memory_pinning = memory_pinning if self.device == "cpu" else True

        self.train_loader = data_manager.build_loader(
            "train/", batch_size, shuffle_data, worker_count,
            augmentation_transforms, memory_pinning
        )

        self.val_loader = data_manager.build_loader(
            "validation/", batch_size, False, worker_count,
            pin_memory=False
        )

        self.test_loader = data_manager.build_loader(
            "test/", batch_size, False, worker_count,
            pin_memory=False
        )

        return self.train_loader, self.val_loader, self.test_loader

    def initialize_resnet(self, output_classes):
        """Configures a pretrained ResNet50 model for transfer learning"""

        self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.preprocessing_pipeline = ResNet50_Weights.IMAGENET1K_V2.transforms()

        # Freeze base layers
        for param in self.model.parameters():
            param.requires_grad = False

        # Modify final layer
        final_in_features = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Linear(final_in_features, output_classes),
            nn.LogSoftmax(dim=1)

        # Initialize final layer
        nn.init.xavier_uniform_(self.model.fc[0].weight)
        self.model.fc[0].bias.data.fill_(0.01)
        self.model.fc[0].requires_grad_(True)

        self.model = self.model.to(self.device)

    def _calculate_metrics(self, model, data_loader):
        """Evaluates model performance on given dataset"""
        model.eval()
        correct_count = 0
        total_samples = 0
        running_loss = 0

        with torch.no_grad():
            for inputs, targets in data_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = model(inputs)

                running_loss += self.loss_fn(outputs, targets).item() * inputs.size(0)
                predictions = outputs.argmax(dim=1)

                correct_count += (predictions == targets).sum().item()
                total_samples += inputs.size(0)

        accuracy = 100 * correct_count / total_samples
        avg_loss = running_loss / total_samples
        return round(avg_loss, 2), round(accuracy, 2)

    def execute_training(self, learning_rate, l2_penalty, loss_type, optimizer_choice, num_epochs):
        """Executes the complete training process"""

        self._configure_optimizer(learning_rate, l2_penalty, optimizer_choice)
        self.loss_fn = nn.CrossEntropyLoss().to(self.device)

        start_timestamp = time.time()

        for epoch in tqdm(range(num_epochs)):
            self.model.train()
            epoch_correct = 0
            epoch_total = 0
            epoch_loss = 0.0

            for inputs, labels in self.train_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.loss_fn(outputs, labels)
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item() * inputs.size(0)
                predictions = outputs.argmax(dim=1)
                epoch_correct += (predictions == labels).sum().item()
                epoch_total += inputs.size(0)

            train_acc = 100 * epoch_correct / epoch_total
            train_loss = epoch_loss / epoch_total
            val_loss, val_acc = self._calculate_metrics(self.model, self.val_loader)

            if epoch % 5 == 0:
                self._save_checkpoint()

            if self.wandb_integration:
                self._log_training_metrics(epoch, train_loss, train_acc, val_loss, val_acc)

            print(f"Epoch {epoch+1}: "
                f"Train Loss {train_loss:.2f} | Acc {train_acc:.2f}% | "
                f"Val Loss {val_loss:.2f} | Acc {val_acc:.2f}%")

        print(f"Training completed in {time.time()-start_timestamp:.1f}s")

    def _configure_optimizer(self, lr, decay, optim_choice):
        """Configures model optimizer"""
        optimizers = {
            'adam': optim.Adam,
            'nadam': optim.NAdam,
            'rmsprop': optim.RMSprop
        }
        self.optimizer = optimizers[optim_choice.lower()](
            self.model.parameters(), lr=lr, weight_decay=decay)

    def _save_checkpoint(self):
        """Saves model state"""
        path = "/kaggle/working/model" if self.kaggle_environment else "model.pth"
        torch.save(self.model.state_dict(), path)

    def _log_training_metrics(self, epoch, t_loss, t_acc, v_loss, v_acc):
        """Handles experiment tracking"""
        wandb.log({
            'epoch': epoch+1,
            'train_loss': t_loss,
            'train_acc': t_acc,
            'val_loss': v_loss,
            'val_acc': v_acc
        })

    def run_evaluation(self):
        """Executes final model evaluation on test set"""
        test_loss, test_acc = self._calculate_metrics(self.model, self.test_loader)
        print(f"Final Test Results - Loss: {test_loss:.2f} | Accuracy: {test_acc:.2f}%")