## Setup

In [1]:
import gc
import re
import os
import json 
import math
import shutil
import random
import warnings
from os.path import join
from functools import partial
from tqdm.notebook import tqdm
from collections import defaultdict
from operator import methodcaller
from typing import Optional, Literal
from typing import Optional, Literal, Iterator
from itertools import pairwise, starmap, product

import torch
import optuna
import kagglehub 
import numpy as np
import pandas as pd
import polars as pl
from numpy import ndarray
from torch import nn, Tensor
from numpy.linalg import norm
import torch.nn.functional as F
from torch.optim import Optimizer
from pandas import DataFrame as DF
from optuna.trial import TrialState
from sklearn.metrics import f1_score
from optuna.pruners import BasePruner
from optuna.exceptions import TrialPruned
from torch.utils.data import TensorDataset
from scipy.spatial.transform import Rotation
import kaggle_evaluation.cmi_inference_server
from torch.utils.data import DataLoader as DL
from sklearn.model_selection import GroupKFold
from rich.progress import Progress, Task, track
from sklearn.model_selection import train_test_split
from numpy.lib.stride_tricks import sliding_window_view
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.utils.class_weight import compute_class_weight
from torch.optim.lr_scheduler import ConstantLR, LRScheduler, _LRScheduler

from config import *
from model import mk_model
from training import CMIDataset

In [None]:
GATING_INPUT_FEATURES = [
    "bin_mae",
    "reg_mae",
    "y_uncertainty",
    "bin_uncertainty",
    "reg_uncertainty",
    "orient_uncertainty",
]
GATING_MODEL_BATCH_SIZE = 256
N_GATING_MODEL_EPOCHS = 4

## Data

In [10]:
def record_target_feature(metrics:defaultdict, y_pred:Tensor, y_true:Tensor, preffix:str):
    y_pred = y_pred.cpu().numpy()
    y_true = y_true.cpu().numpy()
    for target_col_idx in range(y_pred.shape[1]):
        metrics[preffix + "_pred_" + str(target_col_idx)].append(y_pred[:, target_col_idx])
        metrics[preffix + "_true_" + str(target_col_idx)].append(y_true[:, target_col_idx])

def get_perf_and_seq_id(model:nn.Module, data_loader:DL, device:torch.device, seq_meta_data:DF) -> DF:
    metrics:dict[list[ndarray]] = defaultdict(list)
    model.eval()
    with torch.no_grad():
        for batch_x, batch_y, batch_orient_y, batch_bin_demos_y, batch_reg_demos_y, idx in data_loader:
            batch_x = batch_x.to(device).clone()
            batch_y = batch_y.to(device)

            outputs, orient_outputs, bin_demos_output, reg_demos_output = model(batch_x)
            losses = nn.functional.cross_entropy(
                outputs,
                batch_y,
                label_smoothing=LABEL_SMOOTHING,
                reduction="none",
            )
            orient_losses = nn.functional.cross_entropy(
                orient_outputs,
                batch_orient_y,
                label_smoothing=LABEL_SMOOTHING,
                reduction="none",
            )
            bin_demos_losses = nn.functional.binary_cross_entropy_with_logits(
                bin_demos_output,
                batch_bin_demos_y,
                # label_smoothing=LABEL_SMOOTHING,
                reduction="none",
            ).cpu().numpy()
            reg_demos_losses = nn.functional.mse_loss(reg_demos_output, batch_reg_demos_y, reduction="none").cpu().numpy()
            metrics["losses"].append(losses.cpu().numpy())
            metrics["orient_losses"].append(orient_losses.cpu().numpy())
            record_target_feature(metrics, outputs, batch_y, "y")
            record_target_feature(metrics, orient_outputs, batch_orient_y, "orient")
            record_target_feature(metrics, bin_demos_output, batch_bin_demos_y, "bin")
            record_target_feature(metrics, batch_reg_demos_y, reg_demos_output, "reg")
            metrics["sequence_id"].append(seq_meta_data["sequence_id"].iloc[idx].values)

    metrics = {k: np.concat(v) for k, v in metrics.items()}

    return DF(metrics)

In [57]:
EPSILON = 1e-12

def preds_uncertainty(df:DF, task_preffix:str) -> pd.Series:
    preds = df.filter(regex=f"{task_preffix}_pred_*", axis="columns")
    clipped_preds = preds.clip(EPSILON, 1.0)
    df[task_preffix + "_uncertainty"] = -((clipped_preds * np.log(clipped_preds)).sum(axis=1))
    return df

def mae(df:DF, task_preffix:str) -> DF:
    df_pred = df.filter(regex=f"{task_preffix}_pred_*", axis="columns")
    df_true = df.filter(regex=f"{task_preffix}_true_*", axis="columns")
    df[task_preffix + "_mae"] = np.abs(df_pred.values - df_true.values).mean(axis=1)

    return df

def post_process_df(df:DF) -> DF:
    return (
        df
        .pipe(preds_uncertainty, "y")
        .pipe(preds_uncertainty, "orient")
        .pipe(preds_uncertainty, "bin")
        .pipe(preds_uncertainty, "reg")
        .pipe(mae, "y")
        .pipe(mae, "orient")
        .pipe(mae, "bin")
        .pipe(mae, "reg")
    )

def record_models_outputs() -> DF:
    device = torch.device("cuda")
    dataset = CMIDataset(device)
    data_loader = DL(dataset, batch_size=1024, shuffle=False)
    seq_meta_data = pd.read_parquet("preprocessed_dataset/sequences_meta_data.parquet")
    dfs = []
    for fold_idx in tqdm(range(N_FOLDS), total=N_FOLDS):
        model = mk_model(device=device)
        checkpoint = torch.load(
            join(
                "models",
                f"model_fold_{fold_idx}.pth"
            ),
            map_location=device,
            weights_only=True
        )
        model.load_state_dict(checkpoint)
        model.eval()
        df = get_perf_and_seq_id(model, data_loader, device, seq_meta_data)
        dfs.append(df.assign(fold=fold_idx))

    return pd.concat(dfs).pipe(post_process_df)

In [58]:
df = record_models_outputs()

  0%|          | 0/20 [00:00<?, ?it/s]

## Training

In [None]:
def get_gating_features(df:DF) -> ndarray:
    return (
        df
        .loc[:, GATING_INPUT_FEATURES + ["fold"]]
        .pivot(columns="fold")
        .values
    )

def get_experts_outputs(df:DF) -> ndarray:
    return (
        df
        .filter(regex=r"^y_pred_|^fold$")
        .pivot(columns="fold")
        .values
        .reshape(-1, N_FOLDS, N_TARGETS)
    )

def get_gating_targets(df:DF) -> ndarray:
    return (
        df
        .query("fold == 0")
        .filter(regex="y_true*")
        .values
    )

class GatingDataset(TensorDataset):
    def __init__(self, df: DF, device: torch.device):
        super().__init__(
            torch.from_numpy(get_gating_features(df)).to(device),
            torch.from_numpy(get_experts_outputs(df)).to(device),
            torch.from_numpy(get_gating_targets(df)).to(device),
        )

Unnamed: 0_level_0,bin_mae,bin_mae,bin_mae,bin_mae,bin_mae,bin_mae,bin_mae,bin_mae,bin_mae,bin_mae,...,orient_uncertainty,orient_uncertainty,orient_uncertainty,orient_uncertainty,orient_uncertainty,orient_uncertainty,orient_uncertainty,orient_uncertainty,orient_uncertainty,orient_uncertainty
fold,0,1,2,3,4,5,6,7,8,9,...,10,11,12,13,14,15,16,17,18,19
0,5.373200,3.263342,4.640149,4.598254,4.851898,5.196218,4.149457,4.164123,4.707287,3.074975,...,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11
1,6.369648,6.881545,9.064594,7.586702,7.586257,7.663782,7.625151,8.253765,8.368605,8.049928,...,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11
2,8.297356,9.460062,5.012723,5.532870,6.969646,8.566780,7.147378,8.186001,8.710774,5.342514,...,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11
3,5.853631,6.073956,5.834551,7.092682,5.035452,5.440889,6.324386,5.682080,5.922688,4.963815,...,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11
4,6.361362,6.090506,6.644010,4.834425,6.041402,5.543854,4.901386,8.600010,7.591352,7.141142,...,8.289306e-11,8.289306e-11,8.289306e-11,2.628955e-01,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,3.664069e-01
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8146,3.854278,4.744507,4.921330,6.221528,5.334171,5.877177,3.079997,5.552984,4.114785,4.454160,...,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11
8147,5.839823,4.647733,5.976115,4.286340,5.743880,4.078864,5.254677,3.685849,5.469068,3.769079,...,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11
8148,4.513684,4.714145,9.269878,4.908830,6.493989,5.337626,7.587003,9.525738,6.214049,5.937212,...,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11
8149,4.107200,5.895021,6.839435,5.656706,4.276178,4.535508,5.213759,5.066491,5.985665,4.973286,...,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11,8.289306e-11


In [None]:
def evaluate_gating_model(dataset: Dataset, gating_model: nn.Module) -> dict:
    gating_model = gating_model.eval()
    data_loader = DL(dataset, batch_size=1024, shuffle=False)
    y_pred = []
    with torch.no_grad():
        for x, sub_y_preds, _ in data_loader:
            weights = gating_model(x)
            y_pred.append(None)#sum(weights * input) / sum(weights))
    
    y_pred = torch.argmax(torch.concat(y_pred), dim=1)
    y_true = torch.argmax(dataset.tensors[2], dim=1)
    metrics = {
        "accuracy": (y_pred == y_true).mean(),
    }

    return metrics

def train_model_on_single_epoch(data_loader: DL, gating_model: nn.Module, criterion: nn.Module, optimizer: Optimizer) -> dict:
    metrics = defaultdict(float)
    n_samples = 0
    gating_model = gating_model.train()
    for x, sub_y_preds, y_true in data_loader:
        optimizer.zero_grad()
        weights = gating_model(x)
        y_pred = None#sum(weights * input) / sum(weights))
        loss = criterion(y_pred, y_true)
        loss.backward()
        optimizer.step()
        n_samples += x.shape[0]
        metrics["train_loss"] += loss.item() * x.shape[0]

    metrics["train_loss"] /= n_samples
    
    return metrics

def train_model_on_all_epochs(dataset: Dataset, gating_model: nn.Module) -> DF:
    train_loader = DL(dataset, GATING_MODEL_BATCH_SIZE, shuffle=True)
    metrics: list[dict] = []
    criterion = nn.CrossEntropyLoss()
    for epoch in range(N_GATING_MODEL_EPOCHS):
        train_metrics = train_model_on_single_epoch(train_loader, gating_model, criterion)
        eval_metrics = evaluate_gating_model(dataset, gating_model)
        metrics[-1] = {"epoch": epoch} | train_metrics | eval_metrics

    return DF.from_records(metrics)

In [None]:
class MeanGate(nn.Module):
    def forward(self, experts_preds: Tensor) -> Tensor:
        return experts_preds.mean(dim=1)
    
class LogisticRegression(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.LazyLinear(N_TARGETS),
            nn.Sigmoid(),
        )

In [None]:
device = torch.device("cuda")
dataset = GatingDataset(df, device)
mean_gate_metrics = evaluate_gating_model(dataset, MeanGate())
print("mean_gate_metrics:", mean_gate_metrics)
logisticGate = LogisticRegression()
base_logistic_gate_metrics = evaluate_gating_model(dataset, logisticGate)
print("base logistic_gate_metrics:", base_logistic_gate_metrics)
training_metrics = train_model_on_all_epochs(dataset, base_logistic_gate_metrics)
print("trained logistic_gate_metrics:", base_logistic_gate_metrics.iloc[-1])