## Imports

In [1]:
# import kagglehub
import os
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.models as models
import torchvision.transforms as T
import random
import optuna
from optuna.trial import TrialState

In [2]:
# path = kagglehub.dataset_download("sidharkal/sports-image-classification")

# print("Path to dataset files:", path)

In [3]:
data_dir = "../data"
images_dir = "../data/dataset/"
train_dir = images_dir + "train/"

badminton_train_dir = train_dir + "Badminton/"
tennis_train_dir = train_dir + "Tennis/"
cricket_train_dir = train_dir + "Cricket/"
soccer_train_dir = train_dir + "Soccer/"
swimming_train_dir = train_dir + "Swimming/"
karate_train_dir = train_dir + "Karate/"
wrestling_train_dir = train_dir + "Wrestling/"

test_dir = images_dir + "test/"

badminton_test_dir = test_dir + "Badminton/"
tennis_test_dir = test_dir + "Tennis/"
cricket_test_dir = test_dir + "Cricket/"
soccer_test_dir = test_dir + "Soccer/"
swimming_test_dir = test_dir + "Swimming/"
karate_test_dir = test_dir + "Karate/"
wrestling_test_dir = test_dir + "Wrestling/"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
torch.manual_seed(42)

<torch._C.Generator at 0x242556b5bd0>

In [5]:
# shutil.copytree(path, data_dir, dirs_exist_ok=True)

## Organizing Data structure

In [None]:
# checking the train.csv , test.csv
train_df = pd.read_csv(images_dir +"/train.csv")
test_df = pd.read_csv(images_dir +"/test.csv")

In [None]:
test_df

In [None]:
train_df

In [None]:
train_df.set_index("image_ID", inplace=True), test_df.set_index("image_ID", inplace=True)

### Moving data to be per label

In [None]:
labels = train_df["label"].unique()

for label in labels:
    os.makedirs(train_dir + label, exist_ok=True)
    os.makedirs(test_dir + label, exist_ok=True)

In [None]:
labels

In [None]:
train_df.loc['7c225f7b61.jpg']['label']

In [None]:
print(train_dir)

In [None]:
for i in range(len(train_df)):
    image_id = train_df.index[i]
    label = train_df['label'][i]
    old_path = train_dir + image_id
    new_path = train_dir + label + "/" + image_id
    if os.path.exists(old_path):
        shutil.move(old_path, new_path)

### Labeling Test data as it was unlabeled

In [None]:
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

classes = ['Badminton', 'Cricket', 'Tennis', 'Swimming', 'Soccer', 'Wrestling', 'Karate']

def classify_with_clip(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(text=classes, images=image, return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image 
    probs = logits_per_image.softmax(dim=1)
    pred = probs.argmax()
    return classes[pred]

In [None]:
for i in tqdm(range(len(test_df))):
    image_id = test_df.index[i]
    image_path = test_dir + image_id
    if not os.path.exists(image_path):
        continue
    label = classify_with_clip(image_path)
    test_df.at[image_id, 'label'] = label
    new_path = test_dir + label + "/" + image_id
    if os.path.exists(image_path):
        shutil.move(image_path, new_path)


## Statistics from the data

### Checking distribution of classes

In [None]:

classes = ['Badminton', 'Cricket', 'Tennis', 'Swimming', 'Soccer', 'Wrestling', 'Karate']

train_dirs = [train_dir + cls + "/" for cls in classes]
test_dirs = [test_dir + cls + "/" for cls in classes]

train_counts = [len(os.listdir(d)) if os.path.exists(d) else 0 for d in train_dirs]
test_counts = [len(os.listdir(d)) if os.path.exists(d) else 0 for d in test_dirs]

fig, axs = plt.subplots(1, 2, figsize=(16, 6))

axs[0].bar(classes, train_counts, color='skyblue')
axs[0].set_title('Train Class Distribution')
axs[0].set_xlabel('Class')
axs[0].set_ylabel('Number of Images')
axs[0].tick_params(axis='x', rotation=45)

axs[1].bar(classes, test_counts, color='lightgreen')
axs[1].set_title('Test Class Distribution')
axs[1].set_xlabel('Class')
axs[1].set_ylabel('Number of Images')
axs[1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()


### Per class statistics

In [None]:
x = np.arange(len(classes))  
width = 0.35  

fig, ax = plt.subplots(figsize=(12, 6))
rects1 = ax.bar(x - width/2, train_counts, width, label='Train', color='skyblue')
rects2 = ax.bar(x + width/2, test_counts, width, label='Test', color='lightgreen')

ax.set_ylabel('Number of Images')
ax.set_title('Per-Class Distribution: Train vs Test')
ax.set_xticks(x)
ax.set_xticklabels(classes, rotation=45)
ax.legend()

def annotate_bars(rects):
    for rect in rects:
        height = rect.get_height()
        ax.annotate(f'{height}',
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=8)

annotate_bars(rects1)
annotate_bars(rects2)

plt.tight_layout()
plt.show()

### Pixel Value Distribution

In [None]:
class_brightness = {cls: [] for cls in classes}

for cls, folder in zip(classes, train_dirs):
    if not os.path.exists(folder):
        continue
    for img_file in os.listdir(folder):
        img_path = os.path.join(folder, img_file)
        try:
            img = Image.open(img_path).convert("L")  
            img_arr = np.array(img)
            mean_brightness = img_arr.mean()
            class_brightness[cls].append(mean_brightness)
        except Exception as e:
            print(f"Failed to process {img_path}: {e}")

# Plot distributions
fig, axs = plt.subplots(len(classes), 1, figsize=(8, len(classes)*3))

for idx, cls in enumerate(tqdm(classes)):
    axs[idx].hist(class_brightness[cls], bins=30, color='skyblue', edgecolor='black', density=True)
    axs[idx].set_title(f'Pixel Distribution: {cls}')
    axs[idx].set_xlabel('Mean Pixel Values')
    axs[idx].set_ylabel('Density')

plt.tight_layout()
plt.show()

- Pixels values are distributed well across all the images meaning that the images are not too dark or too bright.
- Thus images is considered to be well exposed and not too dark or too bright so little noise are added.

## Showing some images per class

In [None]:
fig, axs = plt.subplots(len(classes), 2, figsize=(8, len(classes) * 3))

for row_idx, (cls, folder) in enumerate(zip(classes, train_dirs)):
    if not os.path.exists(folder):
        continue
    images = [f for f in os.listdir(folder) if f.lower().endswith(('jpg', 'jpeg', 'png'))]
    selected_images = images[:2]  
    for col_idx in range(2):
        if col_idx < len(selected_images):
            img_path = os.path.join(folder, selected_images[col_idx])
            img = Image.open(img_path)
            axs[row_idx, col_idx].imshow(img)
            axs[row_idx, col_idx].axis('off')
            if col_idx == 0:
                axs[row_idx, col_idx].set_title(f"{cls} - Sample 1")
            else:
                axs[row_idx, col_idx].set_title(f"{cls} - Sample 2")
        else:
            axs[row_idx, col_idx].axis('off')

plt.tight_layout()
plt.show()

## Dataset class and data manager

### Dataset class 1

In [6]:
classes = ['Badminton', 'Cricket', 'Tennis', 'Swimming', 'Soccer', 'Wrestling', 'Karate']

In [7]:
class ImageDataset1(Dataset):
    def __init__(self, root_dir, classes, transform=None, is_train=True):
        """
        Args:
            root_dir (str): Directory with all the class folders
            classes (list): List of class names (subfolder names)
            transform (callable, optional): Optional transform to be applied on a sample
            is_train (bool): Whether this is training data or not
        """
        self.root_dir = root_dir
        self.classes = classes
        self.transform = transform
        self.is_train = is_train
        self.class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
        self.samples = []

        # Default transforms if none provided
        if self.transform is None:
            if is_train:
                self.transform = T.Compose([
                    T.RandomResizedCrop(224), # Resize to 128x128
                    # T.RandomHorizontalFlip(),
                    # T.RandomRotation(15),
                    T.ToTensor(),
                ])
            else:
                self.transform = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(224), # Resize to 128x128
                    T.ToTensor(),
                ])

        for idx, cls in enumerate(classes):
            class_folder = os.path.join(root_dir, cls)
            if not os.path.isdir(class_folder):
                continue
            for img_name in os.listdir(class_folder):
                if img_name.lower().endswith(('jpg', 'jpeg', 'png')):
                    img_path = os.path.join(class_folder, img_name)
                    self.samples.append((img_path, idx))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx, retry=0):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
            if retry < 3:
                return self.__getitem__(random.randint(0, len(self)-1), retry=retry+1)
            else:
                raise RuntimeError("Too many failed image loads.")

### Dataset class 2

In [8]:
class ImageDataset2(Dataset):
    def __init__(self, root_dir, classes, transform=None, is_train=True, split_ratio=0.8, seed=42):
        """
        Args:
            root_dir (str): Directory with all the class folders
            classes (list): List of class names (subfolder names)
            transform (callable, optional): Optional transform to be applied on a sample
            is_train (bool): Whether this is training data or not
            split_ratio (float): Ratio for training data (default is 0.8)
            seed (int): Seed for reproducibility
        """
        self.root_dir = root_dir
        self.classes = classes
        self.transform = transform
        self.is_train = is_train
        self.class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
        self.samples = []

        all_samples = []
        for idx, cls in enumerate(classes):
            class_folder = os.path.join(root_dir, cls)
            if not os.path.isdir(class_folder):
                continue
            for img_name in os.listdir(class_folder):
                if img_name.lower().endswith(('jpg', 'jpeg', 'png')):
                    img_path = os.path.join(class_folder, img_name)
                    all_samples.append((img_path, idx))

        # Shuffle and split once
        random.seed(seed)
        random.shuffle(all_samples)
        split_point = int(len(all_samples) * split_ratio)
        if is_train:
            self.samples = all_samples[:split_point]
        else:
            self.samples = all_samples[split_point:]

        # Set default transforms if not provided
        if self.transform is None:
            if is_train:
                self.transform = T.Compose([
                    T.RandomResizedCrop(224),
                    T.ToTensor(),
                ])
            else:
                self.transform = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx, retry=0):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
            if retry < 3:
                return self.__getitem__(random.randint(0, len(self)-1), retry=retry+1)
            else:
                raise RuntimeError("Too many failed image loads.")

In [9]:
classes = ['Badminton', 'Cricket', 'Tennis', 'Swimming', 'Soccer', 'Wrestling', 'Karate']

# Modelling

### Model 1: Simple CNN1

In [10]:
class Simplenet1(nn.Module):
    def __init__(self, num_classes=7):
        super(Simplenet1, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1), # 128 128 3 -> 128 128 64 
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2), # 128 128 64 -> 64 64 64
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1), # 64 64 64 -> 64 64 128
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2), # 64 64 128 -> 32 32 128
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1), # 32 32 128 -> 32 32 256
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2), # 32 32 256 -> 16 16 256
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1), # 16 16 256 -> 16 16 512
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2), # 16 16 512 -> 8 8 512
        )
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)), # 8 8 512 -> 1 1 512
            nn.Flatten(), # 1 1 512 -> 512
            nn.Linear(512, 256), # 512 -> 256
            nn.ReLU(inplace=True),
            nn.Dropout(0.5), # Dropout layer
            nn.Linear(256, 128), # 256 -> 128
            nn.ReLU(inplace=True),
            nn.Dropout(0.5), # Dropout layer
            nn.Linear(128, num_classes), # 128 -> num_classes
        )
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

### Model 2: Simple CNN2

In [11]:
class Simplenet2(nn.Module):
    def __init__(self, num_classes=7):
        super(Simplenet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),  # 224x224x3 -> 224x224x64
            nn.BatchNorm2d(64),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), # 224x224x64 -> 224x224x64
            nn.BatchNorm2d(64), 
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2),  # 224x224x64 -> 112x112x64

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),  # 112x112x64 -> 112x112x128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), # 112x112x128 -> 112x112x128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2),  # 112x112x128 -> 56x56x128

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False), # 56x56x128 -> 56x56x256
            nn.BatchNorm2d(256),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), # 56x56x256 -> 56x56x256
            nn.BatchNorm2d(256),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2),  # 56x56x256 -> 28x28x256

            nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False), # 28x28x256 -> 28x28x512
            nn.BatchNorm2d(512),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False), # 28x28x512 -> 28x28x512
            nn.BatchNorm2d(512),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2),  # 28x28x512 -> 14x14x512

            nn.Conv2d(512, 256, kernel_size=3, padding=1, bias=False), # 14x14x512 -> 14x14x256
            nn.BatchNorm2d(256),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2),  # 14x14x256 -> 7x7x256
        )

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),  # 7x7x256 -> 1x1x256
            nn.Flatten(),  # 1x1x256 -> 256
            nn.Linear(256, 512),  # 256 -> 512
            nn.LeakyReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes),  # 512 -> num_classes
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

### Model 3: Simple CNN3

In [12]:
class SimpleNet3(nn.Module):
    def __init__(self, num_classes=7):
        super(SimpleNet3, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),  # 224x224x3 -> 112x112x16
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),  # 112x112x16 -> 56x56x16

            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # 56x56x16 -> 28x28x32
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),  # 28x28x32 -> 14x14x32

            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # 14x14x32 -> 7x7x64
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),  # 7x7x64 -> 1x1x64
            nn.Flatten(),  # 1x1x64 -> 64
            nn.Linear(64, 32),  # 64 -> 32
            nn.ReLU(inplace=True),
            nn.Linear(32, 16),  # 32 -> 16
            nn.ReLU(inplace=True),
            nn.Linear(16, num_classes),  # 16 -> num_classes
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

### Model 4: Simple CNN4

In [13]:
class SimpleNet4(nn.Module):
    def __init__(self, num_classes=7):
        super(SimpleNet4, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # 224x224x3 -> 224x224x32
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),  # 224x224x32 -> 112x112x32

            nn.Conv2d(32, 64, kernel_size=3, padding=1),  # 112x112x32 -> 112x112x64
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),  # 112x112x64 -> 56x56x64

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # 56x56x64 -> 28x28x128
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),  # 28x28x128 -> 14x14x128

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),  # 14x14x128 -> 7x7x256
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),  # 7x7x256 -> 1x1x256
            nn.Flatten(),  # 1x1x256 -> 256
            nn.Dropout(0.5),
            nn.Linear(256, 128),  # 256 -> 128
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),  # 128 -> 64
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),  # 64 -> 32
            nn.ReLU(inplace=True),
            nn.Linear(32, 16),  # 32 -> 16
            nn.ReLU(inplace=True),
            nn.Linear(16, num_classes),  # 16 -> num_classes
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Utils for Pretrained

In [14]:
def freeze_all_but_last_n(model, n=2):
    for param in model.parameters():
        param.requires_grad = False

    # Get all modules with parameters
    modules_with_params = [m for m in model.modules() if any(p.requires_grad is False for p in m.parameters())]

    # Unfreeze last n modules with parameters
    for module in modules_with_params[-n:]:
        for param in module.parameters():
            param.requires_grad = True

    return model


def print_trainable_params(model):
    print("Trainable Parameters:")
    total = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            num_params = param.numel()
            # print(f"{name}: {num_params}")
            total += num_params
    print(f"Total Trainable Parameters: {total}")


# Utils

In [15]:
def load_model(model_name):
    if model_name == "simplenet1":
        return Simplenet1()
    
    elif model_name == "simplenet2":
        return Simplenet2()
    
    elif model_name == "simplenet3":
        return SimpleNet3()
    
    elif model_name == "simplenet4":
        return SimpleNet4()
    
    else:
        raise ValueError(f"Model {model_name} not recognized. Please choose a valid model name.")    

In [16]:
def get_dataloaders(config, transform=None):
    dataset_type = config["dataset_class"]
    batch_size = config["batch_size"]
    seed = 42 

    if dataset_type == "ImageClass1": # using train and test directories
        train_dataset = ImageDataset1(root_dir=train_dir, transform=transform, classes=classes, is_train=True)
        val_dataset = ImageDataset1(root_dir=test_dir, transform=transform, classes=classes, is_train=False)
    else: # Splitting train into train and validation sets
        train_dataset = ImageDataset2(
            root_dir=train_dir,
            classes=classes,
            transform=transform,
            is_train=True,
            split_ratio=0.8,
            seed=seed
        )
        val_dataset = ImageDataset2(
            root_dir=train_dir,
            classes=classes,
            transform=transform,
            is_train=False,
            split_ratio=0.8,
            seed=seed
        )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    return train_loader, val_loader

def validate_model(model, val_loader, criterion):
    model.eval()
    device = next(model.parameters()).device
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    return val_loss / len(val_loader), 100. * correct / total

In [17]:
def initialize_weights(model, method):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            if method == "xavier":
                nn.init.xavier_uniform_(m.weight)
            elif method == "kaiming":
                nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')

def should_initialize(model_type):
    return model_type == "scratch"  # only initialize scratch models


In [18]:
from datetime import datetime

def train_model(config):
    model_type , model_name = config["model_choice"]
    train_loader, val_loader = get_dataloaders(config)

    init_method = config["init_method"]

    model = load_model(model_name)
    model.to(device)
    time_stamp = datetime.now().strftime("%Y%m%d_%H")
    unique_config = f"{model_name}_{config['dataset_class']}_{config['optimizer']}_{config['init_method']}_{config['batch_size']}_{config['lr']}_time_{time_stamp}"

    if should_initialize(config["model_choice"][0]) and init_method != "default":
        initialize_weights(model, init_method)

    # Optimizer
    if config["optimizer"] == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
    elif config["optimizer"] == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"])
    else:
        optimizer = torch.optim.RMSprop(model.parameters(), lr=config["lr"])

    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)

    best_val_loss = float('inf')
    patience_counter = 0
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []

    epochs = config["epochs"]
    save_interval = 2 if model_type == "pretrained" else 50
    save_dir = os.path.join("logs_224_scratch", "checkpoints")
    os.makedirs(save_dir, exist_ok=True)

    total_batches = len(train_loader)
    total_steps = epochs * total_batches
    progress_bar = tqdm(total=total_steps, dynamic_ncols=True, desc="Training")

    for epoch in range(epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Update tqdm
            train_loss = running_loss / (i + 1)
            train_acc = 100. * correct / total
            progress_bar.update(1)
            progress_bar.set_postfix({
                "Epoch": f"{epoch+1}/{epochs}",
                "Train Loss": f"{train_loss:.4f}",
                "Train Acc": f"{train_acc:.2f}%"
            })

        # Validation phase
        val_loss, val_acc = validate_model(model, val_loader, criterion)
        scheduler.step(val_loss)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)

        # Save model checkpoint
        if (epoch + 1) % save_interval == 0:
            os.makedirs(os.path.join(save_dir, unique_config), exist_ok=True)
            torch.save(model.state_dict(), os.path.join(save_dir, unique_config, f"epoch_{epoch+1}.pt"))

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= 20:
                progress_bar.set_description("Early Stopping")
                break
    
        progress_bar.set_postfix({"Epoch": f"{epoch+1}/{epochs}", "Train Loss": f"{train_loss:.4f}", "Train Acc": f"{train_acc:.2f}%", "Val Loss": f"{val_loss:.4f}", "Val Acc": f"{val_acc:.2f}%"})
    progress_bar.close()

    # Save metrics
    os.makedirs(os.path.join(save_dir, unique_config), exist_ok=True)
    torch.save({
        "train_losses": train_losses,
        "val_losses": val_losses,
        "train_accs": train_accs,
        "val_accs": val_accs
    }, os.path.join(save_dir, unique_config, "metrics.pt"))


    torch.save(model.state_dict(), os.path.join(save_dir, unique_config, "final_model.pt"))

    return max(val_accs)


## Manual Configurations Simple CNN

In [19]:
configv3 = {
    "model_choice": ("scratch", "simplenet4"),
    "optimizer": "sgd",
    "lr": 0.01,
    "batch_size": 32,
    "dataset_class": "ImageClass2",
    "epochs": 100,
    "init_method": "kaiming",
}


configv4 = {
    "model_choice": ("scratch", "simplenet3"),
    "optimizer": "adam",
    "lr": 0.01,
    "batch_size": 64,
    "dataset_class": "ImageClass1",
    "epochs": 100,
    "init_method": "kaiming",
}

In [20]:
configv7 = {
    "model_choice": ("scratch", "simplenet3"),
    "optimizer": "adam",
    "lr": 0.01,
    "batch_size": 64,
    "dataset_class": "ImageClass2",
    "epochs": 100,
    "init_method": "kaiming",
}

configv8 = {
    "model_choice": ("scratch", "simplenet4"),
    "optimizer": "adam",
    "lr": 0.01,
    "batch_size": 32,
    "dataset_class": "ImageClass1",
    "epochs": 100,
    "init_method": "xavier",
}

In [21]:
#rain_model(configv3)
#rain_model(configv4)
#rain_model(configv7)
#rain_model(configv8)

In [22]:
configv1 = {
    "model_choice": ("scratch", "simplenet2"),
    "optimizer": "adam",
    "lr": 0.01,
    "batch_size": 16,
    "dataset_class": "ImageClass2",
    "epochs": 100,
    "init_method": "xavier",
}

configv2 = {
    "model_choice": ("scratch", "simplenet1"),
    "optimizer": "adam",
    "lr": 0.01,
    "batch_size": 32,
    "dataset_class": "ImageClass2",
    "epochs": 100,
    "init_method": "xavier",
}

configv5 = {
    "model_choice": ("scratch", "simplenet2"),
    "optimizer": "sgd",
    "lr": 0.01,
    "batch_size": 16,
    "dataset_class": "ImageClass2",
    "epochs": 100,
    "init_method": "xavier",
}

configv6 = {
    "model_choice": ("scratch", "simplenet1"),
    "optimizer": "adam",
    "lr": 0.01,
    "batch_size": 16,
    "dataset_class": "ImageClass2",
    "epochs": 100,
    "init_method": "xavier",
}

configv9 = {
    "model_choice": ("scratch", "simplenet2"),
    "optimizer": "adam",
    "lr": 0.01,
    "batch_size": 16,
    "dataset_class": "ImageClass2",
    "epochs": 100,
    "init_method": "xavier",
}


In [23]:
train_model(configv5)

Early Stopping:  60%|██████    | 24720/41200 [3:03:11<2:02:07,  2.25it/s, Epoch=60/100, Train Loss=0.5477, Train Acc=81.08%]                            


85.90522478736331

In [24]:
train_model(configv9)

Early Stopping:  93%|█████████▎| 38316/41200 [4:36:35<20:49,  2.31it/s, Epoch=93/100, Train Loss=0.6787, Train Acc=77.33%]                              


83.3535844471446

In [25]:
train_model(configv2)

Early Stopping:  66%|██████▌   | 13596/20600 [11:41:22<6:01:18,  3.10s/it, Epoch=66/100, Train Loss=1.2807, Train Acc=48.79%]                            


55.28554070473876

In [26]:
train_model(configv1)

Early Stopping:  91%|█████████ | 37492/41200 [5:00:25<29:42,  2.08it/s, Epoch=91/100, Train Loss=0.6249, Train Acc=78.39%]                              


83.77885783718105

In [None]:
#train_model(configv6)
train_model(configv1)
train_model(configv5)
train_model(configv9)
train_model(configv2)

# Functions to load


In [None]:
def load_model_weights(model, path):
    model.load_state_dict(torch.load(path))
    return model

def load_metrics(path):
    return torch.load(path)