# Model Training Notebook

This notebook trains the three CNN Architectures (ResNet, DenseNet, ConvNeXt), each under 4 different augmentation strategies.


In [1]:
import torch
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
import numpy as np
from pathlib import Path
import time
import copy
import os
import random

In [2]:
torch.__version__

'2.9.0+cu128'

In [3]:
#data_paths
 
input_path = Path("../data/datasets/trashnet")
input_path2 = Path("../data/datasets/trashvariety")

In [4]:
#function to save models later on

model_dir = "../trained_models"
os.makedirs(model_dir, exist_ok=True)

def save_model(model, model_name, history=None):
    save_path = os.path.join(model_dir, f"{model_name}.pth")

    checkpoint = {
        "model_state_dict": model.state_dict(),
    }

    if history is not None:
        checkpoint["history"] = history

    torch.save(checkpoint, save_path)
    print(f"Model saved")    

In [5]:
seed = 64

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
#calculate class weighted cross entropy loss

from collections import Counter

def compute_class_weights_from_imagefolder(train_root, num_classes):
    ds = datasets.ImageFolder(train_root, transform=transforms.ToTensor())
    counts = Counter(ds.targets)
    total = sum(counts.values())

    weights = [total / counts[i] for i in range(num_classes)]
    weights = torch.tensor(weights, dtype=torch.float)
    weights = weights / weights.sum() * num_classes 
    return weights, ds.classes, counts

tmp_ds = datasets.ImageFolder(input_path / "train")
num_classes = len(tmp_ds.classes)

class_weights, class_names, train_counts = compute_class_weights_from_imagefolder(
    input_path / "train",
    num_classes
)

class_weights = class_weights.to(device)

criterion_weighted = torch.nn.CrossEntropyLoss(weight=class_weights)

print("Classes:", class_names)
print("Train counts:", train_counts)
print("Class weights:", class_weights)

Classes: ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
Train counts: Counter({3: 475, 1: 400, 4: 385, 2: 328, 0: 322, 5: 109})
Class weights: tensor([0.8270, 0.6657, 0.8119, 0.5606, 0.6917, 2.4431], device='cuda:0')


## Data Augmentation data transformations

In [8]:
# baseline data transforms

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

data_transforms = {

    "train": transforms.Compose([
        transforms.RandomResizedCrop(224),  
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(),
        normalize
    ]),
    "val": transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        normalize
    ]),    
}

image_datasets = {
    "train": datasets.ImageFolder(input_path / "train", data_transforms["train"]),
    "val": datasets.ImageFolder(input_path / "val", data_transforms["val"]),
    "test": datasets.ImageFolder(input_path / "test", data_transforms["val"]),
}


dataloaders = {
    "train": torch.utils.data.DataLoader(image_datasets["train"], batch_size=32, shuffle=True, num_workers=4),
    "val": torch.utils.data.DataLoader(image_datasets["val"], batch_size=32, shuffle=False, num_workers=4),
    "test": torch.utils.data.DataLoader(image_datasets["test"], batch_size=32, shuffle=False, num_workers=4),
}   


dataset_size = {
    "train": len(image_datasets["train"]),
    "val": len(image_datasets["val"]),
}

num_classes = len((image_datasets["train"]).classes)

In [9]:
#geometric augmentation

data_transforms_geo = {
    "train": transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.9, 1.1), shear=5),
        transforms.ToTensor(),
        normalize
    ]),
    "val": transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        normalize
    ]), 
}

image_datasets_geo = {
    "train": datasets.ImageFolder(input_path / "train", data_transforms_geo["train"]),
    "val": datasets.ImageFolder(input_path / "val", data_transforms_geo["val"]),
    "test": datasets.ImageFolder(input_path / "test", data_transforms_geo["val"]),
}

dataloaders_geo = {
    "train": torch.utils.data.DataLoader(image_datasets_geo["train"], batch_size=32, shuffle=True, num_workers=4),
    "val": torch.utils.data.DataLoader(image_datasets_geo["val"], batch_size=32, shuffle=False, num_workers=4),
    "test": torch.utils.data.DataLoader(image_datasets_geo["test"], batch_size=32, shuffle=False, num_workers=4),
}

dataset_size_geo = {
    "train": len(image_datasets_geo["train"]),
    "val": len(image_datasets_geo["val"]),
    "test": len(image_datasets_geo["test"]),
}

num_classes_geo = len(image_datasets_geo["train"].classes)


In [10]:
#photometric
 
data_transforms_photo = {
    "train": transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1),
        transforms.RandomGrayscale(p=0.1),
        transforms.ToTensor(),
        normalize
    ]),
    "val": transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        normalize
    ]), 
}

image_datasets_photo = {
    "train": datasets.ImageFolder(input_path / "train", data_transforms_photo["train"]),
    "val": datasets.ImageFolder(input_path / "val", data_transforms_photo["val"]),
    "test": datasets.ImageFolder(input_path / "test", data_transforms_photo["val"]),
}

dataloaders_photo = {
    "train": torch.utils.data.DataLoader(image_datasets_photo["train"], batch_size=32, shuffle=True, num_workers=4),
    "val": torch.utils.data.DataLoader(image_datasets_photo["val"], batch_size=32, shuffle=False, num_workers=4),
    "test": torch.utils.data.DataLoader(image_datasets_photo["test"], batch_size=32, shuffle=False, num_workers=4),
}

dataset_size_photo = {
    "train": len(image_datasets_photo["train"]),
    "val": len(image_datasets_photo["val"]),
    "test": len(image_datasets_photo["test"]),
}

num_classes_photo = len(image_datasets_photo["train"].classes)


In [11]:
#aug mix
 
data_transforms_mix = {
    "train": transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.9, 1.1), shear=5),
        transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1),
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        normalize,
        transforms.RandomErasing(p=0.5, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random')
    ]),
    "val": transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        normalize
    ]), 
}

image_datasets_mix = {
    "train": datasets.ImageFolder(input_path / "train", data_transforms_mix["train"]),
    "val": datasets.ImageFolder(input_path / "val", data_transforms_mix["val"]),
    "test": datasets.ImageFolder(input_path / "test", data_transforms_mix["val"]),
}

dataloaders_mix = {
    "train": torch.utils.data.DataLoader(image_datasets_mix["train"], batch_size=32, shuffle=True, num_workers=4),
    "val": torch.utils.data.DataLoader(image_datasets_mix["val"], batch_size=32, shuffle=False, num_workers=4),
    "test": torch.utils.data.DataLoader(image_datasets_mix["test"], batch_size=32, shuffle=False, num_workers=4),
}

dataset_size_mix = {
    "train": len(image_datasets_mix["train"]),
    "val": len(image_datasets_mix["val"]),
    "test": len(image_datasets_mix["test"]),
}

num_classes_mix = len(image_datasets_mix["train"].classes)


## Model Training function

In [12]:
def train_model(model, dataloaders, dataset_size, device,
                criterion, optimizer, num_epochs=25):

    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    history = {
        "train_loss": [],
        "train_acc": [],
        "val_loss": [],
        "val_acc": [],
    }

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        print("-" * 10)

        for phase in ["train", "val"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0.0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)

                    loss = criterion(outputs, labels)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels).item()

            epoch_loss = running_loss / dataset_size[phase]
            epoch_acc = running_corrects / dataset_size[phase]

            print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

            if phase == "train":
                history["train_loss"].append(epoch_loss)
                history["train_acc"].append(epoch_acc)
            else:
                history["val_loss"].append(epoch_loss)
                history["val_acc"].append(epoch_acc)

            if phase == "val" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f"Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    print(f"Best val Acc: {best_acc:.4f}")

    model.load_state_dict(best_model_wts)
    return model, history


## Multi-run training (5 seeds)

Run each architecture (ResNet50, DenseNet121, ConvNeXt-tiny) under each augmentation strategy (baseline, photo, geo, mixed) for 5 different seeds. Save each trained model with the seed in the filename and collect best validation accuracies.

In [None]:
seeds = [64, 128, 256, 512, 1024]
archs = ['resnet50', 'densenet121', 'convnext_tiny']
augs = {
    'baseline': (dataloaders, dataset_size),
    'photo': (dataloaders_photo, dataset_size_photo),
    'geo': (dataloaders_geo, dataset_size_geo),
    'mixed': (dataloaders_mix, dataset_size_mix),
}

results = {arch: {k: [] for k in augs.keys()} for arch in archs}

for seed in seeds:
    print(f"\n--- Starting run with seed {seed} ---")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    for aug_name, (dl, ds) in augs.items():
        for arch in archs:
            print(f"Training {arch} with {aug_name} augmentations (seed {seed})")

            if arch == 'resnet50':
                model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
                model.fc = nn.Linear(model.fc.in_features, num_classes)
            elif arch == 'densenet121':
                model = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
                model.classifier = nn.Linear(model.classifier.in_features, num_classes)
            elif arch == 'convnext_tiny':
                model = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT)
                in_features = model.classifier[2].in_features
                model.classifier[2] = nn.Linear(in_features, num_classes)
            else:
                raise ValueError(f"Unknown architecture: {arch}")

            model = model.to(device)
            optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)

            model, history = train_model(
                model=model,
                dataloaders=dl,
                dataset_size=ds,
                device=device,
                criterion=criterion_weighted,
                optimizer=optimizer,
                num_epochs=25
            )

            save_model(model, f"{arch}_{aug_name}_seed{seed}", history)

            best_val = max(history.get('val_acc', [])) if history.get('val_acc') else None
            results[arch][aug_name].append(best_val)

            del model
            del optimizer
            torch.cuda.empty_cache()