# Tutorial on using the training pipeline for the event-based eye tracking challenge.

In [1]:
import argparse, json, os, mlflow
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from model.BaselineEyeTrackingModel import CNN_GRU
from utils.training_utils import train_epoch, validate_epoch, top_k_checkpoints
from utils.metrics import weighted_MSELoss
from dataset import ThreeETplus_Eyetracking, ScaleLabel, NormalizeLabel, \
    TemporalSubsample, NormalizeLabel, SliceLongEventsToShort, \
    EventSlicesToVoxelGrid, SliceByTimeEventsTargets
import tonic.transforms as transforms
from tonic import SlicedDataset, DiskCachedDataset

  from .autonotebook import tqdm as notebook_tqdm


#### Examplar config file

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

CHANNELS = {
    "T": [32, 64, 128, 256],
    "S": [48, 96, 192, 384],
    "B": [64, 128, 256, 512]
}

KERNELS = [7, 3, 3, 3]
STRIDES = [4, 2, 2, 2]


class BlockSelfAttention(nn.Module):
    def __init__(self, channels, window_size):
        super(BlockSelfAttention, self).__init__()
        self.channels = channels
        self.window_size = window_size # P
        
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(channels, channels)
        self.value = nn.Linear(channels, channels)

        
    def forward(self, x):
        # B, C, H, W
        B, C, H, W = x.size()
        P = self.window_size

        # Check if H and W are divisible by P
        if H % P != 0 or W % P != 0:
            raise ValueError("The height and width of the input must be divisible by the window size.")

        # Apply unfold on the spatial dimensions H and W
        # The shape after unfolding will be (B, C, H/P, W/P, P, P)
        x = x.unfold(2, P, P).unfold(3, P, P)
        
        # Reshape and permute to bring the window blocks to the front and merge them
        # New shape will be (B, H/P, W/P, P*P, C)
        x = x.contiguous().view(B, C, -1, P, P).permute(0, 2, 1, 3, 4)
        
        # Merge the window dimensions and flatten them
        # New shape will be (B * H/P * W/P, P*P, C)
        x = x.contiguous().view(B, -1, P*P, C) # Shape: (num_windows, P*P, C)
        print(x, x.shape)

        # Apply linear transformations to compute Q, K, V
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        # Compute attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (C ** 0.5)  # Shape: (num_windows, P*P, P*P)
        attn_weights = F.softmax(attn_scores, dim=-1)  # Shape: (num_windows, P*P, P*P)

        # Apply attention to V
        attn_output = torch.matmul(attn_weights, V)  # Shape: (num_windows, P*P, C)

        # We reshape back to the original shape
        attn_output = attn_output.view(B, H//P, W//P, P*P, C).permute(0, 3, 4, 1, 2).contiguous()
        attn_output = attn_output.view(B, C, H, W)
        
        return attn_output

class GridAttention(nn.Module):
    def __init__(self, channels, grid_size):
        super(GridAttention, self).__init__()
        self.channels = channels
        self.grid_size = grid_size  # G
        
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(channels, channels)
        self.value = nn.Linear(channels, channels)
        
    def forward(self, x):
        B, C, H, W = x.size()
        G = self.grid_size

        # Ensure H and W are divisible by G
        if H % G != 0 or W % G != 0:
            raise ValueError("The height and width of the input must be divisible by the grid size.")

        # Partition input into a grid
        x = x.contiguous().view(B, C, H//G, G, W//G, G).permute(0, 3, 5, 1, 2, 4).contiguous()
        x = x.view(B, G*G, H//G * W//G, C)

        # Apply Q, K, V transformations
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        # Compute attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.channels ** 0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)

        # Apply attention to V
        attn_output = torch.matmul(attn_weights, V)

        # Reshape back to the original grid structure
        attn_output = attn_output.view(B, G, G, C, H//G, W//G).permute(0, 3, 4, 1, 5, 2).contiguous()
        attn_output = attn_output.view(B, C, H, W)
        
        return attn_output


class RVTBlock(nn.Module):
    def __init__(self, stage, n_time_bins = None, model_type="T"):
        super().__init__()
        self.stage = stage
        self.model_type = model_type
        # If stage is not 0 then we use the previous stage channels
        if stage == 0 and n_time_bins is None:
            raise ValueError("n_time_bins must be provided for stage 0")
        
        input_channels = CHANNELS[model_type][stage-1] if stage > 0 else n_time_bins
        output_channels = CHANNELS[model_type][stage]
        kernel_size = KERNELS[stage]
        stride = STRIDES[stage]

        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2)
        
        # Block SA
        self.block_sa = BlockSelfAttention(channels=output_channels, window_size=7)
        
        # LayerNorm
        self.ln = nn.LayerNorm(output_channels)

        # MLP
        self.mlp1 = nn.Sequential(
            nn.Linear(output_channels, output_channels),
            nn.ReLU(),
            nn.Linear(output_channels, output_channels)
        )

        # Grid SA
        self.grid_sa = GridAttention(channels=output_channels, grid_size=4)
        
        # MLP 2
        self.mlp2 = nn.Sequential(
            nn.Linear(output_channels, output_channels),
            nn.ReLU(),
            nn.Linear(output_channels, output_channels)
        )

        # LSTM
        self.lstm = nn.LSTM(input_size=output_channels, hidden_size=output_channels, num_layers=1, batch_first=True)


    def forward(self, x, c, h):

        print("Input to block SA", x.shape)
        # B, C, H, W
        x_conv = self.conv(x)
        print("Output of conv", x_conv.shape)
        x_bsa = self.block_sa(x_conv)
        x_bsa = self.ln(x_bsa)
        x_bsa = x_bsa + x_conv

        print("Went through block SA", x_bsa.shape)

        x_mlp1 = self.mlp1(x_bsa)
        x_mlp1 = x_mlp1 + x_bsa

        print("Went through MLP1", x_mlp1.shape)

        x_gsa = self.grid_sa(x_mlp1)
        x_gsa = x_gsa + x_mlp1

        print("Went through grid SA", x_gsa.shape)

        x_mlp2 = self.mlp2(x_gsa)
        x_mlp2 = x_mlp2 + x_gsa

        print("Went through MLP2", x_mlp2.shape)

        x_lstm, (c, h) = self.lstm(x_mlp2, (c, h))

        print("Went through LSTM", x_lstm.shape)
        return x_lstm, c, h
    
class EventBasedObjectDetectionModel(nn.Module):
    def __init__(self, n_time_bins, model_type='T'):
        super().__init__()
        self.n_time_bins = n_time_bins
        self.model_type = model_type

        # Define the RVT stages
        self.stages = nn.ModuleList([
            RVTBlock(stage=i, n_time_bins=n_time_bins if i == 0 else None, model_type=model_type) 
            for i in range(4)
        ])

        # Output layer to get coordinates, assuming the output of the last LSTM has 256 channels for the 'T' model
        self.output_layer = nn.Linear(CHANNELS[model_type][-1], 2)

    def forward(self, x):
        print(x.shape)
        # Initialize LSTM states
        lstm_states = [(torch.zeros(1, x.size(0), CHANNELS[self.model_type][i], device=x.device),
                        torch.zeros(1, x.size(0), CHANNELS[self.model_type][i], device=x.device)) for i in range(4)]

        outputs = []

        # Pass the input through each of the RVT stages
        for t in range(x.size(1)):  # iterate over timesteps in the sequence
            print("Getting tensor for timestep", t)
            xt = x[:, t, :, :, :]  # Get the tensor for the current timestep
            for i, stage in enumerate(self.stages):
                print("Stage", i)
                lstm_state = lstm_states[i]
                xt, c, h = stage(xt, *lstm_state)
                lstm_states[i] = (c, h)  # Update LSTM states

            # After the last stage, use the output to predict coordinates
            final_output = self.output_layer(xt[:, -1, :])  # Assuming the last time step from the LSTM's output
            outputs.append(final_output)

        # Convert the list of outputs to a tensor
        coordinates = torch.stack(outputs, dim=1)  # Shape: (batch_size, seq_len, 2)

        return coordinates

In [3]:
config_file = 'sliced_baseline.json'
with open(os.path.join('./configs', config_file), 'r') as f:
    config = json.load(f)
args = argparse.Namespace(**config)

#### Setup mlflow tracking server (local)

In [5]:
mlflow.set_tracking_uri(args.mlflow_path)
mlflow.set_experiment(experiment_name=args.experiment_name)

# Define your model, optimizer, and criterion
model = EventBasedObjectDetectionModel(n_time_bins=3, model_type='T').to(args.device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

if args.loss == "mse":
    criterion = nn.MSELoss()
elif args.loss == "weighted_mse":
    criterion = weighted_MSELoss(weights=torch.tensor((args.sensor_width/args.sensor_height, 1)).to(args.device), \
                                    reduction='mean')
else:
    raise ValueError("Invalid loss name")

print("Model parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))

factor = args.spatial_factor # spatial downsample factor
temp_subsample_factor = args.temporal_subsample_factor # downsampling original 100Hz label to 20Hz

# The original labels are spatially downsampled with 'factor', downsampled to 20Hz, and normalized w.r.t width and height to [0,1]
label_transform = transforms.Compose([
    ScaleLabel(factor),
    TemporalSubsample(temp_subsample_factor),
    NormalizeLabel(pseudo_width=640*factor, pseudo_height=480*factor)
])

train_data_orig = ThreeETplus_Eyetracking(save_to=args.data_dir, split="train", \
                transform=transforms.Downsample(spatial_factor=factor), 
                target_transform=label_transform)
val_data_orig = ThreeETplus_Eyetracking(save_to=args.data_dir, split="val", \
                transform=transforms.Downsample(spatial_factor=factor),
                target_transform=label_transform)

slicing_time_window = args.train_length*int(10000/temp_subsample_factor) #microseconds
train_stride_time = int(10000/temp_subsample_factor*args.train_stride) #microseconds

train_slicer=SliceByTimeEventsTargets(slicing_time_window, overlap=slicing_time_window-train_stride_time, \
                seq_length=args.train_length, seq_stride=args.train_stride, include_incomplete=False)
# the validation set is sliced to non-overlapping sequences
val_slicer=SliceByTimeEventsTargets(slicing_time_window, overlap=0, \
                seq_length=args.val_length, seq_stride=args.val_stride, include_incomplete=False)

post_slicer_transform = transforms.Compose([
    SliceLongEventsToShort(time_window=int(10000/temp_subsample_factor), overlap=0, include_incomplete=True),
    EventSlicesToVoxelGrid(sensor_size=(int(640*factor), int(480*factor), 2), \
                            n_time_bins=args.n_time_bins, per_channel_normalize=args.voxel_grid_ch_normaization)
])

train_data = SlicedDataset(train_data_orig, train_slicer, transform=post_slicer_transform, metadata_path=f"./metadata/3et_train_tl_{args.train_length}_ts{args.train_stride}_ch{args.n_time_bins}")
val_data = SlicedDataset(val_data_orig, val_slicer, transform=post_slicer_transform, metadata_path=f"./metadata/3et_val_vl_{args.val_length}_vs{args.val_stride}_ch{args.n_time_bins}")

train_data = DiskCachedDataset(train_data, cache_path=f'./cached_dataset/train_tl_{args.train_length}_ts{args.train_stride}_ch{args.n_time_bins}')
val_data = DiskCachedDataset(val_data, cache_path=f'./cached_dataset/val_vl_{args.val_length}_vs{args.val_stride}_ch{args.n_time_bins}')

train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, \
                            num_workers=int(os.cpu_count()-2), pin_memory=True)
val_loader = DataLoader(val_data, batch_size=args.batch_size, shuffle=False, \
                        num_workers=int(os.cpu_count()-2))
def train(model, train_loader, val_loader, criterion, optimizer, args):
    best_val_loss = float("inf")

    # Training loop
    for epoch in range(args.num_epochs):
        model, train_loss, metrics = train_epoch(model, train_loader, criterion, optimizer, args)
        mlflow.log_metric("train_loss", train_loss, step=epoch)
        mlflow.log_metrics(metrics['tr_p_acc_all'], step=epoch)
        mlflow.log_metrics(metrics['tr_p_error_all'], step=epoch)

        if args.val_interval > 0 and (epoch + 1) % args.val_interval == 0:
            val_loss, val_metrics = validate_epoch(model, val_loader, criterion, args)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                # save the new best model to MLflow artifact with 3 decimal places of validation loss in the file name
                torch.save(model.state_dict(), os.path.join(mlflow.get_artifact_uri(), \
                            f"model_best_ep{epoch}_val_loss_{val_loss:.4f}.pth"))
                
                # DANGER Zone, this will delete files (checkpoints) in MLflow artifact
                top_k_checkpoints(args, mlflow.get_artifact_uri())
                
            print(f"[Validation] at Epoch {epoch+1}/{args.num_epochs}: Val Loss: {val_loss:.4f}")
            mlflow.log_metric("val_loss", val_loss, step=epoch)
            mlflow.log_metrics(val_metrics['val_p_acc_all'], step=epoch)
            mlflow.log_metrics(val_metrics['val_p_error_all'], step=epoch)
        # Print progress
        print(f"Epoch {epoch+1}/{args.num_epochs}: Train Loss: {train_loss:.4f}")

    return model
# Start MLflow run
with mlflow.start_run(run_name=args.run_name):
    # dump this training file to MLflow artifact
    # mlflow.log_artifact(__file__) # Disabled for notebook, it is included in with the script

    # Log all hyperparameters to MLflow
    mlflow.log_params(vars(args))
    # also dump the args to a JSON file in MLflow artifact
    with open(os.path.join(mlflow.get_artifact_uri(), "args.json"), 'w') as f:
        json.dump(vars(args), f)

    # Train your model
    model = train(model, train_loader, val_loader, criterion, optimizer, args)

    # Save your model for the last epoch
    torch.save(model.state_dict(), os.path.join(mlflow.get_artifact_uri(), f"model_last_epoch{args.num_epochs}.pth"))


Model parameters:  1969090
Metadata read from ./metadata/3et_train_tl_30_ts15_ch3/slice_metadata.h5.
Metadata read from ./metadata/3et_val_vl_30_vs30_ch3/slice_metadata.h5.
