## Setup

In [1]:
import gc
import re
import os
import json 
import math
import shutil
import random
import warnings
from os.path import join
from functools import partial
from tqdm.notebook import tqdm
from collections import defaultdict
from operator import methodcaller
from typing import Optional, Literal
from typing import Optional, Literal, Iterator
from itertools import pairwise, starmap, product

import torch
import optuna
import kagglehub 
import numpy as np
import pandas as pd
import polars as pl
from numpy import ndarray
from torch import nn, Tensor
from numpy.linalg import norm
import torch.nn.functional as F
from torch.optim import Optimizer
from pandas import DataFrame as DF
from optuna.trial import TrialState
from sklearn.metrics import f1_score
from optuna.pruners import BasePruner
from optuna.exceptions import TrialPruned
from torch.utils.data import TensorDataset
from scipy.spatial.transform import Rotation
import kaggle_evaluation.cmi_inference_server
from torch.utils.data import DataLoader as DL
from sklearn.model_selection import GroupKFold
from rich.progress import Progress, Task, track
from sklearn.model_selection import train_test_split
from numpy.lib.stride_tricks import sliding_window_view
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.utils.class_weight import compute_class_weight
from torch.optim.lr_scheduler import ConstantLR, LRScheduler, _LRScheduler

from config import *
from model import mk_model
from utils import seed_everything
from preprocessing import get_meta_data
from training import split_dataset, move_cmi_dataset

In [2]:
GATING_INPUT_FEATURES = [
    "bin_mae",
    "reg_mae",
    "y_uncertainty",
    "bin_uncertainty",
    "reg_uncertainty",
    "orient_uncertainty",
]
GATING_MODEL_BATCH_SIZE = 256
N_GATING_MODEL_EPOCHS = 5

In [3]:
seed_everything(SEED)

In [4]:
meta_data = get_meta_data()

## Data

In [13]:
def record_model_outputs(model:nn.Module, data_loader:DL, device:torch.device) -> tuple[Tensor]:
    data:list[tuple[Tensor]] = []
    model = model.eval()
    tof_and_thm_idx = np.concatenate((meta_data["tof_idx"], meta_data["thm_idx"]))
    with torch.no_grad():
        for x, *_ in data_loader:
            x = x.to(device).clone()
            x[:1024 // 2, tof_and_thm_idx] = 0.0
            data.append(model(x))
    data: tuple[Tensor] = tuple(map(torch.concat, zip(*data)))
    return data

def mk_gating_model_dataset(dataset: TensorDataset) -> TensorDataset:
    device = torch.device("cuda")
    data_loader = DL(dataset, batch_size=1024, shuffle=False)
    models_outputs: list[tuple[Tensor]] = []
    for fold_idx in tqdm(range(N_FOLDS), total=N_FOLDS):
        model = mk_model(device=device)
        checkpoint = torch.load(
            join(
                "models",
                f"model_fold_{fold_idx}.pth"
            ),
            map_location=device,
            weights_only=True
        )
        model.load_state_dict(checkpoint)
        models_outputs.append(record_model_outputs(model, data_loader, device))
    models_outputs: tuple[Tensor] = tuple(map(partial(torch.stack, dim=1), zip(*models_outputs)))
    tensors = (*models_outputs, *dataset.tensors[-2:], dataset.tensors[1])
    
    return TensorDataset(*tensors)

def mk_gating_model_dataset_splits() -> dict[str, TensorDataset]:
    cuda_splits = {k: move_cmi_dataset(dataset, torch.device("cuda")) for k, (dataset, _) in split_dataset().items()}
    return {k: mk_gating_model_dataset(dataset) for k, dataset in cuda_splits.items()}

In [14]:
dataset_splits = mk_gating_model_dataset_splits()

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

## Training

In [15]:
def evaluate_gating_model(dataset: Dataset, gating_model: nn.Module) -> dict:
    gating_model = gating_model.eval()
    data_loader = DL(dataset, batch_size=1024, shuffle=False)
    y_preds = []
    y_trues = []
    with torch.no_grad():
        for *gating_inputs, y in data_loader:
            y_preds.append(gating_model(*gating_inputs))
            y_trues.append(y)
    y_pred = torch.argmax(torch.concat(y_preds), dim=1).cpu().numpy()
    y_true = torch.argmax(torch.concat(y_trues), dim=1).cpu().numpy()
    model_is_true = y_pred == y_true
    binary_true = np.isin(y_true, BFRB_INDICES).astype(int)
    binary_pred = np.isin(y_pred, BFRB_INDICES).astype(int)
    metrics = {
        "accuracy": model_is_true.mean().item(),
        "binary_f1": f1_score(binary_true, binary_pred),
    }

    # Collapse non-BFRB gestures into a single class
    collapsed_true = np.where(
        np.isin(y_true, BFRB_INDICES),
        y_true,
        len(BFRB_GESTURES)  # Single non-BFRB class
    )
    collapsed_pred = np.where(
        np.isin(y_pred, BFRB_INDICES),
        y_pred,
        len(BFRB_GESTURES)  # Single non-BFRB class
    )

    # Macro F1 on collapsed classes
    metrics["macro_f1"] = f1_score(collapsed_true, collapsed_pred, average='macro')
    metrics["final_metric"] = (metrics["binary_f1"] + metrics["macro_f1"]) / 2

    return metrics

def train_model_on_single_epoch(data_loader: DL, gating_model: nn.Module, criterion: nn.Module, optimizer: Optimizer) -> dict:
    metrics = defaultdict(float)
    n_samples = 0
    gating_model = gating_model.train()
    for *gating_inputs, y_true in data_loader:
        optimizer.zero_grad()
        y_pred = gating_model(*gating_inputs)
        loss = criterion(y_pred, y_true)
        loss.backward()
        optimizer.step()
        n_samples += gating_inputs[0].shape[0]
        metrics["train_loss"] += loss.item() * gating_inputs[0].shape[0]

    metrics["train_loss"] /= n_samples

    return metrics

def train_model_on_all_epochs(dataset: Dataset, gating_model: nn.Module) -> DF:
    train_loader = DL(dataset, GATING_MODEL_BATCH_SIZE, shuffle=True)
    optimizer = torch.optim.AdamW(gating_model.parameters())
    metrics: list[dict] = []
    criterion = nn.CrossEntropyLoss()
    for epoch in range(N_GATING_MODEL_EPOCHS):
        train_metrics = train_model_on_single_epoch(train_loader, gating_model, criterion, optimizer)
        eval_metrics = evaluate_gating_model(dataset, gating_model)
        metrics.append({"epoch": epoch} | train_metrics | eval_metrics)

    return DF.from_records(metrics)

In [16]:
class MeanGate(nn.Module):
    def ___init__(self, device: torch.device):
        self.device = device

    def forward(self,
            y_preds: Tensor,
            orient_preds: Tensor,
            bin_demos_y_preds: Tensor,
            reg_demos_y_preds: Tensor,
            bin_demos_y_true: Tensor,
            reg_demos_y_true: Tensor
        ) -> Tensor:
        return y_preds.mean(dim=1)

In [17]:
def preds_uncertainty(y_preds: Tensor) -> Tensor:
    """
    y_preds: Tensor[batch, n_folds, y_targets]
    returns: Tensor[batch, n_folds]
    """
    clipped_preds = y_preds.clip(EPSILON, 1.0)
    return  -((clipped_preds * torch.log(clipped_preds)).sum(dim=2))

def mae(y_preds: Tensor, y_true: Tensor) -> Tensor:
    """
    y_preds: Tensor[batch, n_folds, y_targets]
    y_true:  Tensor[batch, y_targets]
    returns: Tensor[batch, n_folds]
    """
    return torch.abs(y_preds - y_true.unsqueeze(1)).mean(dim=2)

class LogisticRegression(nn.Module):
    def __init__(self):
        super().__init__()
        hidden_size = 64
        self.gate = nn.Sequential(
            nn.LazyLinear(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, N_FOLDS),
            nn.Sigmoid(),
        )

    def forward(
            self,
            y_preds: Tensor,
            orient_preds: Tensor,
            bin_demos_y_preds: Tensor,
            reg_demos_y_preds: Tensor,
            bin_demos_y_true: Tensor,
            reg_demos_y_true: Tensor
        ) -> Tensor:      
        
        experts_preds_stats = torch.concatenate((
                preds_uncertainty(y_preds),
                preds_uncertainty(orient_preds),
                preds_uncertainty(bin_demos_y_preds),
                preds_uncertainty(reg_demos_y_preds),
                mae(bin_demos_y_preds, bin_demos_y_true),
                mae(reg_demos_y_preds, reg_demos_y_true),
            ),
            dim=1,
        )
        weights = self.gate(experts_preds_stats)
        weighted_y_preds = torch.einsum("be, bet -> bt", weights, y_preds)
        y_pred = weighted_y_preds / weights.sum(dim=1, keepdim=True)

        return y_pred

In [18]:
def train_and_eval_gating_model(dataset: Dataset, device: torch.device) -> tuple[nn.Module, DF]:
    mean_gate_metrics = evaluate_gating_model(dataset, MeanGate())
    print("mean_gate_metrics:", mean_gate_metrics)
    gating_model = LogisticRegression().to(device)
    base_logistic_gate_metrics = evaluate_gating_model(dataset, gating_model)
    print("base logistic_gate_metrics:", base_logistic_gate_metrics)
    training_metrics = train_model_on_all_epochs(dataset, gating_model)
    print("trained logistic_gate_metrics:", training_metrics.iloc[-1].to_dict())

    return gating_model, training_metrics

In [22]:
device = torch.device("cuda")
print("training gating model")
gating_model, trainin_metrics = train_and_eval_gating_model(dataset_splits["gating_train"], device)
print("eval mixture of experts")
print(evaluate_gating_model(dataset_splits["validation"], gating_model))
print(evaluate_gating_model(dataset_splits["validation"], MeanGate()))

training gating model
mean_gate_metrics: {'accuracy': 0.7242647058823529, 'binary_f1': 0.989247311827957, 'macro_f1': 0.6767708694258678, 'final_metric': 0.8330090906269124}
base logistic_gate_metrics: {'accuracy': 0.7144607843137255, 'binary_f1': 0.9882583170254403, 'macro_f1': 0.6685528668700385, 'final_metric': 0.8284055919477393}
trained logistic_gate_metrics: {'epoch': 4.0, 'train_loss': 0.8646772901217142, 'accuracy': 0.7242647058823529, 'binary_f1': 0.990234375, 'macro_f1': 0.6808694091281057, 'final_metric': 0.8355518920640528}
eval mixture of experts
{'accuracy': 0.727391874180865, 'binary_f1': 0.99375, 'macro_f1': 0.6511029947195968, 'final_metric': 0.8224264973597983}
{'accuracy': 0.727391874180865, 'binary_f1': 0.9947970863683663, 'macro_f1': 0.650693004961457, 'final_metric': 0.8227450456649117}


## Upload model ensemble

In [None]:
torch.save(gating_model.state_dict(), f"models/gating_model.pth")
user_input = input("Upload model ensemble?: ").lower()
if user_input == "yes":
    kagglehub.model_upload(
        handle=join(
            kagglehub.whoami()["username"],
            MODEL_NAME,
            "pyTorch",
            MODEL_VARIATION,
        ),
        local_model_dir="models",
        version_notes=input("Please provide model version notes: ")
    )
elif user_input == "no":
    print("Model has not been uploaded to kaggle.")
else:
    print("User input was not understood, model has not been uploaded to kaggle.")

Model has not been uploaded to kaggle.
