## Setup

In [1]:
import gc
import re
import os
import json 
import math
import shutil
import random
import warnings
from os.path import join
from functools import partial
from tqdm.notebook import tqdm
from collections import defaultdict
from operator import methodcaller
from typing import Optional, Literal
from typing import Optional, Literal, Iterator
from itertools import pairwise, starmap, product

import torch
import optuna
import kagglehub 
import numpy as np
import pandas as pd
import polars as pl
from numpy import ndarray
from torch import nn, Tensor
from numpy.linalg import norm
import torch.nn.functional as F
from torch.optim import Optimizer
from pandas import DataFrame as DF
from optuna.trial import TrialState
from sklearn.metrics import f1_score
from optuna.pruners import BasePruner
from optuna.exceptions import TrialPruned
from torch.utils.data import TensorDataset
from scipy.spatial.transform import Rotation
import kaggle_evaluation.cmi_inference_server
from torch.utils.data import DataLoader as DL
from sklearn.model_selection import GroupKFold
from rich.progress import Progress, Task, track
from sklearn.model_selection import train_test_split
from numpy.lib.stride_tricks import sliding_window_view
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.utils.class_weight import compute_class_weight
from torch.optim.lr_scheduler import ConstantLR, LRScheduler, _LRScheduler

from config import *
from model import mk_model
from training import CMIDataset
from utils import seed_everything
from preprocessing import get_meta_data

In [2]:
GATING_INPUT_FEATURES = [
    "bin_mae",
    "reg_mae",
    "y_uncertainty",
    "bin_uncertainty",
    "reg_uncertainty",
    "orient_uncertainty",
]
GATING_MODEL_BATCH_SIZE = 256
N_GATING_MODEL_EPOCHS = 5

In [3]:
seed_everything(SEED)

In [4]:
meta_data = get_meta_data()

## Data

In [5]:
def record_target_feature(metrics:defaultdict, y_pred:Tensor, y_true:Tensor, preffix:str):
    y_pred = y_pred.cpu().numpy()
    y_true = y_true.cpu().numpy()
    for target_col_idx in range(y_pred.shape[1]):
        metrics[preffix + "_pred_" + str(target_col_idx)].append(y_pred[:, target_col_idx])
        metrics[preffix + "_true_" + str(target_col_idx)].append(y_true[:, target_col_idx])

def record_model_outputs(model:nn.Module, data_loader:DL, device:torch.device, seq_meta_data:DF) -> DF:
    metrics:dict[list[ndarray]] = defaultdict(list)
    model = model.eval()
    tof_and_thm_idx = np.concatenate((meta_data["tof_idx"], meta_data["thm_idx"]))
    with torch.no_grad():
        for batch_x, batch_y, batch_orient_y, batch_bin_demos_y, batch_reg_demos_y, idx in data_loader:
            batch_x = batch_x.to(device).clone()
            # batch_x[:VALIDATION_BATCH_SIZE // 2, tof_and_thm_idx] = 0.0
            batch_y = batch_y.to(device)
            outputs, orient_outputs, bin_demos_output, reg_demos_output = model(batch_x)

            record_target_feature(metrics, outputs, batch_y, "y")
            record_target_feature(metrics, orient_outputs, batch_orient_y, "orient")
            record_target_feature(metrics, bin_demos_output, batch_bin_demos_y, "bin")
            record_target_feature(metrics, batch_reg_demos_y, reg_demos_output, "reg")
            metrics["sequence_id"].append(seq_meta_data["sequence_id"].iloc[idx].values)

    metrics = {k: np.concat(v) for k, v in metrics.items()}

    return DF(metrics)

In [6]:
EPSILON = 1e-12

# def post_process_df(df:DF) -> DF:
#     return (
#         df
#         .pipe(preds_uncertainty, "y")
#         .pipe(preds_uncertainty, "orient")
#         .pipe(preds_uncertainty, "bin")
#         .pipe(preds_uncertainty, "reg")
#         .pipe(mae, "y")
#         .pipe(mae, "orient")
#         .pipe(mae, "bin")
#         .pipe(mae, "reg")
#     )

def record_models_outputs() -> DF:
    device = torch.device("cuda")
    dataset = CMIDataset(device)
    data_loader = DL(dataset, batch_size=1024, shuffle=False)
    seq_meta_data = pd.read_parquet("preprocessed_dataset/sequences_meta_data.parquet")
    dfs = []
    for fold_idx in tqdm(range(N_FOLDS), total=N_FOLDS):
        model = mk_model(device=device)
        checkpoint = torch.load(
            join(
                "models",
                f"model_fold_{fold_idx}.pth"
            ),
            map_location=device,
            weights_only=True
        )
        model.load_state_dict(checkpoint)
        model.eval()
        df = record_model_outputs(model, data_loader, device, seq_meta_data)
        dfs.append(df.assign(fold=fold_idx))

    return pd.concat(dfs) #.pipe(post_process_df)

In [7]:
df = record_models_outputs()

  0%|          | 0/20 [00:00<?, ?it/s]

In [8]:
df.to_parquet("gating_df.parquet")

In [9]:
df.columns.to_list()

['y_pred_0',
 'y_true_0',
 'y_pred_1',
 'y_true_1',
 'y_pred_2',
 'y_true_2',
 'y_pred_3',
 'y_true_3',
 'y_pred_4',
 'y_true_4',
 'y_pred_5',
 'y_true_5',
 'y_pred_6',
 'y_true_6',
 'y_pred_7',
 'y_true_7',
 'y_pred_8',
 'y_true_8',
 'y_pred_9',
 'y_true_9',
 'y_pred_10',
 'y_true_10',
 'y_pred_11',
 'y_true_11',
 'y_pred_12',
 'y_true_12',
 'y_pred_13',
 'y_true_13',
 'y_pred_14',
 'y_true_14',
 'y_pred_15',
 'y_true_15',
 'y_pred_16',
 'y_true_16',
 'y_pred_17',
 'y_true_17',
 'orient_pred_0',
 'orient_true_0',
 'orient_pred_1',
 'orient_true_1',
 'orient_pred_2',
 'orient_true_2',
 'orient_pred_3',
 'orient_true_3',
 'bin_pred_0',
 'bin_true_0',
 'bin_pred_1',
 'bin_true_1',
 'reg_pred_0',
 'reg_true_0',
 'reg_pred_1',
 'reg_true_1',
 'reg_pred_2',
 'reg_true_2',
 'reg_pred_3',
 'reg_true_3',
 'sequence_id',
 'fold']

## Training

In [29]:
# def get_gating_features(df:DF) -> ndarray:
#     return (
#         df
#         .loc[:, GATING_INPUT_FEATURES + ["fold"]]
#         .pivot(columns="fold")
#         .values
#     )

def get_experts_preds(df:DF, preffix:str) -> ndarray:
    n_targets = df.filter(regex=f"^{preffix}_pred_").shape[1] # Ugly
    return (
        df
        .filter(regex=f"^{preffix}_pred_|^fold$")
        .pivot(columns="fold")
        .values
        .reshape(-1, N_FOLDS, n_targets)
    )

def get_gating_targets(df:DF, suffix:str) -> ndarray:
    return (
        df
        .query("fold == 0")
        .filter(regex=suffix + "_true_*")
        .values
    )

class GatingDataset(TensorDataset):
    def __init__(self, df: DF, device: torch.device):
        super().__init__(
            torch.from_numpy(get_experts_preds(df, "y")).to(device),
            torch.from_numpy(get_experts_preds(df, "orient")).to(device),
            torch.from_numpy(get_experts_preds(df, "bin")).to(device),
            torch.from_numpy(get_experts_preds(df, "reg")).to(device),
            torch.from_numpy(get_gating_targets(df, "bin")).to(device),
            torch.from_numpy(get_gating_targets(df, "reg")).to(device),
            torch.from_numpy(get_gating_targets(df, "y")).to(device),
        )

In [30]:
def evaluate_gating_model(dataset: Dataset, gating_model: nn.Module) -> dict:
    gating_model = gating_model.eval()
    data_loader = DL(dataset, batch_size=1024, shuffle=False)
    y_preds = []
    y_trues = []
    with torch.no_grad():
        for *gating_inputs, y in data_loader:
            y_preds.append(gating_model(*gating_inputs))
            y_trues.append(y)
    y_pred = torch.argmax(torch.concat(y_preds), dim=1).cpu().numpy()
    y_true = torch.argmax(torch.concat(y_trues), dim=1).cpu().numpy()
    model_is_true = y_pred == y_true
    binary_true = np.isin(y_true, BFRB_INDICES).astype(int)
    binary_pred = np.isin(y_pred, BFRB_INDICES).astype(int)
    metrics = {
        "accuracy": model_is_true.mean().item(),
        "binary_f1": f1_score(binary_true, binary_pred),
    }

    # Collapse non-BFRB gestures into a single class
    collapsed_true = np.where(
        np.isin(y_true, BFRB_INDICES),
        y_true,
        len(BFRB_GESTURES)  # Single non-BFRB class
    )
    collapsed_pred = np.where(
        np.isin(y_pred, BFRB_INDICES),
        y_pred,
        len(BFRB_GESTURES)  # Single non-BFRB class
    )

    # Macro F1 on collapsed classes
    metrics["macro_f1"] = f1_score(collapsed_true, collapsed_pred, average='macro')
    metrics["final_metric"] = (metrics["binary_f1"] + metrics["macro_f1"]) / 2

    return metrics

def train_model_on_single_epoch(data_loader: DL, gating_model: nn.Module, criterion: nn.Module, optimizer: Optimizer) -> dict:
    metrics = defaultdict(float)
    n_samples = 0
    gating_model = gating_model.train()
    for *gating_inputs, y_true in data_loader:
        optimizer.zero_grad()
        y_pred = gating_model(*gating_inputs)
        loss = criterion(y_pred, y_true)
        loss.backward()
        optimizer.step()
        n_samples += gating_inputs[0].shape[0]
        metrics["train_loss"] += loss.item() * gating_inputs[0].shape[0]

    metrics["train_loss"] /= n_samples

    return metrics

def train_model_on_all_epochs(dataset: Dataset, gating_model: nn.Module) -> DF:
    train_loader = DL(dataset, GATING_MODEL_BATCH_SIZE, shuffle=True)
    optimizer = torch.optim.AdamW(gating_model.parameters())
    metrics: list[dict] = []
    criterion = nn.CrossEntropyLoss()
    for epoch in range(N_GATING_MODEL_EPOCHS):
        train_metrics = train_model_on_single_epoch(train_loader, gating_model, criterion, optimizer)
        eval_metrics = evaluate_gating_model(dataset, gating_model)
        metrics.append({"epoch": epoch} | train_metrics | eval_metrics)

    return DF.from_records(metrics)

In [31]:
class MeanGate(nn.Module):
    def ___init__(self, device: torch.device):
        self.device = device

    def forward(self,
            y_preds: Tensor,
            orient_preds: Tensor,
            bin_demos_y_preds: Tensor,
            reg_demos_y_preds: Tensor,
            bin_demos_y_true: Tensor,
            reg_demos_y_true: Tensor
        ) -> Tensor:
        return y_preds.mean(dim=1)

In [44]:
def preds_uncertainty(y_preds: Tensor) -> Tensor:
    """
    y_preds: Tensor[batch, n_folds, y_targets]
    returns: Tensor[batch, n_folds]
    """
    clipped_preds = y_preds.clip(EPSILON, 1.0)
    return  -((clipped_preds * torch.log(clipped_preds)).sum(dim=2))

def mae(y_preds: Tensor, y_true: Tensor) -> Tensor:
    """
    y_preds: Tensor[batch, n_folds, y_targets]
    y_true:  Tensor[batch, y_targets]
    returns: Tensor[batch, n_folds]
    """
    return torch.abs(y_preds - y_true.unsqueeze(1)).mean(dim=2)

class LogisticRegression(nn.Module):
    def __init__(self):
        super().__init__()
        self.log_reg = nn.Sequential(
            nn.LazyLinear(N_FOLDS),
            nn.Sigmoid(),
        )

    def forward(
            self,
            y_preds: Tensor,
            orient_preds: Tensor,
            bin_demos_y_preds: Tensor,
            reg_demos_y_preds: Tensor,
            bin_demos_y_true: Tensor,
            reg_demos_y_true: Tensor
        ) -> Tensor:
        experts_preds_stats = torch.concatenate((
            preds_uncertainty(y_preds),
            preds_uncertainty(orient_preds),
            preds_uncertainty(bin_demos_y_preds),
            preds_uncertainty(reg_demos_y_preds),
            mae(bin_demos_y_preds, bin_demos_y_true),
            mae(reg_demos_y_preds, reg_demos_y_true),
            ),
            dim=1,
        )
        weights = self.log_reg(experts_preds_stats)
        weighted_y_preds = torch.einsum("be, bet -> bt", weights, y_preds)
        y_pred = weighted_y_preds / weights.sum(dim=1, keepdim=True)

        return y_pred

In [45]:
# Fold 0 targets
device = torch.device("cuda")
dataset = GatingDataset(df, device)
mean_gate_metrics = evaluate_gating_model(dataset, MeanGate())
print("mean_gate_metrics:", mean_gate_metrics)
logisticGate = LogisticRegression().to(device)
base_logistic_gate_metrics = evaluate_gating_model(dataset, logisticGate)
print("base logistic_gate_metrics:", base_logistic_gate_metrics)
training_metrics = train_model_on_all_epochs(dataset, logisticGate)
print("trained logistic_gate_metrics:", training_metrics.iloc[-1].to_dict())

mean_gate_metrics: {'accuracy': 0.055944055944055944, 'binary_f1': 0.5694044719946867, 'macro_f1': 0.12098403353544557, 'final_metric': 0.3451942527650661}
base logistic_gate_metrics: {'accuracy': 0.05778432094221568, 'binary_f1': 0.6131272401433692, 'macro_f1': 0.12842631399809165, 'final_metric': 0.3707767770707304}
trained logistic_gate_metrics: {'epoch': 4.0, 'train_loss': 2.812898015028245, 'accuracy': 0.1294319715372347, 'binary_f1': 0.6470266126382652, 'macro_f1': 0.1542235891215908, 'final_metric': 0.400625100879928}


In [46]:
torch.save(logisticGate.state_dict(), f"models/gating_model.pth")