## Imports

In [None]:
# import kagglehub
import os
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.models as models
import torchvision.transforms as T
import random
import optuna
from optuna.trial import TrialState

In [None]:
torch.manual_seed(42)

In [None]:

# Paths
input_path = "/kaggle/input/sports-image-classification/dataset/"
working_dir = "/kaggle/working/dataset/"
train_dir = os.path.join(working_dir, "train/")
test_dir = os.path.join(working_dir, "test/")

# Copy everything from input to working dir (if not already done)
if not os.path.exists(working_dir):
    shutil.copytree(input_path, working_dir)

# Load CSVs
train_df = pd.read_csv(os.path.join(working_dir, "train.csv")).set_index("image_ID")
test_df = pd.read_csv(os.path.join(working_dir, "test.csv")).set_index("image_ID")

# Class labels
labels = train_df["label"].unique()
for label in labels:
    os.makedirs(os.path.join(train_dir, label), exist_ok=True)
    os.makedirs(os.path.join(test_dir, label), exist_ok=True)

# Move training images into label folders
for image_id, row in train_df.iterrows():
    label = row["label"]
    old_path = os.path.join(train_dir, image_id)
    new_path = os.path.join(train_dir, label, image_id)
    if os.path.exists(old_path):
        shutil.move(old_path, new_path)


In [None]:
# Setup CLIP
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

classes = ['Badminton', 'Cricket', 'Tennis', 'Swimming', 'Soccer', 'Wrestling', 'Karate']

def classify_with_clip(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(text=classes, images=image, return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    probs = outputs.logits_per_image.softmax(dim=1)
    return classes[probs.argmax().item()]

In [None]:
for image_id in tqdm(test_df.index):
    image_path = os.path.join(test_dir, image_id)
    if not os.path.exists(image_path):
        continue
    label = classify_with_clip(image_path)
    test_df.at[image_id, 'label'] = label
    target_path = os.path.join(test_dir, label)
    os.makedirs(target_path, exist_ok=True)
    shutil.move(image_path, os.path.join(target_path, image_id))


## Dataset class and data manager

### Dataset class 1

In [None]:
classes = ['Badminton', 'Cricket', 'Tennis', 'Swimming', 'Soccer', 'Wrestling', 'Karate']

In [None]:
class ImageDataset1(Dataset):
    def __init__(self, root_dir, classes, transform=None, is_train=True):
        """
        Args:
            root_dir (str): Directory with all the class folders
            classes (list): List of class names (subfolder names)
            transform (callable, optional): Optional transform to be applied on a sample
            is_train (bool): Whether this is training data or not
        """
        self.root_dir = root_dir
        self.classes = classes
        self.transform = transform
        self.is_train = is_train
        self.class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
        self.samples = []

        # Default transforms if none provided
        if self.transform is None:
            if is_train:
                self.transform = T.Compose([
                    T.RandomResizedCrop(128), # Resize to 128x128
                    # T.RandomHorizontalFlip(),
                    # T.RandomRotation(15),
                    T.ToTensor(),
                ])
            else:
                self.transform = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(128), # Resize to 128x128
                    T.ToTensor(),
                ])

        for idx, cls in enumerate(classes):
            class_folder = os.path.join(root_dir, cls)
            if not os.path.isdir(class_folder):
                continue
            for img_name in os.listdir(class_folder):
                if img_name.lower().endswith(('jpg', 'jpeg', 'png')):
                    img_path = os.path.join(class_folder, img_name)
                    self.samples.append((img_path, idx))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx, retry=0):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
            if retry < 3:
                return self.__getitem__(random.randint(0, len(self)-1), retry=retry+1)
            else:
                raise RuntimeError("Too many failed image loads.")

### Dataset class 2

In [None]:
class ImageDataset2(Dataset):
    def __init__(self, root_dir, classes, transform=None, is_train=True, split_ratio=0.8, seed=42):
        """
        Args:
            root_dir (str): Directory with all the class folders
            classes (list): List of class names (subfolder names)
            transform (callable, optional): Optional transform to be applied on a sample
            is_train (bool): Whether this is training data or not
            split_ratio (float): Ratio for training data (default is 0.8)
            seed (int): Seed for reproducibility
        """
        self.root_dir = root_dir
        self.classes = classes
        self.transform = transform
        self.is_train = is_train
        self.class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
        self.samples = []

        all_samples = []
        for idx, cls in enumerate(classes):
            class_folder = os.path.join(root_dir, cls)
            if not os.path.isdir(class_folder):
                continue
            for img_name in os.listdir(class_folder):
                if img_name.lower().endswith(('jpg', 'jpeg', 'png')):
                    img_path = os.path.join(class_folder, img_name)
                    all_samples.append((img_path, idx))

        # Shuffle and split once
        random.seed(seed)
        random.shuffle(all_samples)
        split_point = int(len(all_samples) * split_ratio)
        if is_train:
            self.samples = all_samples[:split_point]
        else:
            self.samples = all_samples[split_point:]

        # Set default transforms if not provided
        if self.transform is None:
            if is_train:
                self.transform = T.Compose([
                    T.RandomResizedCrop(128),
                    T.ToTensor(),
                ])
            else:
                self.transform = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(128),
                    T.ToTensor(),
                ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx, retry=0):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
            if retry < 3:
                return self.__getitem__(random.randint(0, len(self)-1), retry=retry+1)
            else:
                raise RuntimeError("Too many failed image loads.")

In [None]:
classes = ['Badminton', 'Cricket', 'Tennis', 'Swimming', 'Soccer', 'Wrestling', 'Karate']

# Modelling

### Model 1: Simple CNN1

In [None]:
class Simplenet1(nn.Module):
    def __init__(self, num_classes=7):
        super(Simplenet1, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1), # 128 128 3 -> 128 128 64 
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2), # 128 128 64 -> 64 64 64
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1), # 64 64 64 -> 64 64 128
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2), # 64 64 128 -> 32 32 128
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1), # 32 32 128 -> 32 32 256
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2), # 32 32 256 -> 16 16 256
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1), # 16 16 256 -> 16 16 512
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2), # 16 16 512 -> 8 8 512
        )
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)), # 8 8 512 -> 1 1 512
            nn.Flatten(), # 1 1 512 -> 512
            nn.Linear(512, 256), # 512 -> 256
            nn.ReLU(inplace=True),
            nn.Dropout(0.5), # Dropout layer
            nn.Linear(256, 128), # 256 -> 128
            nn.ReLU(inplace=True),
            nn.Dropout(0.5), # Dropout layer
            nn.Linear(128, num_classes), # 128 -> num_classes
        )
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

### Model 2: Simple CNN2

In [None]:
class Simplenet2(nn.Module):
    def __init__(self, num_classes=7):
        super(Simplenet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), # 128x128x3 → 128x128x64
            nn.BatchNorm2d(64),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), # 128x128x64 → 128x128x64
            nn.BatchNorm2d(64), 
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2), # 128x128x64 → 64x64x64

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False), # 64x64x64 → 64x64x128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), # 64x64x128 → 64x64x128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2), # 64x64x128 → 32x32x128

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False), # 32x32x128 → 32x32x256
            nn.BatchNorm2d(256),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), # 32x32x256 → 32x32x256
            nn.BatchNorm2d(256),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2), # 32x32x256 → 16x16x256

            nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False), # 16x16x256 → 16x16x512
            nn.BatchNorm2d(512),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False), # 16x16x512 → 16x16x512
            nn.BatchNorm2d(512),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2), # 16x16x512 → 8x8x512

            nn.Conv2d(512, 256, kernel_size=3, padding=1, bias=False), # 8x8x512 → 8x8x256
            nn.BatchNorm2d(256),
            nn.LeakyReLU(inplace=True),
            nn.MaxPool2d(2) # 8x8x256 → 4x4x256
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),                  # 4x4x256 = 4096
            nn.Linear(256 * 4 * 4, 512),   # 4096 → 512
            nn.LeakyReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)   # 512 → num_classes
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

### Model 3: Simple CNN3

In [None]:
class SimpleNet3(nn.Module):
    def __init__(self, num_classes=7):
        super(SimpleNet3, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2), # 128 128 3 -> 63 63 64
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # 63 63 64 -> 31 31 64
            nn.Conv2d(16, 32, kernel_size=3, stride=2),  # 31 31 64 -> 14 14 128
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2), #   14 14 128 -> 7 7 128
            nn.Conv2d(32, 64, kernel_size=3, stride=2), # 7 7 128 -> 3 3 256
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(576, 288),
            nn.ReLU(inplace=True),
            nn.Linear(288, 144),
            nn.ReLU(inplace=True),
            nn.Linear(144, 72),
            nn.ReLU(inplace=True),
            nn.Linear(72, 36),
            nn.ReLU(inplace=True),
            nn.Linear(36, 18),
            nn.ReLU(inplace=True),
            nn.Linear(18, num_classes),
        )
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

### Model 4: Simple CNN4

In [None]:
class SimpleNet4(nn.Module):
    def __init__(self, num_classes=7):
        super(SimpleNet4, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1), # 128 128 3 -> 128 128 32
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),  # 128 128 32 -> 64 64 32
            nn.Conv2d(32, 64, kernel_size=3, padding=1),   # 64 64 32 -> 64 64 64
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),  # 64 64 64 -> 32 32 64
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 32 32 64 -> 16 16 128
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),  # 16 16 128 -> 8 8 128
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),#   -> 4 x 4 x 256
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(4096, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, num_classes),
        )
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

### Pretrained models

In [None]:
def freeze_all_but_last_n(model, n=2):
    for param in model.parameters():
        param.requires_grad = False

    # Get all modules with parameters
    modules_with_params = [m for m in model.modules() if any(p.requires_grad is False for p in m.parameters())]

    # Unfreeze last n modules with parameters
    for module in modules_with_params[-n:]:
        for param in module.parameters():
            param.requires_grad = True

    return model


def print_trainable_params(model):
    print("Trainable Parameters:")
    total = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            num_params = param.numel()
            # print(f"{name}: {num_params}")
            total += num_params
    print(f"Total Trainable Parameters: {total}")


#### ResNet18

In [None]:
resnet18 = models.resnet18(weights='DEFAULT')
resnet18.fc = nn.Linear(resnet18.fc.in_features, 7)  # Change the output layer to match the number of classes

print("Trainable parameters before freezing for ResNet18:")
print_trainable_params(resnet18)
resnet18 = freeze_all_but_last_n(resnet18, 2)  # Freeze all but the last 2 layers
print("Trainable parameters after freezing for ResNet18:")
print_trainable_params(resnet18)
resnet18 = resnet18.to(device)

#### Resnet34

In [None]:
resnet34 = models.resnet34(weights='DEFAULT')
resnet34.fc = nn.Linear(resnet34.fc.in_features, 7)  # Change the output layer to match the number of classes

print("Trainable parameters before freezing for ResNet34:")
print_trainable_params(resnet34)
resnet34 = freeze_all_but_last_n(resnet34, 2)  # Freeze all but the last 2 layers
print("Trainable parameters after freezing for ResNet34:")
print_trainable_params(resnet34)
resnet34 = resnet34.to(device)

#### Resnet50

In [None]:
resnet50 = models.resnet50(weights='DEFAULT')
resnet50.fc = nn.Linear(resnet50.fc.in_features, 7)  # Change the output layer to match the number of classes

print("Trainable parameters before freezing for ResNet50:")
print_trainable_params(resnet50)
resnet50 = freeze_all_but_last_n(resnet50, 2)  # Freeze all but the last 2 layers
print("Trainable parameters after freezing for ResNet50:")
print_trainable_params(resnet50)
resnet50 = resnet50.to(device)

#### Resnet101

In [None]:
resnet101 = models.resnet101(weights='DEFAULT')
resnet101.fc = nn.Linear(resnet101.fc.in_features, 7)  # Change the output layer to match the number of classes

print("Trainable parameters before freezing for ResNet101:")
print_trainable_params(resnet101)
resnet101 = freeze_all_but_last_n(resnet101, 2)  # Freeze all but the last 2 layers
print("Trainable parameters after freezing for ResNet101:")
print_trainable_params(resnet101)
resnet101 = resnet101.to(device)

#### Resnet152

In [None]:
resnet152 = models.resnet152(weights='DEFAULT')
resnet152.fc = nn.Linear(resnet152.fc.in_features, 7)  # Change the output layer to match the number of classes

print("Trainable parameters before freezing for ResNet152:")
print_trainable_params(resnet152)
resnet152 = freeze_all_but_last_n(resnet152, 2)  # Freeze all but the last 2 layers
print("Trainable parameters after freezing for ResNet152:")
print_trainable_params(resnet152)
resnet152 = resnet152.to(device)

#### VGG16

In [None]:
vgg16 = models.vgg16(weights='DEFAULT')
vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, 7)  # Change the output layer to match the number of classes

print("Trainable parameters before freezing for VGG16:")
print_trainable_params(vgg16)
vgg16 = freeze_all_but_last_n(vgg16, 1)  # Freeze all but the last 2 layers
print("Trainable parameters after freezing for VGG16:")
print_trainable_params(vgg16)
vgg16 = vgg16.to(device)

#### AlexNet

In [None]:
alexnet = models.alexnet(weights='DEFAULT')
alexnet.classifier[6] = nn.Linear(alexnet.classifier[6].in_features, 7)  # Change the output layer to match the number of classes

print("Trainable parameters before freezing for AlexNet:")
print_trainable_params(alexnet)
alexnet = freeze_all_but_last_n(alexnet, 1)  # Freeze all but the last 2 layers
print("Trainable parameters after freezing for AlexNet:")
print_trainable_params(alexnet)
alexnet = alexnet.to(device)

#### GoogleNet

In [None]:
googlenet = models.googlenet(weights='DEFAULT')
googlenet.fc = nn.Linear(googlenet.fc.in_features, 7)  # Change the output layer to match the number of classes

print("Trainable parameters before freezing for GoogLeNet:")
print_trainable_params(googlenet)
googlenet = freeze_all_but_last_n(googlenet, 2)  # Freeze all but the last 2 layers
print("Trainable parameters after freezing for GoogLeNet:")
print_trainable_params(googlenet)
googlenet = googlenet.to(device)

# Using Optuna for different hyperparameter combinations

In [None]:
def load_model(model_name):
    if model_name == "resnet18":
        resnet18 = models.resnet18(weights='DEFAULT')
        resnet18.fc = nn.Linear(resnet18.fc.in_features, 7)  # Change the output layer to match the number of classes
        resnet18 = freeze_all_but_last_n(resnet18, 2)  # Freeze all but the last 2 layers
        resnet18 = resnet18.to(device)
        return resnet18

    elif model_name == "resnet34":
        resnet34 = models.resnet34(weights='DEFAULT')
        resnet34.fc = nn.Linear(resnet34.fc.in_features, 7)
        resnet34 = freeze_all_but_last_n(resnet34, 2)
        resnet34 = resnet34.to(device)
        return resnet34
    
    elif model_name == "resnet50":
        resnet50 = models.resnet50(weights='DEFAULT')
        resnet50.fc = nn.Linear(resnet50.fc.in_features, 7)
        resnet50 = freeze_all_but_last_n(resnet50, 2)
        resnet50 = resnet50.to(device)
        return resnet50

    elif model_name == "resnet101":
        resnet101 = models.resnet101(weights='DEFAULT')
        resnet101.fc = nn.Linear(resnet101.fc.in_features, 7)
        resnet101 = freeze_all_but_last_n(resnet101, 2)
        resnet101 = resnet101.to(device)
        return resnet101
    
    elif model_name == "resnet152":
        resnet152 = models.resnet152(weights='DEFAULT')
        resnet152.fc = nn.Linear(resnet152.fc.in_features, 7)
        resnet152 = freeze_all_but_last_n(resnet152, 2)
        resnet152 = resnet152.to(device)
        return resnet152
    
    elif model_name == "vgg16":
        vgg16 = models.vgg16(weights='DEFAULT')
        vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, 7)
        vgg16 = freeze_all_but_last_n(vgg16, 1)
        vgg16 = vgg16.to(device)
        return vgg16
    
    elif model_name == "alexnet":
        alexnet = models.alexnet(weights='DEFAULT')
        alexnet.classifier[6] = nn.Linear(alexnet.classifier[6].in_features, 7)
        alexnet = freeze_all_but_last_n(alexnet, 1)
        alexnet = alexnet.to(device)
        return alexnet
    
    elif model_name == "googlenet":
        googlenet = models.googlenet(weights='DEFAULT')
        googlenet.fc = nn.Linear(googlenet.fc.in_features, 7)
        googlenet = freeze_all_but_last_n(googlenet, 2)
        googlenet = googlenet.to(device)
        return googlenet
    
    elif model_name == "simplenet1":
        return Simplenet1()
    
    elif model_name == "simplenet2":
        return Simplenet2()
    
    elif model_name == "simplenet3":
        return SimpleNet3()
    
    elif model_name == "simplenet4":
        return SimpleNet4()
    
    else:
        raise ValueError(f"Model {model_name} not recognized. Please choose a valid model name.")    

In [None]:
def get_dataloaders(config, transform=None):
    dataset_type = config["dataset_class"]
    batch_size = config["batch_size"]
    seed = 42 

    if dataset_type == "ImageClass1": # using train and test directories
        train_dataset = ImageDataset1(root_dir=train_dir, transform=transform, classes=classes, is_train=True)
        val_dataset = ImageDataset1(root_dir=test_dir, transform=transform, classes=classes, is_train=False)
    else: # Splitting train into train and validation sets
        train_dataset = ImageDataset2(
            root_dir=train_dir,
            classes=classes,
            transform=transform,
            is_train=True,
            split_ratio=0.8,
            seed=seed
        )
        val_dataset = ImageDataset2(
            root_dir=train_dir,
            classes=classes,
            transform=transform,
            is_train=False,
            split_ratio=0.8,
            seed=seed
        )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    return train_loader, val_loader

def validate_model(model, val_loader, criterion):
    model.eval()
    device = next(model.parameters()).device
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    return val_loss / len(val_loader), 100. * correct / total

In [None]:
def initialize_weights(model, method):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            if method == "xavier":
                nn.init.xavier_uniform_(m.weight)
            elif method == "kaiming":
                nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')

def should_initialize(model_type):
    return model_type == "scratch"  # only initialize scratch models


In [None]:
import datetime


def train_model(config):
    _ , model_name = config["model_choice"]
    train_loader, val_loader = get_dataloaders(config)

    init_method = config["init_method"]

    model = load_model(model_name)
    model.to(device)
    if should_initialize(config["model_choice"][0]) and init_method != "default":
        initialize_weights(model, init_method)

    # Optimizer
    if config["optimizer"] == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
    elif config["optimizer"] == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"])
    else:
        optimizer = torch.optim.RMSprop(model.parameters(), lr=config["lr"])

    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)

    best_val_loss = float('inf')
    patience_counter = 0
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []

    epochs = config["epochs"]
    save_interval = 2 if epochs <= 6 else 50
    save_dir = os.path.join("/kaggle/working/logs", "checkpoints")
    os.makedirs(save_dir, exist_ok=True)

    total_batches = len(train_loader)
    total_steps = epochs * total_batches
    progress_bar = tqdm(total=total_steps, dynamic_ncols=True, desc="Training")

    for epoch in range(epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Update tqdm
            train_loss = running_loss / (i + 1)
            train_acc = 100. * correct / total
            progress_bar.update(1)
            progress_bar.set_postfix({
                "Epoch": f"{epoch+1}/{epochs}",
                "Train Loss": f"{train_loss:.4f}",
                "Train Acc": f"{train_acc:.2f}%"
            })

        # Validation phase
        val_loss, val_acc = validate_model(model, val_loader, criterion)
        scheduler.step(val_loss)

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)

        # Save model checkpoint
        if (epoch + 1) % save_interval == 0:
            time_stamp = datetime.now().strftime("%Y%m%d_%H")
            unique_config = f"{model_name}_{config['dataset_class']}_{config['optimizer']}_{config['init_method']}_{config['batch_size']}_{config['lr']}_time_{time_stamp}"
            os.makedirs(os.path.join(save_dir, unique_config), exist_ok=True)
            torch.save(model.state_dict(), os.path.join(save_dir, unique_config, f"epoch_{epoch+1}.pt"))

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= 20:
                progress_bar.set_description("Early Stopping")
                break

    progress_bar.close()

    # Save metrics
    os.makedirs(os.path.join(save_dir, unique_config), exist_ok=True)
    torch.save({
        "train_losses": train_losses,
        "val_losses": val_losses,
        "train_accs": train_accs,
        "val_accs": val_accs
    }, os.path.join(save_dir, unique_config, "metrics.pt"))

    return max(val_accs)


In [None]:
def objective1(trial):
    # Suggest hyperparameters
    model_name = trial.suggest_categorical("model_name", [
     "simplenet1", "simplenet2", "simplenet3", "simplenet4"
    ])
    init_method = trial.suggest_categorical("init_method", ["xavier", "kaiming", "default"])
    optimizer_name = trial.suggest_categorical("optimizer", ["adam", "sgd", "rmsprop"])
    lr = trial.suggest_categorical("lr", [0.1, 0.01, 0.001])
    batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
    dataset_class = trial.suggest_categorical("dataset_class", ["ImageClass1", "ImageClass2"])
    
    model_type = "pretrained" if model_name in [
        "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "googlenet",
        "alexnet", "vgg16"
    ] else "scratch"
    epochs = 150

    config = {
        "model_choice": (model_type, model_name),
        "optimizer": optimizer_name,
        "lr": lr,
        "batch_size": batch_size,
        "dataset_class": dataset_class,
        "epochs": epochs,
        "init_method": init_method,
    }

    # Training
    acc = train_model(config)
    return acc  # maximize accuracy

def objective2(trial):
    # Suggest hyperparameters
    model_name = trial.suggest_categorical("model_name", [
        "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "googlenet",
        "alexnet", "vgg16"
    ])
    init_method = trial.suggest_categorical("init_method", ["xavier", "kaiming", "default"])
    optimizer_name = trial.suggest_categorical("optimizer", ["adam", "sgd", "rmsprop"])
    lr = trial.suggest_categorical("lr", [0.1, 0.01, 0.001])
    batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
    dataset_class = trial.suggest_categorical("dataset_class", ["ImageClass1", "ImageClass2"])
    
    model_type = "pretrained" if model_name in [
        "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "googlenet",
        "alexnet", "vgg16"
    ] else "scratch"
    epochs = 8

    config = {
        "model_choice": (model_type, model_name),
        "optimizer": optimizer_name,
        "lr": lr,
        "batch_size": batch_size,
        "dataset_class": dataset_class,
        "epochs": epochs,
        "init_method": init_method,
    }

    # Training
    acc = train_model(config)
    return acc  # maximize accuracy



## To Run on a device with GPU

In [None]:
study = optuna.create_study(direction="maximize")
study.optimize(objective1, n_trials=100)  # adjust trials as needed

In [None]:
study = optuna.create_study(direction="maximize")
study.optimize(objective2, n_trials=100)  # adjust trials as needed

# Functions to load

In [None]:
def load_model_weights(model, path):
    model.load_state_dict(torch.load(path))
    return model

def load_metrics(path):
    return torch.load(path)