# Libraries

In [1]:
import os
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from pprint import pprint

from src.utils import (
    # Old utils
    print_h, eval_person_majority_voting,

    # New utils
    set_seed, get_device, init_metrics, update_metrics, save_metrics_to_json,
)
from src.models import RNNInceptionTime

  from .autonotebook import tqdm as notebook_tqdm


# Config

In [2]:
# Project config
seed = 69
set_seed(seed)
device = get_device()
print("Device:", device)

# Model and data config
# TODO: Ensure there is an assert to check whether the K-fold data is matched with the model

# ================================================================
# Ga
# ================================================================
# ================================
# Ga, OVERLAPPING, n_epoch=1
# ================================
model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s250_e1_v20250520220659'
k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================
# Ga, OVERLAPPING, n_epoch=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s250_e5_v20250520220706'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================
# Ga, OVERLAPPING, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s250_e10_v20250520220713'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================
# Ga, OVERLAPPING, n_epoch=20
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s250_e20_v20250520224411'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================
# Ga, OVERLAPPING, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s250_e30_v20250520224419'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================
# Ga, OVERLAPPING, n_epoch=50
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s250_e50_v20250521140906'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================
# Ga, OVERLAPPING, n_epoch=100
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s250_e100_v20250521140919'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================
# Ga, NON-OVERLAPPING, n_epoch=1
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s500_e1_v20250520220425'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s500_v20250501004633'

# ================================
# Ga, NON-OVERLAPPING, n_epoch=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s500_e5_v20250520220454'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s500_v20250501004633'

# ================================
# Ga, NON-OVERLAPPING, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s500_e10_v20250520220522'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s500_v20250501004633'

# ================================
# Ga, NON-OVERLAPPING, n_epoch=20
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s500_e20_v20250520224322'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s500_v20250501004633'

# ================================
# Ga, NON-OVERLAPPING, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s500_e30_v20250520224334'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s500_v20250501004633'

# ================================
# Ga, NON-OVERLAPPING, n_epoch=50
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s500_e50_v20250521141023'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s500_v20250501004633'

# ================================
# Ga, NON-OVERLAPPING, n_epoch=100
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s500_e100_v20250521141032'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s500_v20250501004633'

# ================================
# Ga, NON-OVERLAPPING, n_epoch=20, k_fold=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k5_w500_s500_e20_v20250523150148'
# k_fold_dir = 'data/preprocessed/Ga_k5_w500_s500_v20250523145245'

# ================================
# Ga, NON-OVERLAPPING, n_epoch=20, k_fold=3
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k3_w500_s500_e20_v20250523150158'
# k_fold_dir = 'data/preprocessed/Ga_k3_w500_s500_v20250523145118'

# ================================================================
# Ju
# ================================================================
# ================================
# Ju, OVERLAPPING, n_epoch=1
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_e1_v20250520221129'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_v20250501213914'

# ================================
# Ju, OVERLAPPING, n_epoch=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_e5_v20250520221136'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_v20250501213914'

# ================================
# Ju, OVERLAPPING, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_e10_v20250520221144'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_v20250501213914'

# ================================
# Ju, OVERLAPPING, n_epoch=20
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_e20_v20250520224505'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_v20250501213914'

# ================================
# Ju, OVERLAPPING, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_e30_v20250520224513'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_v20250501213914'

# ================================
# Ju, OVERLAPPING, n_epoch=50
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_e50_v20250521141059'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_v20250501213914'

# ================================
# Ju, OVERLAPPING, n_epoch=100
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_e100_v20250521141107'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_v20250501213914'

# ================================
# Ju, OVERLAPPING, W/ ANOMALY, n_epoch=1
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_w_anomaly_e1_v20250529213316'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_w_anomaly_v20250516211110'

# ================================
# Ju, OVERLAPPING, W/ ANOMALY, n_epoch=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_w_anomaly_e5_v20250529213323'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_w_anomaly_v20250516211110'

# ================================
# Ju, OVERLAPPING, W/ ANOMALY, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_w_anomaly_e10_v20250529213331'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_w_anomaly_v20250516211110'

# ================================
# Ju, OVERLAPPING, W/ ANOMALY, n_epoch=20
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_w_anomaly_e20_v20250529213338'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_w_anomaly_v20250516211110'

# ================================
# Ju, OVERLAPPING, W/ ANOMALY, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_w_anomaly_e30_v20250529213346'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_w_anomaly_v20250516211110'

# ================================
# Ju, OVERLAPPING, W/ ANOMALY, n_epoch=50
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_w_anomaly_e50_v20250529213354'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_w_anomaly_v20250516211110'

# ================================
# Ju, OVERLAPPING, W/ ANOMALY, n_epoch=100
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_w_anomaly_e100_v20250529213404'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_w_anomaly_v20250516211110'

# ================================
# Ju, NON-OVERLAPPING, n_epoch=1
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_e1_v20250520221239'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================
# Ju, NON-OVERLAPPING, n_epoch=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_e5_v20250520221245'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================
# Ju, NON-OVERLAPPING, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_e10_v20250520221252'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================
# Ju, NON-OVERLAPPING, n_epoch=20
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_e20_v20250520224548'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================
# Ju, NON-OVERLAPPING, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_e30_v20250520224556'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================
# Ju, NON-OVERLAPPING, n_epoch=50
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_e50_v20250521141129'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================
# Ju, NON-OVERLAPPING, n_epoch=100
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_e100_v20250521141136'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================
# Ju, NON-OVERLAPPING, n_epoch=30, k_fold=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k5_w500_s500_e30_v20250523150745'
# k_fold_dir = 'data/preprocessed/Ju_k5_w500_s500_v20250520202944'

# ================================
# Ju, NON-OVERLAPPING, n_epoch=30, k_fold=3
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k3_w500_s500_e30_v20250523150827'
# k_fold_dir = 'data/preprocessed/Ju_k3_w500_s500_v20250520203030'

# ================================
# Ju, NON-OVERLAPPING, W/ ANOMALY, n_epoch=1
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_w_anomaly_e1_v20250523203516'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_w_anomaly_v20250501004735'

# ================================
# Ju, NON-OVERLAPPING, W/ ANOMALY, n_epoch=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_w_anomaly_e5_v20250529001340'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_w_anomaly_v20250501004735'

# ================================
# Ju, NON-OVERLAPPING, W/ ANOMALY, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_w_anomaly_e10_v20250529001347'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_w_anomaly_v20250501004735'

# ================================
# Ju, NON-OVERLAPPING, W/ ANOMALY, n_epoch=20
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_w_anomaly_e20_v20250529001355'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_w_anomaly_v20250501004735'

# ================================
# Ju, NON-OVERLAPPING, W/ ANOMALY, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_w_anomaly_e30_v20250529001442'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_w_anomaly_v20250501004735'

# ================================
# Ju, NON-OVERLAPPING, W/ ANOMALY, n_epoch=50
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_w_anomaly_e50_v20250529001404'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_w_anomaly_v20250501004735'

# ================================
# Ju, NON-OVERLAPPING, W/ ANOMALY, n_epoch=100
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_w_anomaly_e100_v20250529001453'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_w_anomaly_v20250501004735'

# ================================================================
# Si
# ================================================================
# ================================
# Si, OVERLAPPING, n_epoch=1
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_e1_v20250520221423'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_v20250501004820'

# ================================
# Si, OVERLAPPING, n_epoch=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_e5_v20250520221428'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_v20250501004820'

# ================================
# Si, OVERLAPPING, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_e10_v20250520221436'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_v20250501004820'

# ================================
# Si, OVERLAPPING, n_epoch=20
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_e20_v20250520224651'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_v20250501004820'

# ================================
# Si, OVERLAPPING, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_e30_v20250520224702'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_v20250501004820'

# ================================
# Si, OVERLAPPING, n_epoch=50
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_e50_v20250521141204'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_v20250501004820'

# ================================
# Si, OVERLAPPING, n_epoch=100
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_e100_v20250521141211'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_v20250501004820'

# ================================
# Si, OVERLAPPING, n_epoch=1
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_w_anomaly_e1_v20250529213546'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_w_anomaly_v20250501004847'

# ================================
# Si, OVERLAPPING, n_epoch=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_w_anomaly_e5_v20250529213554'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_w_anomaly_v20250501004847'

# ================================
# Si, OVERLAPPING, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_w_anomaly_e10_v20250529213602'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_w_anomaly_v20250501004847'

# ================================
# Si, OVERLAPPING, n_epoch=20
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_w_anomaly_e20_v20250529213620'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_w_anomaly_v20250501004847'

# ================================
# Si, OVERLAPPING, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_w_anomaly_e30_v20250529213629'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_w_anomaly_v20250501004847'

# ================================
# Si, OVERLAPPING, n_epoch=50
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_w_anomaly_e50_v20250529213641'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_w_anomaly_v20250501004847'

# ================================
# Si, OVERLAPPING, n_epoch=100
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_w_anomaly_e100_v20250529213647'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_w_anomaly_v20250501004847'

# ================================
# Si, NON-OVERLAPPING, n_epoch=1
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_e1_v20250520221557'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

# ================================
# Si, NON-OVERLAPPING, n_epoch=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_e5_v20250520221603'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

# ================================
# Si, NON-OVERLAPPING, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_e10_v20250520221611'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

# ================================
# Si, NON-OVERLAPPING, n_epoch=20
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_e20_v20250520224754'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

# ================================
# Si, NON-OVERLAPPING, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_e30_v20250520224801'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

# ================================
# Si, NON-OVERLAPPING, n_epoch=50
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_e50_v20250521141245'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

# ================================
# Si, NON-OVERLAPPING, n_epoch=100
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_e100_v20250521141250'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

# ================================
# Si, NON-OVERLAPPING, n_epoch=20, k_fold=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k5_w500_s500_e20_v20250523151306'
# k_fold_dir = 'data/preprocessed/Si_k5_w500_s500_v20250523145822'

# ================================
# Si, NON-OVERLAPPING, n_epoch=20, k_fold=3
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k3_w500_s500_e20_v20250523151334'
# k_fold_dir = 'data/preprocessed/Si_k3_w500_s500_v20250523145845'

# ================================
# Si, NON-OVERLAPPING, W/ ANOMALY, W/O SYNTH, n_epoch=20
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_w_anomaly_wo_synth_e20_v20250523151200'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_w_anomaly_wo_synth_v20250520191606'

# ================================
# Si, NON-OVERLAPPING, W/ ANOMALY n_epoch=1
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_w_anomaly_e1_v20250523203129'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_w_anomaly_v20250516233616'

# ================================
# Si, NON-OVERLAPPING, W/ ANOMALY n_epoch=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_w_anomaly_e5_v20250529001212'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_w_anomaly_v20250516233616'

# ================================
# Si, NON-OVERLAPPING, W/ ANOMALY n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_w_anomaly_e10_v20250529001223'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_w_anomaly_v20250516233616'

# ================================
# Si, NON-OVERLAPPING, W/ ANOMALY n_epoch=20
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_w_anomaly_e20_v20250523151109'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_w_anomaly_v20250516233616'

# ================================
# Si, NON-OVERLAPPING, W/ ANOMALY n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_w_anomaly_e30_v20250529212025'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_w_anomaly_v20250516233616'

# ================================
# Si, NON-OVERLAPPING, W/ ANOMALY n_epoch=50
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_w_anomaly_e50_v20250529001235'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_w_anomaly_v20250516233616'

# ================================
# Si, NON-OVERLAPPING, W/ ANOMALY n_epoch=100
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_w_anomaly_e100_v20250529001243'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_w_anomaly_v20250516233616'

# ================================
# Si, NON-OVERLAPPING, W/ ANOMALY n_epoch=80
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_w_anomaly_e80_v20250529212134'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_w_anomaly_v20250516233616'

# ================================
# Si, NON-OVERLAPPING, W/ ANOMALY n_epoch=120
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_w_anomaly_e120_v20250529212145'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_w_anomaly_v20250516233616'

# ================================
# Si, NON-OVERLAPPING, W/ ANOMALY n_epoch=150
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_w_anomaly_e150_v20250529212152'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_w_anomaly_v20250516233616'

# ================================
# Si, OVERLAPPING, W/ ANOMALY, W/O SYNTH, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_w_anomaly_wo_synth_e30_v20250602190732'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_w_anomaly_wo_synth_v20250501004913'

model_name = model_dir.split('/')[-1]
study = model_name.rsplit('_k')[0].rsplit('_')[-1]
print("Dataset study:", study)

# Evaluation config
k_fold = 10
batch_size = 8
n_feat = 16
n_class = 4
window_size = 500
max_vgrf_data_len = 25_000

# model_paths = sorted([model_dir+'/'+f for f in os.listdir(model_dir) if f.endswith('.pth')])
# assert len(set([len(model_path.rsplit('/')[-1]) == 11 for model_path in model_paths])) == 1 # TODO: Add assert message
# k_fold = len([f for f in os.listdir(model_dir) if f.endswith('.pth')])

general_metrics_dir = f'evaluations/{model_name}/_general_metrics'

2025-06-02 22:24:38.995303: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-02 22:24:39.014974: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-02 22:24:39.021020: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-02 22:24:39.035879: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Random seed: 69
Device: cuda
Dataset study: Ga


# Evaluation

In [3]:
metrics = {
    'person_majority_voting': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm']),
    # 'person_severity_voting': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm']),
    # 'person_max_severity': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm']),
    # 'window': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm']),
}

study_label_map = {
    'Ga': 0,
    'Ju': 1,
    'Si': 2,
}

for fold_i_dir_name in sorted(os.listdir(k_fold_dir)):
    # ================================================================================================================================
    # FOLD
    # ================================================================================================================================
    fold_i_dir = os.path.join(k_fold_dir, fold_i_dir_name)
    print_h(fold_i_dir_name, 128)

    # ================================================================================================
    # DATA
    # ================================================================================================
    X_train_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
    y_train_window_GaJuSi = torch.empty(0).long()
    study_labels_train_window_GaJuSi = torch.empty(0).long()
    
    X_val_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
    y_val_window_GaJuSi = torch.empty(0).long()
    study_labels_val_window_GaJuSi = torch.empty(0).long()

    X_test_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
    y_test_window_GaJuSi = torch.empty(0).long()
    study_labels_test_window_GaJuSi = torch.empty(0).long()

    X_val_person_GaJuSi = torch.empty(0, max_vgrf_data_len, n_feat).float()
    y_val_person_GaJuSi = torch.empty(0).long()
    # study_labels_val_person_GaJuSi = torch.empty(0).long()

    X_test_person_GaJuSi = torch.empty(0, max_vgrf_data_len, n_feat).float()
    y_test_person_GaJuSi = torch.empty(0).long()
    # study_labels_test_person_GaJuSi = torch.empty(0).long()

    X_train_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_train_window.npy'))).float()
    y_train_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_train_window.npy'))).long()
    study_labels_train_window = torch.tensor([study_label_map[study]] * len(y_train_window)).long()
    X_train_window_GaJuSi = torch.cat((X_train_window_GaJuSi, X_train_window), dim=0)
    y_train_window_GaJuSi = torch.cat((y_train_window_GaJuSi, y_train_window), dim=0)
    study_labels_train_window_GaJuSi = torch.cat((study_labels_train_window_GaJuSi, study_labels_train_window), dim=0)

    X_val_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_val_window.npy'))).float()
    y_val_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_val_window.npy'))).long()
    study_labels_val_window = torch.tensor([study_label_map[study]] * len(y_val_window)).long()
    X_val_window_GaJuSi = torch.cat((X_val_window_GaJuSi, X_val_window), dim=0)
    y_val_window_GaJuSi = torch.cat((y_val_window_GaJuSi, y_val_window), dim=0)
    study_labels_val_window_GaJuSi = torch.cat((study_labels_val_window_GaJuSi, study_labels_val_window), dim=0)

    X_test_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_test_window.npy'))).float()
    y_test_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_test_window.npy'))).long()
    study_labels_test_window = torch.tensor([study_label_map[study]] * len(y_test_window)).long()
    X_test_window_GaJuSi = torch.cat((X_test_window_GaJuSi, X_test_window), dim=0)
    y_test_window_GaJuSi = torch.cat((y_test_window_GaJuSi, y_test_window), dim=0)
    study_labels_test_window_GaJuSi = torch.cat((study_labels_test_window_GaJuSi, study_labels_test_window), dim=0)

    X_val_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_val_person.npy'))).float()
    y_val_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_val_person.npy'))).long()
    X_val_person_GaJuSi = torch.cat((X_val_person_GaJuSi, X_val_person), dim=0)
    y_val_person_GaJuSi = torch.cat((y_val_person_GaJuSi, y_val_person), dim=0)

    X_test_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_test_person.npy'))).float()
    y_test_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_test_person.npy'))).long()
    X_test_person_GaJuSi = torch.cat((X_test_person_GaJuSi, X_test_person), dim=0)
    y_test_person_GaJuSi = torch.cat((y_test_person_GaJuSi, y_test_person), dim=0)

    train_window_dataset = TensorDataset(X_train_window, y_train_window)
    val_window_dataset = TensorDataset(X_val_window, y_val_window)
    test_window_dataset = TensorDataset(X_test_window, y_test_window)
    
    val_person_dataset = TensorDataset(X_val_person, y_val_person)
    test_person_dataset = TensorDataset(X_test_person, y_test_person)

    train_dataloader = DataLoader(train_window_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_window_dataset, batch_size=batch_size, shuffle=False)
    test_dataloader = DataLoader(test_window_dataset, batch_size=batch_size, shuffle=False)

    # ================================================================================================s
    # MODEL
    # ================================================================================================s
    print_h(f"EXPERT-{study} MODEL", 96)
    
    # Load pretrained expert model
    model = RNNInceptionTime(c_in=n_feat, c_out=n_class, seq_len=window_size, bidirectional=True).to(device)
    model_i_path = os.path.join(model_dir, fold_i_dir_name + '.pth')
    model.load_state_dict(torch.load(model_i_path, map_location=device))

    # ================================================================
    # EVALUATION ON PERSON DATA BY MAJORITY VOTING          
    # ================================================================
    print_h("EVALUATION ON PERSON DATA BY MAJORITY VOTING", 64)
    (
        _, 
        acc_person_majority_voting, 
        f1_person_majority_voting, 
        precision_person_majority_voting, 
        recall_person_majority_voting, 
        cm_person_majority_voting, 
        *_
    ) = eval_person_majority_voting(
        model, 
        test_person_dataset, 
        criterion=None, 
        average='weighted',
        window_size=window_size, 
        debug=False,
        seed=seed,
    )
    print("acc:", acc_person_majority_voting)
    print("f1:", f1_person_majority_voting)
    print("precision:", precision_person_majority_voting)
    print("recall:", recall_person_majority_voting)
    print("cm:\n", np.array(cm_person_majority_voting))
    print()

    # metrics = update_metrics(metrics, {
    #     'acc': acc_person_majority_voting,
    #     'f1': f1_person_majority_voting,
    #     'precision': precision_person_majority_voting,
    #     'recall': recall_person_majority_voting,
    #     'cm': cm_person_majority_voting,
    # })

    in_metrics = {
        'person_majority_voting': {
            'acc': acc_person_majority_voting,
            'f1': f1_person_majority_voting,
            'precision': precision_person_majority_voting,
            'recall': recall_person_majority_voting,
            'cm': cm_person_majority_voting,
        },
        # 'person_severity_voting': {
        #     'acc': acc_person_severity_voting,
        #     'f1': f1_person_severity_voting,
        #     'precision': precision_person_severity_voting,
        #     'recall': recall_person_severity_voting,
        #     'cm': cm_person_severity_voting,
        # },
        # 'person_max_severity': {
        #     'acc': acc_person_max_severity,
        #     'f1': f1_person_max_severity,
        #     'precision': precision_person_max_severity,
        #     'recall': recall_person_max_severity,
        #     'cm': cm_person_max_severity,
        # },
        # 'window': {
        #     'acc': acc_window,
        #     'f1': f1_window,
        #     'precision': precision_window,
        #     'recall': recall_window,
        #     'cm': cm_window,
        # },
    }

    for metric_type in in_metrics.keys():
        update_metrics(metrics[metric_type], in_metrics[metric_type])

    # DEBUG: Test for only one fold
    # break

                                                            fold_01                                                             
                                        EXPERT-Si MODEL                                         
          EVALUATION ON PERSON DATA BY MAJORITY VOTING          
Random seed: 69

acc: 0.7777777777777778
f1: 0.7283950617283951
precision: 0.6888888888888889
recall: 0.7777777777777778
cm:
 [[4 0 0 0]
 [1 3 0 0]
 [0 1 0 0]
 [0 0 0 0]]

                                                            fold_02                                                             
                                        EXPERT-Si MODEL                                         
          EVALUATION ON PERSON DATA BY MAJORITY VOTING          
Random seed: 69

acc: 0.4444444444444444
f1: 0.42328042328042326
precision: 0.4148148148148148
recall: 0.4444444444444444
cm:
 [[3 1 0 0]
 [2 1 1 0]
 [0 1 0 0]
 [0 0 0 0]]

                                                            fold_03      

## Metrics

In [4]:
print_h("METRICS", 128)
save_metrics_to_json(metrics, general_metrics_dir, f'_{study}.json')
pprint(metrics, sort_dicts=False)
print()
print("Evaluation metrics is saved in:", general_metrics_dir)

                                                            METRICS                                                             
{'person_majority_voting': {'acc': {'folds': [0.7777777777777778,
                                              0.4444444444444444,
                                              0.6666666666666666,
                                              0.4444444444444444,
                                              0.7777777777777778,
                                              0.6666666666666666,
                                              0.7777777777777778,
                                              0.6666666666666666,
                                              0.3333333333333333,
                                              0.4444444444444444],
                                    'avg': 0.6,
                                    'std': 0.15869840952317446},
                            'f1': {'folds': [0.7283950617283951,
                                