In [None]:
%pip install kagglehub

In [None]:
from dataset.translate import *

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import seaborn as sns
from sklearn.metrics import classification_report, accuracy_score, recall_score, precision_score, f1_score
import wandb
import kagglehub

In [None]:
path = kagglehub.dataset_download("alessiocorrado99/animals10")
print(path)

In [None]:
def translate_label(dirname: str, mapping: dict[str, str]) -> str:
    translated = mapping.get(dirname, dirname)
    return translated

In [None]:
df = pd.DataFrame()
data = []

for (dirpath, dirnames, filenames) in os.walk(path + "/raw-img"):
    if filenames:
        label = os.path.basename(dirpath)
        translated_label = translate_label(label, translate_it_to_en)
        for filename in filenames:
            filepath = os.path.join(dirpath, filename)
            data.append({'label': translated_label, 'filepath': filepath})

df = pd.DataFrame(data)

In [None]:
labelEncoding = {label: str(idx) for idx, label in enumerate(df['label'].unique())}
df['label_encoded'] = df['label'].map(labelEncoding)
df.groupby('label_encoded').first().head(df['label'].nunique())

In [None]:
def split_data(df: pd.DataFrame, test_size: float, val_size: float, random_state: int):
    train_val_df, test_df = train_test_split(
        df,
        test_size=test_size,
        stratify=df['label_encoded'],
        random_state=random_state
    )

    val_relative_size = val_size / (1 - test_size)

    train_df, val_df = train_test_split(
        train_val_df,
        test_size=val_relative_size,
        stratify=train_val_df['label_encoded'],
        random_state=random_state
    )

    train_x = train_df['filepath']
    train_y = train_df['label_encoded']

    val_x = val_df['filepath']
    val_y = val_df['label_encoded']

    test_x = test_df['filepath']
    test_y = test_df['label_encoded']

    return train_x, train_y, val_x, val_y, test_x, test_y

In [None]:
train_x, train_y, val_x, val_y, test_x, test_y = split_data(df, test_size=0.2, val_size=0.1, random_state=24)

In [None]:
from image_dataset import ImageDataset
tr = transforms.Compose([
    transforms.Resize(size=(128, 128)),
    transforms.ToTensor(),
])

train_dataset = ImageDataset(val_x, val_y, transform=tr)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)

test_dataset = ImageDataset(test_x, test_y, transform=tr)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

In [None]:
device = torch.device(device="cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
class baseCNN(nn.Module):
    def __init__(self, num_classes, conv_dropouts=None, linear_dropout=0.2, filters=None):
        super().__init__()

        # Default architecture
        if filters is None:
            filters = [32, 128, 256, 512]
        if conv_dropouts is None:
            conv_dropouts = [0.2, 0.2, 0.2, 0.25]

        self.model = nn.Sequential(
            self.conv_block(3, filters[0], conv_dropouts[0]),
            self.conv_block(filters[0], filters[1], conv_dropouts[1]),
            self.conv_block(filters[1], filters[2], conv_dropouts[2]),
            self.conv_block(filters[2], filters[3], conv_dropouts[3]),

            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(filters[3], 128),
            nn.ReLU(),
            nn.Dropout(linear_dropout),
            nn.Linear(128, num_classes)
        )

    def conv_block(self, in_chanel, out_chanel, drop):
        return nn.Sequential(
            nn.Conv2d(in_chanel, out_chanel, 3, padding=1),
            nn.BatchNorm2d(out_chanel),
            nn.ReLU(),

            nn.Conv2d(out_chanel, out_chanel, 3, padding=1),
            nn.BatchNorm2d(out_chanel),
            nn.ReLU(),

            nn.MaxPool2d(3),  # Note: using 3 instead of 2
            nn.Dropout(drop)
        )

    def forward(self, x):
        return self.model(x)

    def train_epoch(self, dataloader: DataLoader, criterion: nn.Module, optimizer: torch.optim.Optimizer, device: str):
        self.train()
        total_loss = 0.0
        total_correct = 0
        total = 0

        for features, targets in dataloader:
            features, targets = features.to(device), targets.to(device)

            optimizer.zero_grad()
            logits = self(features)
            loss = criterion(logits, targets)
            loss.backward()
            optimizer.step()

            batch_size = targets.size(0)
            total_loss += loss.item() * batch_size
            preds = logits.argmax(dim=1)
            total_correct += (preds == targets).sum().item()
            total += batch_size

        return total_loss / total, total_correct / total


    def evaluate(self, dataloader: DataLoader, criterion: nn.Module, device: str):
        self.eval()
        total_loss = 0.0
        total_correct = 0
        total = 0

        with torch.no_grad():
            for features, targets in dataloader:
                features, targets = features.to(device), targets.to(device)

                logits = self(features)
                loss = criterion(logits, targets)

                batch_size = targets.size(0)
                total_loss += loss.item() * batch_size
                preds = logits.argmax(dim=1)
                total_correct += (preds == targets).sum().item()
                total += batch_size

        return total_loss / total, total_correct / total


In [None]:
# def train(run: wandb.Run, model: nn.Module, optimizer: torch.optim.Optimizer, criterion: nn.Module, train_loader: DataLoader, num_epochs=8):
#     loss_vals=  []

#     for epoch in range(num_epochs):
#         running_loss = 0.0
#         correct = 0
#         total = 0

#         epoch_loss= []
#         for _, (data, targets) in enumerate(train_loader):
#             data = data.to(device=device)
#             targets = targets.to(device=device)

#             scores = model(data)
#             loss = criterion(scores, targets)

#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()

#             running_loss += loss.item()
#             _, predicted = scores.max(1)
#             total += targets.size(0)
#             correct += predicted.eq(targets).sum().item()

#         epoch_loss = running_loss / len(train_loader)
#         epoch_acc = 100. * correct / total
#         loss_vals.append(epoch_loss)
#         run.log({"Train Loss": epoch_loss, "Train Accuracy": epoch_acc, "epoch": epoch})

#         print(epoch, "Current Loss:", loss , "Acc:" , epoch_acc )
#     return loss_vals

# def evaluate(loader, model):
#     """
#         @returns: (all_preds, all_targets)
#     """
#     all_preds = []
#     all_targets = []
#     running_loss = 0.0
#     correct = 0
#     total = 0

#     model.eval()
#     criterion = nn.CrossEntropyLoss()

#     with torch.no_grad():
#         for x, y in loader:
#             x = x.to(device=device)
#             y = y.to(device=device)

#             scores = model(x)
#             loss = criterion(scores, y)
#             _, pred = scores.max(1)

#             all_preds.extend(pred.cpu().numpy())
#             all_targets.extend(y.cpu().numpy())

#             running_loss += loss.item()
#             total += y.size(0)
#             correct += pred.eq(y).sum().item()

#     accuracy = 100. * correct / total
#     avg_loss = running_loss / len(loader)

#     return all_preds, all_targets, accuracy, avg_loss

In [None]:
def train_with_config(run: wandb.Run, config, train_loader, val_loader, device, num_classes=10,):

    model = baseCNN(
        num_classes=num_classes,
        conv_dropouts=config.get('conv_dropouts', [0.2, 0.2, 0.2, 0.25]),
        linear_dropout=config.get('linear_dropout', 0.2)
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    criterion = nn.CrossEntropyLoss()

    best_state = None
    best_val_loss = float('inf')

    for epoch in range(config.epochs):

        train_loss, train_acc = model.train_epoch( train_loader, criterion, optimizer, device)
        model.eval()
        with torch.no_grad():
            val_loss, val_acc = model.evaluate( val_loader, criterion, device)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        run.log({"epoch": epoch + 1, "train_loss": train_loss, "val_loss": val_loss, "val_accuracy": val_acc, "train_accuracy": train_acc})

    assert best_state is not None

    return best_state, best_val_loss

In [None]:
configs = [
    # {'lr': 0.001, 'conv_dropouts': [0.2, 0.2, 0.2, 0.25], 'linear_dropout': 0.2},
    {'lr': 0.0005, 'conv_dropouts': [0.2, 0.2, 0.2, 0.25], 'linear_dropout': 0.2},

    # {'lr': 0.001, 'conv_dropouts': [0.1, 0.1, 0.15, 0.2], 'linear_dropout': 0.1},
    {'lr': 0.001, 'conv_dropouts': [0.3, 0.3, 0.35, 0.4], 'linear_dropout': 0.3},

    {'lr': 0.001, 'filters': [32, 64, 128, 256]},
]

In [None]:
os.environ['WANDB_API_KEY'] = "71c92645b7421a14f584454ecb6d69f570710ec1"
api_key = os.environ.get('WANDB_API_KEY')
if api_key:
    wandb.login(key=api_key)
    print("Logged in to W&B successfully.")
else:
    raise ValueError("W&B API key not found in environment variables.")

In [None]:
run = wandb.init(project="animal_classification", name="test_hyperparam")
best_state = None
best_val_loss = float('inf')
best_config = None

for i, config in enumerate(configs):
    best_state, best_val_loss = train_with_config(run, config, train_loader, test_loader, device)
    if best_val_loss < best_val_loss:
        best_val_loss = best_val_loss
        best_config = config
    
run.finish()
print("Training complete.")
print(f"Best validation loss: {best_val_loss} for model with config: {best_config}")