## Setup

In [None]:
import gc
import re
import os
import json 
import math
import shutil
import random
import warnings
from os.path import join
from functools import partial
from tqdm.notebook import tqdm
from collections import defaultdict
from operator import methodcaller
from typing import Optional, Literal
from typing import Optional, Literal, Iterator
from itertools import pairwise, starmap, product

import torch
import optuna
import kagglehub 
import numpy as np
import pandas as pd
import polars as pl
from numpy import ndarray
from torch import nn, Tensor
from numpy.linalg import norm
import torch.nn.functional as F
from torch.optim import Optimizer
from pandas import DataFrame as DF
from optuna.trial import TrialState
from sklearn.metrics import f1_score
from optuna.pruners import BasePruner
from optuna.exceptions import TrialPruned
from torch.utils.data import TensorDataset
from scipy.spatial.transform import Rotation
import kaggle_evaluation.cmi_inference_server
from torch.utils.data import DataLoader as DL
from sklearn.model_selection import GroupKFold
from rich.progress import Progress, Task, track
from sklearn.model_selection import train_test_split
from numpy.lib.stride_tricks import sliding_window_view
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.utils.class_weight import compute_class_weight
from torch.optim.lr_scheduler import ConstantLR, LRScheduler, _LRScheduler

from config import *
from model import mk_model
from training import CMIDataset

## Data

In [None]:
def record_target_feature(metrics:defaultdict, y_pred:Tensor, y_true:Tensor, preffix:str):
    for target_col_idx in range(y_pred.shape[1]):
        metrics[preffix + "_pred_" + str(target_col_idx)].append(y_pred[:, target_col_idx])
        metrics[preffix + "_true_" + str(target_col_idx)].append(y_true[:, target_col_idx])

def get_perf_and_seq_id(model:nn.Module, data_loader:DL, device:torch.device, seq_meta_data:DF) -> DF:
    metrics:dict[list[ndarray]] = defaultdict(list)
    model.eval()
    with torch.no_grad():
        for batch_x, batch_y, batch_orient_y, batch_bin_demos_y, batch_reg_demos_y, idx in data_loader:
            batch_x = batch_x.to(device).clone()
            batch_y = batch_y.to(device)

            outputs, orient_outputs, bin_demos_output, reg_demos_output = model(batch_x)
            losses = nn.functional.cross_entropy(
                outputs,
                batch_y,
                label_smoothing=LABEL_SMOOTHING,
                reduction="none",
            )
            orient_losses = nn.functional.cross_entropy(
                orient_outputs,
                batch_orient_y,
                label_smoothing=LABEL_SMOOTHING,
                reduction="none",
            )
            bin_demos_losses = nn.functional.binary_cross_entropy_with_logits(
                bin_demos_output,
                batch_bin_demos_y,
                # label_smoothing=LABEL_SMOOTHING,
                reduction="none",
            ).cpu().numpy()
            reg_demos_losses = nn.functional.mse_loss(reg_demos_output, batch_reg_demos_y, reduction="none").cpu().numpy()
            # Get predicted class indices
            y_pred = outputs.cpu().numpy()
            # Get true class indices from one-hot
            y_true = batch_y.cpu().numpy()

            metrics["losses"].append(losses.cpu().numpy())
            metrics["orient_losses"].append(orient_losses.cpu().numpy())
            record_target_feature(metrics, y_pred, y_true, "y")
            record_target_feature(metrics, orient_outputs, batch_orient_y, "orient")
            record_target_feature(metrics, bin_demos_output, batch_bin_demos_y, "bin")
            record_target_feature(metrics, batch_reg_demos_y, reg_demos_output, "reg")
            metrics["sequence_id"].append(seq_meta_data["sequence_id"].iloc[idx].values)

    metrics = {k: np.concat(v) for k, v in metrics.items()}

    return DF(metrics)

In [None]:
def record_models_outputs() -> DF:
    device = torch.device("cuda")
    dataset = CMIDataset(device)
    data_loader = DL(dataset, batch_size=1024, shuffle=False)
    seq_meta_data = pd.read_parquet("preprocessed_dataset/sequences_meta_data.parquet")
    dfs = []
    for fold in range(N_FOLDS):
        model = mk_model()
        checkpoint = torch.load(
            join(
                "models",
                f"model_fold_{fold}.pth"
            ),
            map_location=device,
            weights_only=True
        )
        model.load_state_dict(checkpoint)
        model.eval()
        dfs.append(get_perf_and_seq_id(model, data_loader, device, seq_meta_data))

    return pd.concat(dfs)

In [None]:
EPSILON = 1e-12

def preds_uncertainty(df:DF, preds_preffix:str) -> pd.Series:
    preds = df.filter(regex=f"{preds_preffix}*", axis="columns")
    clipped_preds = preds.clip(EPSILON, 1.0)
    return -float((clipped_preds * np.log(clipped_preds)).sum())

def post_process_df(df:DF) -> DF:
    return (
        df
        .assign(preds_uncertainty, )
    )

In [None]:
df = record_models_outputs()