# Tutorial on using the training pipeline for the event-based eye tracking challenge.

In [1]:
import argparse, json, os, mlflow
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from model.BaselineEyeTrackingModel import CNN_GRU
from utils.training_utils import train_epoch, validate_epoch, top_k_checkpoints
from utils.metrics import weighted_MSELoss
from dataset import ThreeETplus_Eyetracking, ScaleLabel, NormalizeLabel, \
    LabelTemporalSubsample, NormalizeLabel, SliceLongEventsToShort, \
    EventSlicesToVoxelGrid, SliceByTimeEventsTargets
import tonic.transforms as transforms
from tonic import SlicedDataset, DiskCachedDataset
from tqdm import tqdm

#### Examplar config file

In [2]:
config_file = 'train_baseline.json'
with open(os.path.join('./configs', config_file), 'r') as f:
    config = json.load(f)
args = argparse.Namespace(**config)

#### Setup mlflow tracking server (local)

In [3]:
mlflow.set_tracking_uri(args.mlflow_path)
mlflow.set_experiment(experiment_name=args.experiment_name)

<Experiment: artifact_location='file:///C:/Users/Junkyy/CV_exercise/thesis/3et_challenge_2025-main/mlruns/617056576973668955', creation_time=1745660511476, experiment_id='617056576973668955', last_update_time=1745660511476, lifecycle_stage='active', name='trial_experiment', tags={}>

## Model and Optimizer Definition

In [4]:
# Define your model, optimizer, and criterion
model = eval(args.architecture)(args).to(args.device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

if args.loss == "mse":
    criterion = nn.MSELoss()
elif args.loss == "weighted_mse":
    criterion = weighted_MSELoss(weights=torch.tensor((args.sensor_width/args.sensor_height, 1)).to(args.device), \
                                    reduction='mean')
else:
    raise ValueError("Invalid loss name")

## Dataloding and Preprocessing

First we define the label transformations

In [5]:
factor = args.spatial_factor # spatial downsample factor
temp_subsample_factor = args.temporal_subsample_factor # downsampling original 100Hz label to 20Hz

# The original labels are spatially downsampled with 'factor', downsampled to 20Hz, and normalized w.r.t width and height to [0,1]
label_transform = transforms.Compose([
    ScaleLabel(factor),
    LabelTemporalSubsample(temp_subsample_factor),
    NormalizeLabel(pseudo_width=640*factor, pseudo_height=480*factor)
])

Then we define the raw event recording and label dataset, the raw events spatial coordinates are also spatially downsampled to 80x60 spatial resolutions.

In [6]:
train_data_orig = ThreeETplus_Eyetracking(save_to=args.data_dir, split="train", \
                transform=transforms.Downsample(spatial_factor=factor), 
                target_transform=label_transform)
val_data_orig = ThreeETplus_Eyetracking(save_to=args.data_dir, split="val", \
                transform=transforms.Downsample(spatial_factor=factor),
                target_transform=label_transform)

Then we slice the event recordings into sub-sequences. The time-window is determined by the sequence length (train_length, val_length) and the temporal subsample factor.

In [7]:
slicing_time_window = args.train_length*int(10000/temp_subsample_factor) #microseconds
train_stride_time = int(10000/temp_subsample_factor*args.train_stride) #microseconds

train_slicer=SliceByTimeEventsTargets(slicing_time_window, overlap=slicing_time_window-train_stride_time, \
                seq_length=args.train_length, seq_stride=args.train_stride, include_incomplete=False)
# the validation set is sliced to non-overlapping sequences
val_slicer=SliceByTimeEventsTargets(slicing_time_window, overlap=0, \
                seq_length=args.val_length, seq_stride=args.val_stride, include_incomplete=False)


After slicing the raw event recordings into sub-sequences, we make each subsequences into your favorite event representation, in this case event voxel-

You could also try other representations with the Tonic library easily.

In [8]:
post_slicer_transform = transforms.Compose([
    SliceLongEventsToShort(time_window=int(10000/temp_subsample_factor), overlap=0, include_incomplete=True),
    EventSlicesToVoxelGrid(sensor_size=(int(640*factor), int(480*factor), 2), \
                            n_time_bins=args.n_time_bins, per_channel_normalize=args.voxel_grid_ch_normaization)
])

We use the Tonic SlicedDataset class to handle the collation of the sub-sequences into batches.

The slicing indices will be cached to disk for faster slicing in the future, for the same slice parameters.

In [9]:
train_data = SlicedDataset(train_data_orig, train_slicer, transform=post_slicer_transform, metadata_path=f"./metadata/3et_train_tl_{args.train_length}_ts{args.train_stride}_ch{args.n_time_bins}")
val_data = SlicedDataset(val_data_orig, val_slicer, transform=post_slicer_transform, metadata_path=f"./metadata/3et_val_vl_{args.val_length}_vs{args.val_stride}_ch{args.n_time_bins}")

Metadata read from ./metadata/3et_train_tl_30_ts15_ch3\slice_metadata.h5.
Metadata read from ./metadata/3et_val_vl_30_vs30_ch3\slice_metadata.h5.


Cache the preprocessed data to disk to speed up training. The first epoch will be slow, but the following epochs will be fast. This will consume certain disk space.

In [10]:
train_data = DiskCachedDataset(train_data, cache_path=f'./cached_dataset/train_tl_{args.train_length}_ts{args.train_stride}_ch{args.n_time_bins}')
val_data = DiskCachedDataset(val_data, cache_path=f'./cached_dataset/val_vl_{args.val_length}_vs{args.val_stride}_ch{args.n_time_bins}')


Finally we wrap the dataset with pytorch dataloader

In [11]:
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, \
                            num_workers=int(os.cpu_count()-2), pin_memory=True)
val_loader = DataLoader(val_data, batch_size=args.batch_size, shuffle=False, \
                        num_workers=int(os.cpu_count()-2))


## Define the Training Loop Functionalities

In [17]:
def train(model, train_loader, val_loader, criterion, optimizer, args):
    best_val_loss = float("inf")

    # Training loop
    for epoch in range(args.num_epochs):
        # Wrap train_loader with tqdm for progress bar
        train_pbar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{args.num_epochs}")
        model, train_loss, metrics = train_epoch(model, train_pbar, criterion, optimizer, args)
        mlflow.log_metric("train_loss", train_loss, step=epoch)
        mlflow.log_metrics(metrics['tr_p_acc_all'], step=epoch)
        mlflow.log_metrics(metrics['tr_p_euc_error_all'], step=epoch)

        if args.val_interval > 0 and (epoch + 1) % args.val_interval == 0:
            # Wrap val_loader with tqdm for progress bar
            val_pbar = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{args.num_epochs}")
            val_loss, val_metrics = validate_epoch(model, val_pbar, criterion, args)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                # save the new best model to MLflow artifact
                torch.save(model.state_dict(), os.path.join(mlflow.get_artifact_uri(), \
                            f"model_best_ep{epoch}_val_loss_{val_loss:.4f}.pth"))
                
                # Keep only top K checkpoints
                top_k_checkpoints(args, mlflow.get_artifact_uri())
                
            print(f"[Validation] at Epoch {epoch+1}/{args.num_epochs}: Val Loss: {val_loss:.4f}")
            mlflow.log_metric("val_loss", val_loss, step=epoch)
            mlflow.log_metrics(val_metrics['val_p_acc_all'], step=epoch)
            mlflow.log_metrics(val_metrics['val_p_euc_error_all'], step=epoch)
        # Print progress
        print(f"Epoch {epoch+1}/{args.num_epochs}: Train Loss: {train_loss:.4f}")

    return model

In [12]:
def train(model, train_loader, val_loader, criterion, optimizer, args):
    best_val_loss = float("inf")

    # Training loop
    for epoch in range(args.num_epochs):
        # Wrap train_loader with tqdm for progress bar
        train_pbar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{args.num_epochs}")
        model, train_loss, metrics = train_epoch(model, train_pbar, criterion, optimizer, args)
        mlflow.log_metric("train_loss", train_loss, step=epoch)
        mlflow.log_metrics(metrics['tr_p_acc_all'], step=epoch)
        mlflow.log_metrics(metrics['tr_p_euc_error_all'], step=epoch)

        if args.val_interval > 0 and (epoch + 1) % args.val_interval == 0:
            # Wrap val_loader with tqdm for progress bar
            val_pbar = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{args.num_epochs}")
            val_loss, val_metrics = validate_epoch(model, val_pbar, criterion, args)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                # save the new best model to MLflow artifact

                # Save to a local file with unique name
                model_filename = f"model_best_ep{epoch}_val_loss_{val_loss:.4f}.pth"
                torch.save(model.state_dict(), model_filename)
                # Log it to MLflow
                mlflow.log_artifact(model_filename)

                # Keep only top K checkpoints
                top_k_checkpoints(args, mlflow.get_artifact_uri())
                
            print(f"[Validation] at Epoch {epoch+1}/{args.num_epochs}: Val Loss: {val_loss:.4f}")
            mlflow.log_metric("val_loss", val_loss, step=epoch)
            mlflow.log_metrics(val_metrics['val_p_acc_all'], step=epoch)
            mlflow.log_metrics(val_metrics['val_p_euc_error_all'], step=epoch)
        # Print progress
        print(f"Epoch {epoch+1}/{args.num_epochs}: Train Loss: {train_loss:.4f}")

    return model

In [13]:
from urllib.parse import urlparse
import os

def top_k_checkpoints(args, artifact_uri):
    # Convert artifact URI to local path
    parsed_uri = urlparse(artifact_uri)
    artifact_path = parsed_uri.path
    if os.name == 'nt' and artifact_path.startswith('/'):
        artifact_path = artifact_path[1:]

    # List all .pth files in artifact directory
    model_checkpoints = [f for f in os.listdir(artifact_path) if f.endswith(".pth")]

    # Keep only top-K based on val loss embedded in filename
    if len(model_checkpoints) > args.save_k_best:
        # Sort based on validation loss parsed from filename
        model_checkpoints.sort(key=lambda name: float(name.split("val_loss_")[1].replace(".pth", "")))
        for ckpt_to_remove in model_checkpoints[args.save_k_best:]:
            os.remove(os.path.join(artifact_path, ckpt_to_remove))


## Start Training!

This is the major training loop including validation.

In [None]:
# Start MLflow run
with mlflow.start_run(run_name=args.run_name):
    # dump this training file to MLflow artifact
    # mlflow.log_artifact(__file__) # Disabled for notebook, it is included in with the script

    # Log all hyperparameters to MLflow
    mlflow.log_params(vars(args))
    # also dump the args to a JSON file in MLflow artifact
    with open(os.path.join(mlflow.get_artifact_uri(), "args.json"), 'w') as f:
        json.dump(vars(args), f)

    # Train your model
    model = train(model, train_loader, val_loader, criterion, optimizer, args)

    # Save your model for the last epoch
    torch.save(model.state_dict(), os.path.join(mlflow.get_artifact_uri(), f"model_last_epoch{args.num_epochs}.pth"))


In [14]:
# Start MLflow run
with mlflow.start_run(run_name=args.run_name):
    # Log all hyperparameters to MLflow
    mlflow.log_params(vars(args))

    # Save args to a temporary file and log it as an artifact
    args_path = "args.json"
    with open(args_path, 'w') as f:
        json.dump(vars(args), f)
    mlflow.log_artifact(args_path)

    # Train your model
    model = train(model, train_loader, val_loader, criterion, optimizer, args)

    # Save model state_dict to file and log it
    model_path = f"model_last_epoch{args.num_epochs}.pth"
    torch.save(model.state_dict(), model_path)
    mlflow.log_artifact(model_path)


The git executable must be specified in one of the following ways:
    - be included in your $PATH
    - be set via $GIT_PYTHON_GIT_EXECUTABLE
    - explicitly set via git.refresh(<full-path-to-git-executable>)

All git commands will error until this is rectified.

This initial message can be silenced or aggravated in the future by setting the
$GIT_PYTHON_REFRESH environment variable. Use one of the following values:
    - quiet|q|silence|s|silent|none|n|0: for no message or exception
    - error|e|exception|raise|r|2: for a raised exception

Example:
    export GIT_PYTHON_REFRESH=quiet

Training Epoch 1/20: 100%|██████████| 81/81 [11:07<00:00,  8.24s/it, loss=0.0181] 


Epoch 1/20: Train Loss: 0.1559


Training Epoch 2/20: 100%|██████████| 81/81 [09:18<00:00,  6.89s/it, loss=0.00844]
Validation Epoch 2/20: 100%|██████████| 15/15 [00:43<00:00,  2.88s/it, loss=0.00672]


[Validation] at Epoch 2/20: Val Loss: 0.0140
Epoch 2/20: Train Loss: 0.0088


Training Epoch 3/20: 100%|██████████| 81/81 [09:12<00:00,  6.82s/it, loss=0.00696]


Epoch 3/20: Train Loss: 0.0074


Training Epoch 4/20: 100%|██████████| 81/81 [09:15<00:00,  6.85s/it, loss=0.00324]
Validation Epoch 4/20: 100%|██████████| 15/15 [00:45<00:00,  3.03s/it, loss=0.014]  


[Validation] at Epoch 4/20: Val Loss: 0.0145
Epoch 4/20: Train Loss: 0.0063


Training Epoch 5/20: 100%|██████████| 81/81 [08:39<00:00,  6.42s/it, loss=0.0078] 


Epoch 5/20: Train Loss: 0.0054


Training Epoch 6/20: 100%|██████████| 81/81 [09:35<00:00,  7.10s/it, loss=0.00452]
Validation Epoch 6/20: 100%|██████████| 15/15 [00:46<00:00,  3.11s/it, loss=0.00963]


[Validation] at Epoch 6/20: Val Loss: 0.0145
Epoch 6/20: Train Loss: 0.0051


Training Epoch 7/20: 100%|██████████| 81/81 [09:23<00:00,  6.96s/it, loss=0.0024] 


Epoch 7/20: Train Loss: 0.0045


Training Epoch 8/20: 100%|██████████| 81/81 [09:00<00:00,  6.67s/it, loss=0.00215]
Validation Epoch 8/20: 100%|██████████| 15/15 [00:42<00:00,  2.86s/it, loss=0.00916]


[Validation] at Epoch 8/20: Val Loss: 0.0138
Epoch 8/20: Train Loss: 0.0041


Training Epoch 9/20: 100%|██████████| 81/81 [09:24<00:00,  6.97s/it, loss=0.00188]


Epoch 9/20: Train Loss: 0.0037


Training Epoch 10/20: 100%|██████████| 81/81 [08:54<00:00,  6.60s/it, loss=0.00562]
Validation Epoch 10/20: 100%|██████████| 15/15 [00:40<00:00,  2.67s/it, loss=0.00944]


[Validation] at Epoch 10/20: Val Loss: 0.0145
Epoch 10/20: Train Loss: 0.0035


Training Epoch 11/20: 100%|██████████| 81/81 [08:32<00:00,  6.33s/it, loss=0.00459]


Epoch 11/20: Train Loss: 0.0034


Training Epoch 12/20: 100%|██████████| 81/81 [08:48<00:00,  6.52s/it, loss=0.00236]
Validation Epoch 12/20: 100%|██████████| 15/15 [00:41<00:00,  2.76s/it, loss=0.00728]


[Validation] at Epoch 12/20: Val Loss: 0.0147
Epoch 12/20: Train Loss: 0.0029


Training Epoch 13/20: 100%|██████████| 81/81 [09:28<00:00,  7.02s/it, loss=0.00348]


Epoch 13/20: Train Loss: 0.0028


Training Epoch 14/20: 100%|██████████| 81/81 [08:43<00:00,  6.46s/it, loss=0.00233]
Validation Epoch 14/20: 100%|██████████| 15/15 [00:40<00:00,  2.70s/it, loss=0.0046] 


[Validation] at Epoch 14/20: Val Loss: 0.0142
Epoch 14/20: Train Loss: 0.0025


Training Epoch 15/20: 100%|██████████| 81/81 [08:35<00:00,  6.36s/it, loss=0.00257]


Epoch 15/20: Train Loss: 0.0024


Training Epoch 16/20: 100%|██████████| 81/81 [08:31<00:00,  6.32s/it, loss=0.00166]
Validation Epoch 16/20: 100%|██████████| 15/15 [00:40<00:00,  2.70s/it, loss=0.00536]


[Validation] at Epoch 16/20: Val Loss: 0.0134
Epoch 16/20: Train Loss: 0.0021


Training Epoch 17/20: 100%|██████████| 81/81 [09:21<00:00,  6.93s/it, loss=0.00129]


Epoch 17/20: Train Loss: 0.0020


Training Epoch 18/20: 100%|██████████| 81/81 [09:24<00:00,  6.96s/it, loss=0.00134]
Validation Epoch 18/20: 100%|██████████| 15/15 [00:42<00:00,  2.83s/it, loss=0.00648]


[Validation] at Epoch 18/20: Val Loss: 0.0145
Epoch 18/20: Train Loss: 0.0019


Training Epoch 19/20: 100%|██████████| 81/81 [09:08<00:00,  6.77s/it, loss=0.00158]


Epoch 19/20: Train Loss: 0.0019


Training Epoch 20/20: 100%|██████████| 81/81 [09:21<00:00,  6.93s/it, loss=0.00129] 
Validation Epoch 20/20: 100%|██████████| 15/15 [00:42<00:00,  2.83s/it, loss=0.00507]


[Validation] at Epoch 20/20: Val Loss: 0.0142
Epoch 20/20: Train Loss: 0.0016


In [None]:
# Start MLflow run
with mlflow.start_run(run_name=args.run_name):
    # Log all hyperparameters to MLflow
    mlflow.log_params(vars(args))

    # Save args to a temporary file and log it as an artifact
    args_path = "args.json"
    with open(args_path, 'w') as f:
        json.dump(vars(args), f)
    mlflow.log_artifact(args_path)

    # Train your model
    model = train(model, train_loader, val_loader, criterion, optimizer, args)

    # Save model state_dict to file and log it
    model_path = f"model_last_epoch{args.num_epochs}.pth"
    torch.save(model.state_dict(), model_path)
    mlflow.log_artifact(model_path)