In [1]:
import os
import yaml
import json
import sys
import pprint
import gc
import torch
from torch import nn, optim
from modules.util import plant_seed, extract_parameters_from_filename
from modules.load_data import get_test_val_train_indices
from modules.dataset import MultiRepresentationDataset
from models.load_model import load_create_model
from modules.lr_scheduler import CosineScheduler
from torch.utils.data import DataLoader
from models.train import train_model, eval_model

In [None]:

# Ensure reproducibility
seed_worker, g = plant_seed(0)
print(f"os.environ.get('CUBLAS_WORKSPACE_CONFIG'): {os.environ.get('CUBLAS_WORKSPACE_CONFIG')}") # Default is None
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
print(f"os.environ.get('CUBLAS_WORKSPACE_CONFIG'): {os.environ.get('CUBLAS_WORKSPACE_CONFIG')}")

In [None]:
# Get the label map and model config
with open('models/label_map.yaml', 'r') as file:
    label_map = yaml.safe_load(file)

with open('models/model_config.json', 'r') as file:
    model_config = json.load(file)

window_time, periocular_rate, frame_size, frame_channel, gaze_pupil_rate, overlap = extract_parameters_from_filename(model_config['data_type_path'])
model_config['window_time'] = window_time
model_config['periocular_rate'] = periocular_rate
model_config['frame_size'] = frame_size
model_config['frame_channel'] = frame_channel
model_config['gaze_pupil_rate'] = gaze_pupil_rate
model_config['overlap'] = overlap

# Check if the dataset dir exists
dataset_dir = os.path.join(model_config['data_path'], model_config['data_type_path'])
if not os.path.exists(dataset_dir):
    raise FileNotFoundError(f"The dataset directory {dataset_dir} does not exist.")

# Check if the amount of files are enough for Z-fold cross validation
h5_files = sorted([f for f in os.listdir(dataset_dir) if f.endswith('.h5')])
subjects = [f.replace('.h5', '') for f in h5_files]
if len(subjects) < 5 or len(subjects) % model_config['cross_validation_fold'] != 0:
    sys.exit(f"The amount of subjects ({len(subjects)}) is not enough for {model_config['cross_validation_fold']}-fold cross validation.")
else:
    subjects_per_fold = len(subjects) // model_config['cross_validation_fold']
    subjects_folds = [subjects[i * subjects_per_fold:(i + 1) * subjects_per_fold] for i in range(model_config['cross_validation_fold'])]
    pprint.pprint(subjects_folds)

# Check if pretrain model path exists
if model_config['model_path'] is not None:
    sys.exit(f"model_path should be None when training a new model.")

pprint.pprint(model_config)
pprint.pprint(label_map)

In [None]:
total_acc = 0
total_f1 = 0

for fold_idx, test_subjects in enumerate(subjects_folds):
    #==============================================================================
    train_subjects = [subject for subject in subjects if subject not in test_subjects]

    print(f"Training fold {fold_idx + 1} with test subjects: {test_subjects}")
    print(f"Training fold {fold_idx + 1} with train subjects: {train_subjects}")
    
    test_indices, val_indices, train_indices, index_info = get_test_val_train_indices(test_subjects, train_subjects, dataset_dir, model_config['training_proportion'])

    train_dataset = MultiRepresentationDataset(dataset_dir, train_indices, index_info)
    val_dataset = MultiRepresentationDataset(dataset_dir, val_indices, index_info)
    test_dataset = MultiRepresentationDataset(dataset_dir, test_indices, index_info)

    train_loader = DataLoader(train_dataset, batch_size=model_config['batch_size'], shuffle=True, num_workers=6, worker_init_fn=seed_worker, generator=g, pin_memory=True, prefetch_factor=2)
    val_loader = DataLoader(val_dataset, batch_size=model_config['batch_size'], shuffle=False, num_workers=6, worker_init_fn=seed_worker, generator=g, pin_memory=True, prefetch_factor=2)
    test_loader = DataLoader(test_dataset, batch_size=model_config['batch_size'], shuffle=False, num_workers=6, worker_init_fn=seed_worker, generator=g, pin_memory=True, prefetch_factor=2)

    data_loaders = (train_loader, val_loader, test_loader)

    #==============================================================================
    model, vivit_params, ts_transformer_params, train_params, device = load_create_model(model_config, fold_idx, num_classes=len(label_map))

    criterion = nn.CrossEntropyLoss(reduction='none', label_smoothing=model_config['label_smoothing'])

    optimizer = optim.Adam(model.parameters(), lr=1.0, betas=(train_params['beta_1'], train_params['beta_2']))
    scheduler = CosineScheduler(max_epochs=train_params['max_update_epochs'], base_lr=train_params['base_lr'], final_lr=train_params['final_lr'], warmup_epochs=train_params['warmup_epochs'], warmup_begin_lr=train_params['warmup_begin_lr'])
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)

    #==============================================================================
    # Train the model
    tensorboard_writer = train_model(model, criterion, optimizer, scheduler, data_loaders, vivit_params, ts_transformer_params, train_params, model_config, fold_idx)

    # Eval for the best model
    best_model_acc, best_model_f1 = eval_model(model, data_loaders[2], model_config, label_map, tensorboard_writer, 'best')

    # Eval for the last model
    last_model_acc, last_model_f1 = eval_model(model, data_loaders[2], model_config, label_map, tensorboard_writer, 'last')

    test_acc = max(best_model_acc, last_model_acc)
    test_f1 = max(best_model_f1, last_model_f1)

    total_acc += test_acc
    total_f1 += test_f1

In [None]:
# Compute the average accuracy and F1 score after the loop
average_acc = total_acc / model_config['cross_validation_fold']
average_f1 = total_f1 / model_config['cross_validation_fold']

print(f"Average Accuracy across all folds: {average_acc}")
print(f"Average F1 Score across all folds: {average_f1}")

parent_dir = os.path.dirname(model_config['model_path'])

with open(os.path.join(parent_dir, "average_acc_and_f1.txt"), 'w') as f:
    f.write(f"Average Accuracy: {average_acc}\n")
    f.write(f"Average F1 Score: {average_f1}\n")

del seed_worker, g
gc.collect()