# Training

In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import pandas as pd
from torchvision import transforms
from torchvision import models
from torch import nn
from torch import optim
from torcheval.metrics import MulticlassAccuracy, MulticlassF1Score
from PIL import Image
from tqdm import tqdm
import numpy as np

In [None]:
dataset_root = 'C:/DATASETS/AGE-FER'
dataset_imgs_path = os.path.join(dataset_root, 'images-preprocessed')
full_dataset_labels_path = os.path.join(dataset_root, '24-datasets.csv')

labels = ['anger', 'disgust', 'fear', 'happiness', 'sadness', 'surprise', 'neutral']
label_map = { 'anger': 0,
              'disgust': 1,
              'fear': 2,
              'happiness': 3,
              'sadness': 4,
              'surprise': 5,
              'neutral': 6}

dtypes = {
    'dataset': 'category',
    'user_id': 'category',
    'name': str,
    'class': 'category',
    'age': 'Int8',
    'gender':'category' ,
    'race': 'category',
    'perspective': 'category',
    'age_group': 'category',
    'subset': 'category',
    'auto_age': bool,
    'auto_gender': bool,
    'age_group_clean': 'category',
    'gaze': 'category',
    'auto_perspective': bool,
    'key': 'category'}

In [None]:
class FERDataset(Dataset):
    def __init__(self, annotations_file, dtypes, img_dir, transform=None, target_transform=None, subdataset=None):
        self.img_labels = pd.read_csv(annotations_file, dtype=dtypes, sep=',', quotechar='"')
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        
        if subdataset:
            self.img_labels = self.img_labels[self.img_labels['dataset'] == subdataset]

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, self.img_labels.columns.get_loc("name")]).lower()
        image = Image.open(img_path).convert("RGB")
        label = self.img_labels.iloc[idx, self.img_labels.columns.get_loc("class")]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label, self.img_labels.iloc[idx, self.img_labels.columns.get_loc("age_group_clean")]

def display_image_and_label(dataloader, labels):
    train_features, train_labels = next(iter(dataloader))
    print(f"Feature batch shape: {train_features.size()}")
    print(f"Labels batch shape: {train_labels.size()}")
    img = train_features[0].squeeze()
    label = train_labels[0]
    plt.imshow(img, cmap="gray")
    plt.show()
    print(f"Label: {labels[label]}")

def get_model(model_name, num_classes, pretraining_dataset='IMAGENET1K_V1', in_channels=3):
    if model_name == 'ConvNeXt_Small':
        model = models.convnext_small(weights=pretraining_dataset)
        model.classifier[2] = nn.Linear(model.classifier[2].in_features, num_classes)
        model.features[0][0].in_channels = in_channels
    elif model_name == 'Swin_S':
        model = models.swin_s(weights=pretraining_dataset)
        model.head = nn.Linear(model.head.in_features, num_classes)
        model.features[0][0].in_channels = in_channels
    
    transform = models.get_weight(model_name + '_Weights.'+pretraining_dataset).transforms()
    
    return model, transform

def get_transform(default_transform, mode='train', augment=True):

    aug_transform = [
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05)),
            transforms.ColorJitter(brightness=0.25, contrast=0.25)] if augment else []

    if mode == 'train':
        return transforms.Compose(aug_transform + [default_transform])
    elif mode == 'test':
        return default_transform
    elif mode == 'augment':
        return transforms.Compose(aug_transform)
    else:
        raise ValueError('Invalid mode')

def train(dataloader, model, loss_fn, optimizer, verbose=True, f_logs=None):
    num_batches = len(dataloader)
    size = len(dataloader.dataset)
    model.train()
    train_loss = 0
    num_classes = dataloader.dataset.img_labels['class'].unique().tolist()
    print(f"Num classes: {len(num_classes)} - {num_classes}", file=f_logs)

    accuracy = MulticlassAccuracy(num_classes=7, average='micro')
    f1_scores = MulticlassF1Score(num_classes=7, average=None)

    for batch, (X, y, age_group) in enumerate(tqdm(dataloader)):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y, age_group)
        train_loss += loss.item()

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # Metrics
        accuracy.update(pred, y)
        f1_scores.update(pred, y)
    
    train_loss /= num_batches
    accuracy = accuracy.compute()
    f1_scores = f1_scores.compute()
    macro_f1_score = np.mean([f1_scores[label_map[label]] for label in num_classes])
    
    if verbose:
        print(f"Avg loss: {train_loss:>8f}. Accuracy: {accuracy:>8f}. Macro F1 Score: {macro_f1_score:>8f}", file=f_logs)
        print('F1 scores.', [f'{label}: {f1_scores[label_map[label]]:>8f}' for label in labels], file=f_logs)
    return train_loss, accuracy, macro_f1_score

def val(dataloader, model, loss_fn, verbose=True, f_logs=None):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    val_loss = 0
    num_classes = dataloader.dataset.img_labels['class'].unique().tolist()
    print(f"Num classes: {len(num_classes)} - {num_classes}", file=f_logs)
    accuracy = MulticlassAccuracy(num_classes=7, average='micro')
    f1_scores = MulticlassF1Score(num_classes=7, average=None)
    with torch.no_grad():
        for batch, (X, y, age_group) in enumerate(tqdm(dataloader)):
            X, y = X.to(device), y.to(device)
            pred = model(X)
            val_loss += loss_fn(pred, y, age_group).item()
            accuracy.update(pred, y)
            f1_scores.update(pred, y)
    val_loss /= num_batches
    accuracy = accuracy.compute()
    f1_scores = f1_scores.compute()
    macro_f1_score = np.mean([f1_scores[label_map[label]] for label in num_classes])
    if verbose:
        print(f"Avg loss: {val_loss:>8f}. Accuracy: {accuracy:>8f}. Macro F1 Score: {macro_f1_score:>8f}", file=f_logs)
        print('F1 scores.', [f'{label}: {f1_scores[label_map[label]]:>8f}' for label in labels], file=f_logs)
    return val_loss, accuracy, macro_f1_score

def epochs_loop(train_dataloader, test_dataloader, model, loss_fn, optimizer, model_name, 
                subdataset, k, epochs_path, model_path, max_epochs, 
                use_patience = True, patience = 5, patience_th = 0.01, metric = 'val_acc', f_logs=None):

    # Create pandas dataframe to store results
    epochs_df = pd.DataFrame(columns=['model', 'subdataset', 'k', 'aug', 'epoch', 'train_loss', 'train_acc', 'train_f1', 'val_loss', 'val_acc', 'val_f1'])

    # Initialize early stopping variables
    counter_patience = 0
    current_value = 1000 if 'loss' in metric else 0
    
    for epoch in range(max_epochs):

        print("-------------------------------", file=f_logs)        
        print(f"Epoch [{epoch+1:>{len(str(max_epochs))}}/{max_epochs}]", file=f_logs)
        
        print("Train:", file=f_logs)
        train_loss, train_acc, train_f1 = train(train_dataloader, model, loss_fn, optimizer, verbose=True, f_logs=f_logs)

        print("Validate:", file=f_logs)
        val_loss, val_acc, val_f1 = val(test_dataloader, model, loss_fn, verbose=True, f_logs=f_logs)

        if f_logs is not None:
            f_logs.flush()
        
        # Store results to dataframe using concat
        epochs_df.loc[len(epochs_df)] = {
            'model': model_name, 
            'subdataset': subdataset, 
            'k': k, 
            'epoch': epoch+1, 
            'train_loss': train_loss, 
            'train_acc': float(train_acc), 
            'train_f1': float(train_f1), 
            'val_loss': val_loss, 
            'val_acc': float(val_acc),
            'val_f1': float(val_f1)}
        
        if use_patience:

            # Select current value
            if metric == 'val_acc':
                value = val_acc
            elif metric == 'val_f1':
                value = val_f1
            elif metric == 'val_loss':
                value = val_loss
            elif metric == 'train_acc':
                value = train_acc
            elif metric == 'train_loss':
                value = train_loss
            elif metric == 'train_f1':
                value = train_f1
            
            if 'loss' in metric:
                # Better value: save model
                if current_value - value > patience_th:
                    counter_patience = 0
                    current_value = value
                    torch.save(model.state_dict(), model_path)
                # Worse value: increase patience counter
                else:
                    counter_patience += 1
            else:
                # Better value: save model
                if value - current_value > patience_th:
                    counter_patience = 0
                    current_value = value
                    torch.save(model.state_dict(), model_path)
                # Worse value: increase patience counter
                else:
                    counter_patience += 1
            
            # Early stopping
            if counter_patience >= patience:
                epoch = epoch - patience
                print(f"Early stopping at epoch {epoch+1}.", file=f_logs)
                break
    
    # Save model if not using patience
    if not use_patience:
        torch.save(model.state_dict(), model_path)
    
    # Save dataframe to csv
    epochs_df.to_csv(epochs_path, index=False)

## Run

In [None]:
logs_file = 'training-logs.txt'
counter = 1
while os.path.exists(logs_file):
    logs_file = f'training-logs_{counter}.txt'
    counter += 1

with open(logs_file, 'w') as f_logs:

    training_dir = os.path.join(dataset_root, 'training')
    chosen_models = ['ConvNeXt_Small', 'Swin_S']
    chosen_augmentation = [True]
    chosen_balancing = ['weighted-loss']
    epochs = 20
    metric = 'val_f1'
    patience_th = 0.01
    patience = 5

    # Get distinct datasets
    df = pd.read_csv(full_dataset_labels_path, dtype=dtypes, sep=',', quotechar='"')
    subdatasets = list(df['dataset'].unique())

    # Set device
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"Using {device} device", file=f_logs)

    # Create folders
    if not os.path.exists(training_dir):
        os.makedirs(training_dir)
    models_dir = os.path.join(training_dir, 'models')
    if not os.path.exists(models_dir):
        os.makedirs(models_dir)
    epochs_dir = os.path.join(training_dir, 'epochs')
    if not os.path.exists(epochs_dir):
        os.makedirs(epochs_dir)

    for subdataset in subdatasets:
        for k in range(1, 6):
            for model_name in chosen_models:
                print(f"Model: {model_name}, Subdataset: {subdataset}, CV: {k}.", file=f_logs)
                
                base_name = f'{model_name}_{subdataset}_{k}'
                model_path = os.path.join(models_dir, f'{base_name}.pth')
                epochs_path = os.path.join(epochs_dir, f'{base_name}.csv')
                
                if not os.path.exists(epochs_path):

                    # Define model, load pre-trained weights and get default transform
                    model, default_transform = get_model(model_name, len(labels), 'IMAGENET1K_V1')
                    model = model.to(device)

                    # Define optimizer
                    optimizer = optim.Adam(model.parameters(), lr=1e-4)

                    # Load the training dataset
                    training_data = FERDataset(
                        os.path.join(dataset_root, 'cv-labels', f'19-datasets_train_{k}.csv'),
                        dtypes,
                        dataset_imgs_path,
                        transform=get_transform(default_transform, mode='train', augment=True),
                        target_transform=lambda label: label_map[label],
                        subdataset=subdataset)

                    # Load the test dataset
                    test_data = FERDataset(
                        os.path.join(dataset_root, 'cv-labels', f'19-datasets_test_{k}.csv'),
                        dtypes,
                        dataset_imgs_path,
                        transform=default_transform,
                        target_transform=lambda label: label_map[label], 
                        subdataset=subdataset)

                    # Count class counts for the dataset
                    counts = training_data.img_labels.groupby(['class', 'age_group_clean']).size().unstack(fill_value=0).to_dict()

                    # Get minimum class count, greater than 0
                    class_counts = training_data.img_labels['class'].value_counts()
                    class_counts = class_counts.reindex(labels)

                    # Get minimum class count, greater than 0
                    min_class_count = class_counts[class_counts > 0].min()

                    # Define loss function
                    weights = torch.tensor([min_class_count/class_counts[i] if class_counts[i] > 0 else .0 for i in range(len(labels))], dtype=torch.float32).to(device)
                    print(f"Weights: {weights.tolist()}", file=f_logs)
                    loss_fn = nn.CrossEntropyLoss(weight=weights)
                    sampler = None

                    # Create dataloaders
                    train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
                    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

                    # Run training loop
                    f_logs.flush()
                    epochs_loop(train_dataloader, test_dataloader, model, loss_fn, optimizer, model_name, 
                                subdataset, k, True, epochs_path, model_path, 
                                epochs, use_patience=True, patience=patience, 
                                patience_th=patience_th, metric=metric, f_logs=f_logs)
                         

## Check training loss evolution

In [None]:
training_dir = os.path.join(dataset_root, 'training')
epochs_dir = os.path.join(training_dir, 'epochs')
all_epochs_file = os.path.join(epochs_dir, 'all_trainings.csv')
chosen_models = ['ConvNeXt_Small', 'Swin_S']

# Get distinct datasets
df = pd.read_csv(full_dataset_labels_path, dtype=dtypes, sep=',', quotechar='"')
subdatasets = list(df['dataset'].unique())

### Merge all epoch logs into one CSV

In [None]:
# Create pandas dataframe to store results
all_epochs_df = pd.DataFrame(columns=['model', 'subdataset', 'k', 'epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc'])

for subdataset in subdatasets:
    for k in range(1, 6):
        for model_name in chosen_models:

            # CSV file name
            base_name = f'{model_name}_{subdataset}_{k}'
            epochs_path = os.path.join(epochs_dir, f'{base_name}.csv')

            # Load dataframe
            df = pd.read_csv(epochs_path, sep=',', quotechar='"')

            # Append to dataframe
            all_epochs_df = pd.concat([all_epochs_df, df], ignore_index=True)

# Save results
all_epochs_df.to_csv(all_epochs_file, index=False)
