In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import torch, sys
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from pathlib import Path
from tqdm import tqdm

# Check if running in Google Colab
IN_COLAB = 'google.colab' in sys.modules

# Set dataset path accordingly
if IN_COLAB:
    ! git clone --branch refactored https://github.com/MrKiwix/IAPR-project.git
    %cd IAPR-project
    from google.colab import drive
    drive.mount('/content/drive')
    ROOT_DIR = Path('/content/drive/MyDrive/IAPR')
else:
    ROOT_DIR = Path('./')

# Training Notebook

Function calls to start a training.

This assumes the following:
- Training data is present at the desired location
- The CSV label information is also created

> Warning: this exports the best model weights and will therefore erase the previous one with the same name

We start with the constant and transformation setup:

In [None]:
# Constants
NUM_CLASSES = 13
IMG_SIZE = (120, 180) # (height, width)
BATCH_SIZE = 32
NUM_EPOCHS = 60

# Path to dataset and csv label
label_csv  = ROOT_DIR / Path("./data/train.csv")
images_dir = ROOT_DIR / Path("./data/train")
alpha_reference = ROOT_DIR / Path("./data/alpha_references/")
synth_dir = ROOT_DIR / Path("./data/synthetic_data/")
# create synth_dir if it doesn't exist
synth_dir.mkdir(parents=True, exist_ok=True)

# Create the model directory if it doesn't exist
best_model_path = ROOT_DIR / Path("./model/best_choco_model.pt")
model_dir = best_model_path.parent
model_dir.mkdir(parents=True, exist_ok=True)

# Computed on the whole training set:
# Dataset mean (R, G, B): [0.6887134909629822, 0.666830837726593, 0.6608285307884216]
# Dataset std  (R, G, B): [0.15740245580673218, 0.1555258184671402, 0.17858198285102844]

# Training and eval transform
train_tf = v2.Compose([
    v2.ToImage(),
    
    v2.Resize(IMG_SIZE, antialias=True), 
        
    v2.RandomHorizontalFlip(p=0.5),           
    v2.RandomVerticalFlip(p=0.5),   
            
    v2.ToDtype(torch.float32, scale=True), 
    v2.Normalize(mean=[0.6887134909629822, 0.666830837726593, 0.6608285307884216], std=[0.15740245580673218, 0.1555258184671402, 0.17858198285102844]),
])

val_tf = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True), 
    v2.Resize(IMG_SIZE, antialias=True),                
    v2.Normalize(mean=[0.6887134909629822, 0.666830837726593, 0.6608285307884216], std=[0.15740245580673218, 0.1555258184671402, 0.17858198285102844]),
])



Now, we can load our dataset to create a training and validation split

In [15]:
from src.data.TrainChocolateDataset import * 
from src.data.SyntheticChocolateDataset import SyntheticChocolateDataset

# We first create a general dataset
train_eval_dataset = ChocolateDataset(
    data_dir=images_dir,
    label_csv=label_csv,
    transform=None, # Since the two split are not using the same transform, we set it to None
    target_transform=LabelToTensor(),
)
# We now split the dataset into training and validation sets
# Split indexes
train_len = int(0.8 * len(train_eval_dataset))
test_len  = len(train_eval_dataset) - train_len
train_idxs, test_idxs = torch.utils.data.random_split(
    range(len(train_eval_dataset)), [train_len, test_len], generator=torch.Generator().manual_seed(42))


training_dataset = Subset(
    ChocolateDataset(images_dir, label_csv, transform=train_tf, target_transform=LabelToTensor()),
    train_idxs)
val_dataset = Subset(
    ChocolateDataset(images_dir, label_csv, transform=val_tf, target_transform=LabelToTensor()),
    test_idxs)

# Synthetic dataset
synth_dataset = SyntheticChocolateDataset(
    background_dir=images_dir,
    alpha_reference_dir=alpha_reference,
    synth_dir=synth_dir,
    original_label_csv=label_csv,
    train_idx=train_idxs,
    per_background=3,
    transform=train_tf,
    target_transform=LabelToTensor(),
)

# We can now merge the two datasets
merged_training_dataset = torch.utils.data.ConcatDataset([training_dataset, synth_dataset])

# Create DataLoaders
num_workers = 0
train_loader = DataLoader(merged_training_dataset, BATCH_SIZE,
                          shuffle=True,  num_workers=num_workers, pin_memory=True)
val_loader  = DataLoader(val_dataset,  BATCH_SIZE,
                          shuffle=False, num_workers=num_workers, pin_memory=True)

# print the size of the datasets
print(f"Training dataset size: {len(training_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Synthetic dataset size: {len(synth_dataset)}")
print(f"Merged training dataset size: {len(merged_training_dataset)}")

Generating the synthetic images in the background directory...
Using 82 as background images


  background_name = row[0]
Generating synthetic images: 100%|██████████| 82/82 [02:20<00:00,  1.71s/it]

Training dataset size: 82
Validation dataset size: 21
Synthetic dataset size: 246
Merged training dataset size: 328





Data is now ready with our loader, let's instantiate the model

In [16]:
import csv # for logging

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Import the model and the training settings
from src.model.ChocoNetwork import ChocoNetwork
from src.training.training import *

model = ChocoNetwork().to(device)
loss = torch.nn.SmoothL1Loss() 
kaggle_loss = ChocolateCountF1Loss() # Custom loss to match the one on the leaderboard
optimizer = get_optimizer(model) # AdamW with weight decay

best_val_f1 = 0.0

# This is done to implement early stopping -> avoid overfitting
patience = 10
no_improvement_epoch = 0

# Let's train this model
with open(f"run_{NUM_EPOCHS}epochs_{BATCH_SIZE}batches.csv", "w", newline="") as csv_file:
    
    # csv
    writer = csv.writer(csv_file)
    writer.writerow(["epoch", "train_loss", "val_loss", "val_f1", "val_mae"])
    
    # training loop
    for epoch in tqdm(range(1, NUM_EPOCHS + 1)):
        
        train_loss = train_epoch(train_loader, model, loss, optimizer, device)
        val_loss, val_f1, val_mae = eval_epoch(val_loader, model, loss, NUM_CLASSES, device)

        # ---- logging ----
        mae_str = "; ".join([f"{m:.2f}" for m in val_mae])
        print(f"Epoch {epoch:02d} | "
            f"train loss L1: {train_loss:.4f} | "
            f"val loss l1: {val_loss:.4f} | "
            f"val custom F1: {val_f1:.4f} | "
            f"val MAE/class: [{mae_str}]")
        
        # log to csv
        writer.writerow([epoch,train_loss, val_loss, val_f1, mae_str])
        csv_file.flush()

        # save best model
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            print(f"New best model found at epoch {epoch} with val F1: {val_f1:.4f}")
            torch.save(model.state_dict(), best_model_path)
            no_improvement_epoch = 0
        else:
            no_improvement_epoch += 1
            if no_improvement_epoch >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

print(f"The training finished at epoch {epoch} with a best val F1 of {best_val_f1:.4f}")


Using device: cuda


  2%|▏         | 1/60 [03:00<2:57:42, 180.72s/it]

Epoch 01 | train loss L1: 0.3270 | val loss l1: 0.2996 | val custom F1: 0.0000 | val MAE/class: [0.48; 0.56; 0.52; 0.37; 0.59; 0.47; 0.78; 0.49; 0.85; 0.34; 0.50; 0.70; 0.47]
New best model found at epoch 1 with val F1: 0.0000


  3%|▎         | 2/60 [05:49<2:47:49, 173.62s/it]

Epoch 02 | train loss L1: 0.2985 | val loss l1: 0.2862 | val custom F1: 0.0000 | val MAE/class: [0.46; 0.61; 0.53; 0.42; 0.64; 0.56; 0.78; 0.56; 0.84; 0.44; 0.51; 0.71; 0.44]


  5%|▌         | 3/60 [08:42<2:44:41, 173.36s/it]

Epoch 03 | train loss L1: 0.2863 | val loss l1: 0.2760 | val custom F1: 0.1457 | val MAE/class: [0.51; 0.62; 0.52; 0.41; 0.64; 0.57; 0.79; 0.63; 0.81; 0.56; 0.56; 0.71; 0.47]
New best model found at epoch 3 with val F1: 0.1457


  7%|▋         | 4/60 [11:30<2:39:59, 171.42s/it]

Epoch 04 | train loss L1: 0.2839 | val loss l1: 0.2873 | val custom F1: 0.0339 | val MAE/class: [0.50; 0.64; 0.59; 0.41; 0.67; 0.51; 0.78; 0.58; 0.84; 0.59; 0.61; 0.79; 0.58]


  7%|▋         | 4/60 [12:14<2:51:27, 183.71s/it]


KeyboardInterrupt: 