# Libraries

In [1]:
import os
import numpy as np
import torch
from datetime import datetime
from torch.utils.data import TensorDataset, DataLoader

from src.utils import (
    set_seed, get_device, print_h, init_model,
    eval_window, eval_person_severity_voting, eval_person_majority_voting, eval_person_max_severity, 
    init_metrics, update_metrics, save_metrics_to_json,    
)

  from .autonotebook import tqdm as notebook_tqdm


# Config

In [2]:
seed = 69
set_seed(seed)
device = get_device()
print("Device:", device)

# Model config
model_name = 'RNNInceptionTimeStacked' # 'InceptionTime' | 'RNN' | 'InceptionTimeRNN' | 'RNNInceptionTime' | 'RNNInceptionTimeStacked'
bidirectional = True

# Data config
fold_i_dir_map = {
    'Ga': f'data/preprocessed/Ga_k10_w500_s500_v20250501004633/fold_06',
    'Ju': f'data/preprocessed/Ju_k10_w500_s500_w_anomaly_v20250501004735/fold_06',
    'Si': f'data/preprocessed/Si_k10_w500_s250_w_anomaly_v20250501004847/fold_06',
}
i_folds = [int(fold_i_dir.split('fold_')[-1]) for fold_i_dir in fold_i_dir_map.values()]
assert len(set(i_folds)) == 1, f"Fold numbers are inconsistent: {i_folds}"
i_fold = i_folds[0]
print("Fold number:", i_fold)

# Training config
batch_size = 8
n_feat = 16
n_class = 4
window_size = 500
max_vgrf_data_len = 25_000
lr = 3e-4
n_epoch = 20

# Generate name tag
run_name_tag = '_'.join([k_fold_dir.split('/')[-2].rsplit('_v', 1)[0] for k_fold_dir in fold_i_dir_map.values()]) + f'_fold_{i_fold:02}_e{n_epoch}'
print("Run name tag:", run_name_tag)

2025-06-05 15:06:43.037593: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-05 15:06:43.056754: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-05 15:06:43.062704: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-05 15:06:43.076807: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Random seed: 69
Device: cuda
Fold number: 6
Run name tag: Ga_k10_w500_s500_Ju_k10_w500_s500_w_anomaly_Si_k10_w500_s250_w_anomaly_fold_06_e20


# Training

In [3]:
# Set run name
run_name = f'{model_name}{'_bidirectional' if bidirectional else ''}_non_moe_{run_name_tag+'_' if run_name_tag else ''}v{datetime.now().strftime("%Y%m%d%H%M%S")}'
print("Run name:", run_name)

# Create save directory
save_dir = 'checkpoints/' + run_name
os.makedirs(save_dir, exist_ok=True)
print("Save directory:", save_dir)
print()

# Initialize evaluation metrics
metrics = {
    'person_majority_voting': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm', 'train_loss', 'val_loss']),
    # 'person_severity_voting': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm', 'train_loss', 'val_loss']),
    # 'person_max_severity': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm', 'train_loss', 'val_loss']),
    # 'window': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm', 'train_loss', 'val_loss']),
}

# ================================================================================================================================
# FOLD
# ================================================================================================================================
print_h(f"FOLD {i_fold}", 128)

# ================================================================================================
# DATA
# ================================================================================================
X_train_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
y_train_window_GaJuSi = torch.empty(0).long()
study_labels_train_window_GaJuSi = torch.empty(0).long()

X_val_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
y_val_window_GaJuSi = torch.empty(0).long()
study_labels_val_window_GaJuSi = torch.empty(0).long()

X_test_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
y_test_window_GaJuSi = torch.empty(0).long()
study_labels_test_window_GaJuSi = torch.empty(0).long()

X_val_person_GaJuSi = torch.empty(0, max_vgrf_data_len, n_feat).float()
y_val_person_GaJuSi = torch.empty(0).long()
# study_labels_val_person_GaJuSi = torch.empty(0).long()

X_test_person_GaJuSi = torch.empty(0, max_vgrf_data_len, n_feat).float()
y_test_person_GaJuSi = torch.empty(0).long()
# study_labels_test_person_GaJuSi = torch.empty(0).long()

study_label_map = {
    'Ga': 0,
    'Ju': 1,
    'Si': 2,
}

for study, fold_i_dir in fold_i_dir_map.items():
    X_train_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_train_window.npy'))).float()
    y_train_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_train_window.npy'))).long()
    study_labels_train_window = torch.tensor([study_label_map[study]] * len(y_train_window)).long()
    X_train_window_GaJuSi = torch.cat((X_train_window_GaJuSi, X_train_window), dim=0)
    y_train_window_GaJuSi = torch.cat((y_train_window_GaJuSi, y_train_window), dim=0)
    study_labels_train_window_GaJuSi = torch.cat((study_labels_train_window_GaJuSi, study_labels_train_window), dim=0)

    X_val_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_val_window.npy'))).float()
    y_val_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_val_window.npy'))).long()
    study_labels_val_window = torch.tensor([study_label_map[study]] * len(y_val_window)).long()
    X_val_window_GaJuSi = torch.cat((X_val_window_GaJuSi, X_val_window), dim=0)
    y_val_window_GaJuSi = torch.cat((y_val_window_GaJuSi, y_val_window), dim=0)
    study_labels_val_window_GaJuSi = torch.cat((study_labels_val_window_GaJuSi, study_labels_val_window), dim=0)

    X_test_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_test_window.npy'))).float()
    y_test_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_test_window.npy'))).long()
    study_labels_test_window = torch.tensor([study_label_map[study]] * len(y_test_window)).long()
    X_test_window_GaJuSi = torch.cat((X_test_window_GaJuSi, X_test_window), dim=0)
    y_test_window_GaJuSi = torch.cat((y_test_window_GaJuSi, y_test_window), dim=0)
    study_labels_test_window_GaJuSi = torch.cat((study_labels_test_window_GaJuSi, study_labels_test_window), dim=0)

    X_val_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_val_person.npy'))).float()
    y_val_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_val_person.npy'))).long()
    X_val_person_GaJuSi = torch.cat((X_val_person_GaJuSi, X_val_person), dim=0)
    y_val_person_GaJuSi = torch.cat((y_val_person_GaJuSi, y_val_person), dim=0)

    X_test_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_test_person.npy'))).float()
    y_test_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_test_person.npy'))).long()
    X_test_person_GaJuSi = torch.cat((X_test_person_GaJuSi, X_test_person), dim=0)
    y_test_person_GaJuSi = torch.cat((y_test_person_GaJuSi, y_test_person), dim=0)

train_window_dataset_GaJuSi = TensorDataset(X_train_window_GaJuSi, y_train_window_GaJuSi)
val_window_dataset_GaJuSi = TensorDataset(X_val_window_GaJuSi, y_val_window_GaJuSi)
test_window_dataset_GaJuSi = TensorDataset(X_test_window_GaJuSi, y_test_window_GaJuSi)

train_dataloader_GaJuSi = DataLoader(train_window_dataset_GaJuSi, batch_size=batch_size, shuffle=True)
val_dataloader_GaJuSi = DataLoader(val_window_dataset_GaJuSi, batch_size=batch_size, shuffle=False)
test_dataloader_GaJuSi = DataLoader(test_window_dataset_GaJuSi, batch_size=batch_size, shuffle=False)

val_person_dataset_GaJuSi = TensorDataset(X_val_person_GaJuSi, y_val_person_GaJuSi)
test_person_dataset_GaJuSi = TensorDataset(X_test_person_GaJuSi, y_test_person_GaJuSi)

# ================================================================================================
# TRAINING
# ================================================================================================
print_h("TRAINING", 96)

# Initialize model
model = init_model(model_name, device, c_in=n_feat, c_out=n_class, seq_len=window_size, bidirectional=bidirectional)

# Initialize optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion =  torch.nn.CrossEntropyLoss()

# Swith the model to training mode
model.train()

# Loop training epochs
global_val_loss_window_list = []
global_val_loss_person_list = []
global_train_loss_list = []
train_loss_list = []
# step = 0
for epoch in range(n_epoch):
    # Loop training batches
    for iter, (X_train, y_train) in enumerate(train_dataloader_GaJuSi):
        # Flush the computed gradients
        optimizer.zero_grad()
        
        X_train = X_train.to(device)
        y_train = y_train.to(device)
        
        # Feed forward the model
        X_train = X_train.permute(0, 2, 1)
        y_pred = model(X_train)

        # Compute training loss
        train_loss = criterion(y_pred, y_train)
        train_loss_list.append(train_loss)
        
        # if (iter+1) % 'step_siz']= 0:
        if iter+1 == len(train_dataloader_GaJuSi):
            # ================================================================
            # VALIDATION
            # ================================================================
            avg_val_loss_window, acc_window, f1_window, *_ = eval_window(model, val_dataloader_GaJuSi, criterion, average='weighted')
            # avg_val_loss_person, acc_person, f1_person, *_ = eval_person_majority_voting(model, val_person_dataset, criterion=criterion, average='weighted', 
            #                                                                              window_size=window_size)
            
            global_val_loss_window_list.append(avg_val_loss_window)
            # global_val_loss_person_list.append(avg_val_loss_person)
            
            # Compute the average training loss for each epoch
            avg_train_loss = sum(train_loss_list) / len(train_dataloader_GaJuSi)
            global_train_loss_list.append(avg_train_loss.item())
            train_loss_list = []
            
            # ================================================================
            # LOGGING
            # ================================================================
            print(f"epoch: {epoch+1}, "
                # f"iter: {iter+1}, "
                # f"step: {step+1}, "
                f"train/loss: {avg_train_loss:.3f}, "
                f"val/loss_window: {avg_val_loss_window:.3f}, "
                f"val/acc_window: {acc_window:.3f}, "
                f"val/f1_window: {f1_window:.3f}"
                # f"val/loss_person: {avg_val_loss_person:.3f}, "
                # f"val/acc_person: {acc_person:.3f}, "
                # f"val/f1_person: {f1_person:.3f}"
            )
            
            # Switch the model back to training mode
            model.train()
            
            # step += 1
        
        # Backward pass the model
        train_loss.backward()
        
        # Update the model weights based on computed gradients
        optimizer.step()
print()

# ================================================================================================
# EVALUATION
# ================================================================================================
print_h("EVALUATION", 96)

# ================================================================
# EVALUATION ON WINDOW DATA
# ================================================================
# print_h("EVALUATION ON WINDOW DATA", 64)
# (
#     _, 
#     acc_window, 
#     f1_window, 
#     precision_window, 
#     recall_window, 
#     cm_window
# ) = eval_window(
#     model, 
#     test_dataloader, 
#     average='weighted',
# )
# print("acc:", acc_window)
# print("f1:", f1_window)
# print("precision:", precision_window)
# print("recall:", recall_window)
# print("cm:\n", np.array(cm_window))
# print()

# ================================================================
# EVALUATION ON PERSON DATA BY MAJORITY VOTING
# ================================================================
print_h("EVALUATION ON PERSON DATA BY MAJORITY VOTING", 64)
(
    _,
    acc_person_majority_voting,
    f1_person_majority_voting,
    precision_person_majority_voting,
    recall_person_majority_voting,
    cm_person_majority_voting,
    *_,
) = eval_person_majority_voting(
    model, 
    val_person_dataset_GaJuSi, 
    criterion=None, 
    average='weighted',
    window_size=window_size, 
    debug=False,
)
print("acc:", acc_person_majority_voting)
print("f1:", f1_person_majority_voting)
print("precision:", precision_person_majority_voting)
print("recall:", recall_person_majority_voting)
print("cm:\n", np.array(cm_person_majority_voting))
print()

# ================================================================
# EVALUATION ON PERSON DATA BY SEVERITY VOTING
# ================================================================
# print_h("EVALUATION ON PERSON DATA BY SEVERITY VOTING", 64)
# (
#     _, 
#     acc_person_severity_voting, 
#     f1_person_severity_voting, 
#     precision_person_severity_voting, 
#     recall_person_severity_voting, 
#     cm_person_severity_voting,
# ) = eval_person_severity_voting(
#     model, 
#     test_person_dataset, 
#     criterion=None, 
#     average='weighted',
#     window_size=window_size, 
#     debug=False,
# )
# print("acc:", acc_person_severity_voting)
# print("f1:", f1_person_severity_voting)
# print("precision:", precision_person_severity_voting)
# print("recall:", recall_person_severity_voting)
# print("cm:\n", np.array(cm_person_severity_voting))
# print()

# ================================================================
# EVALUATION ON PERSON DATA BY MAX. SEVERITY
# ================================================================
# print_h("EVALUATION ON PERSON DATA BY MAX. SEVERITY", 64)
# (
#     _, 
#     acc_person_max_severity, 
#     f1_person_max_severity, 
#     precision_person_max_severity, 
#     recall_person_max_severity, 
#     cm_person_max_severity,
# ) = eval_person_max_severity(
#     model, 
#     test_person_dataset, 
#     criterion=None, 
#     average='weighted',
#     window_size=window_size, 
#     debug=False,
# )
# print("acc:", acc_person_max_severity)
# print("f1:", f1_person_max_severity)
# print("precision:", precision_person_max_severity)
# print("recall:", recall_person_max_severity)
# print("cm:\n", np.array(cm_person_max_severity))
# print()

in_metrics = {
    'person_majority_voting': {
        'acc': acc_person_majority_voting,
        'f1': f1_person_majority_voting,
        'precision': precision_person_majority_voting,
        'recall': recall_person_majority_voting,
        'cm': cm_person_majority_voting,
    },
    # 'person_severity_voting': {
    #     'acc': acc_person_severity_voting,
    #     'f1': f1_person_severity_voting,
    #     'precision': precision_person_severity_voting,
    #     'recall': recall_person_severity_voting,
    #     'cm': cm_person_severity_voting,
    # },
    # 'person_max_severity': {
    #     'acc': acc_person_max_severity,
    #     'f1': f1_person_max_severity,
    #     'precision': precision_person_max_severity,
    #     'recall': recall_person_max_severity,
    #     'cm': cm_person_max_severity,
    # },
    # 'window': {
    #     'acc': acc_window,
    #     'f1': f1_window,
    #     'precision': precision_window,
    #     'recall': recall_window,
    #     'cm': cm_window,
    # },
}

for metric_type in in_metrics.keys():
    update_metrics(metrics[metric_type], in_metrics[metric_type])

# metrics['window']['train_loss']['folds'].append(global_train_loss_list)
# metrics['window']['val_loss']['folds'].append(global_val_loss_window_list)

# ================================================================================================
# CHECKPOINT SAVING
# ================================================================================================
save_path = os.path.join(save_dir, f'fold_{i_fold:02}.pth')
torch.save(model.state_dict(), save_path)
print(f"Checkpoint for fold {i_fold:02} is saved to:", save_path)

save_metrics_to_json(metrics, save_dir, filename='_evaluation_metrics.json')
print("Evaluation metrics is saved in:", save_dir)

Run name: RNNInceptionTimeStacked_bidirectional_non_moe_Ga_k10_w500_s500_Ju_k10_w500_s500_w_anomaly_Si_k10_w500_s250_w_anomaly_fold_06_e20_v20250605150644
Save directory: checkpoints/RNNInceptionTimeStacked_bidirectional_non_moe_Ga_k10_w500_s500_Ju_k10_w500_s500_w_anomaly_Si_k10_w500_s250_w_anomaly_fold_06_e20_v20250605150644

                                                             FOLD 6                                                             


                                            TRAINING                                            
epoch: 1, train/loss: 1.087, val/loss_window: 1.526, val/acc_window: 0.456, val/f1_window: 0.343
epoch: 2, train/loss: 0.995, val/loss_window: 1.287, val/acc_window: 0.500, val/f1_window: 0.410
epoch: 3, train/loss: 0.938, val/loss_window: 3.178, val/acc_window: 0.296, val/f1_window: 0.176
epoch: 4, train/loss: 0.891, val/loss_window: 2.166, val/acc_window: 0.409, val/f1_window: 0.335
epoch: 5, train/loss: 0.845, val/loss_window: 2.479, val/acc_window: 0.442, val/f1_window: 0.384
epoch: 6, train/loss: 0.798, val/loss_window: 2.581, val/acc_window: 0.405, val/f1_window: 0.330
epoch: 7, train/loss: 0.755, val/loss_window: 1.989, val/acc_window: 0.548, val/f1_window: 0.483
epoch: 8, train/loss: 0.717, val/loss_window: 1.751, val/acc_window: 0.558, val/f1_window: 0.507
epoch: 9, train/loss: 0.681, val/loss_window: 2.283, val/acc_window: 0.485, val/f1_window: 0.398
epoch: 10, train/loss: 0.645, 