# Libraries

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from pprint import pprint

from src.utils import (
    # Old utils
    print_h, eval_person_majority_voting,

    # New utils
    set_seed, get_device, init_metrics, update_metrics, save_metrics_to_json,
)
from src.models import RNNInceptionTime

# Config

In [None]:
# Project config
seed = 69
set_seed(seed)
device = get_device()
print("Device:", device)

# Model config
# TODO: Ensure there is an assert to check whether the K-fold data is matched with the model

# ================================================================
# EXPERT-Ga MODEL
# ================================================================
# ================================
# OVERLAPPING, n_epoch=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s250_e5_v20250501214214'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================
# NON-OVERLAPPING, n_epoch=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s500_e5_v20250501005824'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s500_v20250501004633'

# ================================
# OVERLAPPING, n_epoch=1
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Ga_k10_w500_s250_e1_v20250517103843'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================
# OVERLAPPING, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Ga_k10_w500_s250_e10_v20250517103944'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================
# OVERLAPPING, n_epoch=20 (BEST)
# ================================
model_dir = 'checkpoints/RNNInceptionTime_Ga_k10_w500_s250_e20_v20250517104024'
k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================
# OVERLAPPING, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Ga_k10_w500_s250_e30_v20250517104154'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================
# OVERLAPPING, n_epoch=50 (BEST)
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Ga_k10_w500_s250_e50_v20250517104242'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================
# OVERLAPPING, n_epoch=100 (BEST)
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Ga_k10_w500_s250_e100_v20250517104322'
# k_fold_dir = 'data/preprocessed/Ga_k10_w500_s250_v20250501213826'

# ================================================================
# EXPERT-Ju MODEL
# ================================================================
# ================================
# OVERLAPPING, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s250_e30_v20250501214326'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s250_v20250501213914'

# ================================
# NON-OVERLAPPING, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_e30_v20250501010022'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================
# NON-OVERLAPPING, W/ ANOMALY, n_epoch=30
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_w_anomaly_e30_v20250501005623'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_w_anomaly_v20250501004735'

# ================================
# NON-OVERLAPPING, n_epoch=1
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Ju_k10_w500_s500_e1_v20250517110513'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================
# NON-OVERLAPPING, n_epoch=5 (BEST)
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Ju_k10_w500_s500_e5_v20250517110531'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================
# NON-OVERLAPPING, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Ju_k10_w500_s500_e10_v20250517110544'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================
# NON-OVERLAPPING, n_epoch=20 (BEST)
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Ju_k10_w500_s500_e20_v20250517110601'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================
# NON-OVERLAPPING, n_epoch=50 (BEST)
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Ju_k10_w500_s500_e50_v20250517110615'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================
# NON-OVERLAPPING, n_epoch=100
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Ju_k10_w500_s500_e100_v20250517110631'
# k_fold_dir = 'data/preprocessed/Ju_k10_w500_s500_v20250501004709'

# ================================================================
# EXPERT-Si MODEL
# ================================================================
# ================================
# NON-OVERLAPPING, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_e10_v20250501214439'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

# ================================
# NON-OVERLAPPING, W/ ANOMALY, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_w_anomaly_e10_v20250516233733'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_w_anomaly_v20250516233616'

# ================================
# OVERLAPPING, n_epoch=10
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s250_e10_v20250501010222'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s250_v20250501004820'

# ================================
# NON-OVERLAPPING, n_epoch=1
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Si_k10_w500_s500_e1_v20250517111017'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

# ================================
# NON-OVERLAPPING, n_epoch=5
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Si_k10_w500_s500_e5_v20250517111032'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

# ================================
# NON-OVERLAPPING, n_epoch=20 (BEST)
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Si_k10_w500_s500_e20_v20250517111045'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

# ================================
# NON-OVERLAPPING, n_epoch=30 (BEST)
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Si_k10_w500_s500_e30_v20250517111054'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

# ================================
# NON-OVERLAPPING, n_epoch=50
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Si_k10_w500_s500_e50_v20250517111104'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

# ================================
# NON-OVERLAPPING, n_epoch=100 (BEST)
# ================================
# model_dir = 'checkpoints/RNNInceptionTime_Si_k10_w500_s500_e100_v20250517111115'
# k_fold_dir = 'data/preprocessed/Si_k10_w500_s500_v20250501213954'

model_name = model_dir.split('/')[-1]
study = model_name.rsplit('_k')[0].rsplit('_')[-1]
print("Dataset study:", study)

# Evaluation config
k_fold = 10
batch_size = 8
n_feat = 16
n_class = 4
window_size = 500
max_vgrf_data_len = 25_000

model_paths = sorted([model_dir+'/'+f for f in os.listdir(model_dir) if f.endswith('.pth')])
assert len(set([len(model_path.rsplit('/')[-1]) == 11 for model_path in model_paths])) == 1 # TODO: Add assert message
k_fold = len([f for f in os.listdir(model_dir) if f.endswith('.pth')])

general_metrics_dir = f'evaluations/{model_name}/_general_metrics'

# Evaluation

In [None]:
metrics = {
    'person_majority_voting': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm']),
    # 'person_severity_voting': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm']),
    # 'person_max_severity': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm']),
    # 'window': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm']),
}

study_label_map = {
    'Ga': 0,
    'Ju': 1,
    'Si': 2,
}

for i_fold in range(k_fold):
    print_h(f"FOLD {i_fold+1}", 128)

    X_train_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
    y_train_window_GaJuSi = torch.empty(0).long()
    study_labels_train_window_GaJuSi = torch.empty(0).long()
    
    X_val_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
    y_val_window_GaJuSi = torch.empty(0).long()
    study_labels_val_window_GaJuSi = torch.empty(0).long()

    X_test_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
    y_test_window_GaJuSi = torch.empty(0).long()
    study_labels_test_window_GaJuSi = torch.empty(0).long()

    X_val_person_GaJuSi = torch.empty(0, max_vgrf_data_len, n_feat).float()
    y_val_person_GaJuSi = torch.empty(0).long()
    # study_labels_val_person_GaJuSi = torch.empty(0).long()

    X_test_person_GaJuSi = torch.empty(0, max_vgrf_data_len, n_feat).float()
    y_test_person_GaJuSi = torch.empty(0).long()
    # study_labels_test_person_GaJuSi = torch.empty(0).long()

    print_h(f"EXPERT-{study} MODEL", 96)

    fold_i_dir_name = os.listdir(k_fold_dir)[i_fold]
    fold_i_dir = os.path.join(k_fold_dir, fold_i_dir_name)

    X_train_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_train_window.npy'))).float()
    y_train_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_train_window.npy'))).long()
    study_labels_train_window = torch.tensor([study_label_map[study]] * len(y_train_window)).long()
    X_train_window_GaJuSi = torch.cat((X_train_window_GaJuSi, X_train_window), dim=0)
    y_train_window_GaJuSi = torch.cat((y_train_window_GaJuSi, y_train_window), dim=0)
    study_labels_train_window_GaJuSi = torch.cat((study_labels_train_window_GaJuSi, study_labels_train_window), dim=0)

    X_val_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_val_window.npy'))).float()
    y_val_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_val_window.npy'))).long()
    study_labels_val_window = torch.tensor([study_label_map[study]] * len(y_val_window)).long()
    X_val_window_GaJuSi = torch.cat((X_val_window_GaJuSi, X_val_window), dim=0)
    y_val_window_GaJuSi = torch.cat((y_val_window_GaJuSi, y_val_window), dim=0)
    study_labels_val_window_GaJuSi = torch.cat((study_labels_val_window_GaJuSi, study_labels_val_window), dim=0)

    X_test_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_test_window.npy'))).float()
    y_test_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_test_window.npy'))).long()
    study_labels_test_window = torch.tensor([study_label_map[study]] * len(y_test_window)).long()
    X_test_window_GaJuSi = torch.cat((X_test_window_GaJuSi, X_test_window), dim=0)
    y_test_window_GaJuSi = torch.cat((y_test_window_GaJuSi, y_test_window), dim=0)
    study_labels_test_window_GaJuSi = torch.cat((study_labels_test_window_GaJuSi, study_labels_test_window), dim=0)

    X_val_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_val_person.npy'))).float()
    y_val_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_val_person.npy'))).long()
    X_val_person_GaJuSi = torch.cat((X_val_person_GaJuSi, X_val_person), dim=0)
    y_val_person_GaJuSi = torch.cat((y_val_person_GaJuSi, y_val_person), dim=0)

    X_test_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_test_person.npy'))).float()
    y_test_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_test_person.npy'))).long()
    X_test_person_GaJuSi = torch.cat((X_test_person_GaJuSi, X_test_person), dim=0)
    y_test_person_GaJuSi = torch.cat((y_test_person_GaJuSi, y_test_person), dim=0)

    train_window_dataset = TensorDataset(X_train_window, y_train_window)
    val_window_dataset = TensorDataset(X_val_window, y_val_window)
    test_window_dataset = TensorDataset(X_test_window, y_test_window)
    
    val_person_dataset = TensorDataset(X_val_person, y_val_person)
    test_person_dataset = TensorDataset(X_test_person, y_test_person)

    train_dataloader = DataLoader(train_window_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_window_dataset, batch_size=batch_size, shuffle=False)
    test_dataloader = DataLoader(test_window_dataset, batch_size=batch_size, shuffle=False)
    
    # Load pretrained expert model
    model = RNNInceptionTime(c_in=n_feat, c_out=n_class, seq_len=window_size, bidirectional=True).to(device)
    model_i_path = model_paths[i_fold]
    model.load_state_dict(torch.load(model_i_path, map_location=device))

    print_h("EVALUATION ON PERSON DATA BY MAJORITY VOTING", 64)
    (
        _, 
        acc_person_majority_voting, 
        f1_person_majority_voting, 
        precision_person_majority_voting, 
        recall_person_majority_voting, 
        cm_person_majority_voting, 
        *_
    ) = eval_person_majority_voting(
        model, 
        test_person_dataset, 
        criterion=None, 
        average='weighted',
        window_size=window_size, 
        debug=False,
        seed=seed,
    )
    print("acc:", acc_person_majority_voting)
    print("f1:", f1_person_majority_voting)
    print("precision:", precision_person_majority_voting)
    print("recall:", recall_person_majority_voting)
    print("cm:\n", np.array(cm_person_majority_voting))
    print()

    # metrics = update_metrics(metrics, {
    #     'acc': acc_person_majority_voting,
    #     'f1': f1_person_majority_voting,
    #     'precision': precision_person_majority_voting,
    #     'recall': recall_person_majority_voting,
    #     'cm': cm_person_majority_voting,
    # })

    in_metrics = {
        'person_majority_voting': {
            'acc': acc_person_majority_voting,
            'f1': f1_person_majority_voting,
            'precision': precision_person_majority_voting,
            'recall': recall_person_majority_voting,
            'cm': cm_person_majority_voting,
        },
        # 'person_severity_voting': {
        #     'acc': acc_person_severity_voting,
        #     'f1': f1_person_severity_voting,
        #     'precision': precision_person_severity_voting,
        #     'recall': recall_person_severity_voting,
        #     'cm': cm_person_severity_voting,
        # },
        # 'person_max_severity': {
        #     'acc': acc_person_max_severity,
        #     'f1': f1_person_max_severity,
        #     'precision': precision_person_max_severity,
        #     'recall': recall_person_max_severity,
        #     'cm': cm_person_max_severity,
        # },
        # 'window': {
        #     'acc': acc_window,
        #     'f1': f1_window,
        #     'precision': precision_window,
        #     'recall': recall_window,
        #     'cm': cm_window,
        # },
    }

    for metric_type in in_metrics.keys():
        update_metrics(metrics[metric_type], in_metrics[metric_type])

    # DEBUG: Test for only one fold
    # break

## Metrics

In [None]:
print_h("METRICS", 128)
save_metrics_to_json(metrics, general_metrics_dir, f'_{study}.json')
pprint(metrics, sort_dicts=False)
print()
print("Evaluation metrics is saved in:", general_metrics_dir)