# Library

In [1]:
import os
import time
import numpy as np
import torch
from datetime import datetime
from torch.utils.data import TensorDataset, DataLoader

from src.utils import (
    set_seed, get_device, print_h, 
    eval_window, eval_person_severity_voting, eval_person_majority_voting, eval_person_max_severity,
    init_metrics, update_metrics, save_metrics_to_json,
)
from src.models import RNNInceptionTime

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn as nn
import torch.nn.functional as F

# class Gate(nn.Module):
#     def __init__(self, input_dim, num_experts):
#         super().__init__()
#         self.linear = nn.Linear(input_dim, num_experts)

#     def forward(self, x):
#         logits = self.linear(x)  # [batch_size, num_experts]
#         weights = F.softmax(logits, dim=-1)
#         return weights

class MoE(nn.Module):
    def __init__(self, c_in, seq_len, experts, top_k=1):
        super().__init__()
        self.top_k = top_k
        self.experts = nn.ModuleList(experts)
        self.gate = RNNInceptionTime(c_in=c_in, seq_len=seq_len, c_out=len(experts), bidirectional=True)

        # Freeze expert parameters
        for expert in self.experts:
            for param in expert.parameters():
                param.requires_grad = False

    # def forward(self, x):
    #     # Ensure model is on the same device as x
    #     self.gate.to(x.device)
    #     self.experts.to(x.device)
        
    #     gate_output = self.gate(x)  # [batch_size, num_experts]
        
    #     expert_outputs = []
    #     for expert in self.experts:
    #         expert_out = expert(x)  # [batch_size, input_dim]
    #         expert_outputs.append(expert_out)

    #     expert_stack = torch.stack(expert_outputs, dim=1)  # [batch_size, num_experts, input_dim]

    #     # Reshape gate_output for broadcasting
    #     gate_output = gate_output.unsqueeze(-1)  # [batch_size, num_experts, 1]

    #     # Weighted sum of all expert outputs
    #     moe_output = (gate_output * expert_stack).sum(dim=1)  # [batch_size, input_dim]
    #     return moe_output

    def forward(self, x):
        gate_output = self.gate(x)  # [batch_size, num_experts]
        
        # Top-k routing
        topk_vals, topk_indices = torch.topk(gate_output, self.top_k, dim=-1)
        one_hot = torch.zeros_like(gate_output).scatter_(-1, topk_indices, 1.0)
        dispatch_mask = one_hot.unsqueeze(-1)  # [batch, num_experts, 1]

        expert_outputs = []
        for i, expert in enumerate(self.experts):
            expert_out = expert(x)  # [batch_size, input_dim]
            expert_out = expert_out.unsqueeze(1)  # [batch_size, 1, input_dim]
            expert_outputs.append(expert_out)
        
        expert_stack = torch.cat(expert_outputs, dim=1)  # [batch_size, num_experts, input_dim]

        # Combine expert outputs based on gate mask
        moe_output = (dispatch_mask * expert_stack).sum(dim=1)  # [batch_size, input_dim]
        return moe_output

    def train(self, mode=True):
        super().train(mode)
        for expert in self.experts:
            expert.eval()  # Ensure experts stay in eval mode

# Config

In [3]:
seed = 69
set_seed(seed)
device = get_device()
print("Device:", device)

# Data and model config
k_fold_dir_map = {
    'Ga': 'data/preprocessed/Ga_k10_w500_s500_v20250501004633',
    'Ju': 'data/preprocessed/Ju_k10_w500_s500_v20250501004709',
    'Si': 'data/preprocessed/Si_k10_w500_s500_v20250501213954',
}
expert_model_dir_map = {
    'Ga': 'checkpoints/RNNInceptionTime_bidirectional_Ga_k10_w500_s500_e20_v20250520224322',
    'Ju': 'checkpoints/RNNInceptionTime_bidirectional_Ju_k10_w500_s500_e30_v20250520224556',
    'Si': 'checkpoints/RNNInceptionTime_bidirectional_Si_k10_w500_s500_e20_v20250520224754',
}

# Training config
n_epoch = 20
batch_size = 8
n_feat = 16
n_class = 4
window_size = 500
max_vgrf_data_len = 25_000
lr = 3e-4

# Generate name tag
run_name_tag = '_'.join([k_fold_dir.split('/')[-1].rsplit('_v', 1)[0] for k_fold_dir in k_fold_dir_map.values()]) + f'_e{n_epoch}'
print("Run name tag:", run_name_tag)

2025-06-04 00:56:31.747813: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-04 00:56:31.768357: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-04 00:56:31.774769: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-04 00:56:31.789059: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Random seed: 69
Device: cuda
Run name tag: Ga_k10_w500_s500_Ju_k10_w500_s500_Si_k10_w500_s500_e20


In [4]:
# Set run names
v = datetime.now().strftime("%Y%m%d%H%M%S")
gate_run_name = f'RNNInceptionTimeGate_bidirectional_{run_name_tag+'_' if run_name_tag else ''}v{v}'
moe_run_name = f'RNNInceptionTimeMoE_bidirectional_{run_name_tag+'_' if run_name_tag else ''}v{v}'
print("Gate model run name:", gate_run_name)
print("MoE model run name:", moe_run_name)
print()

# Create save directories
# gate_save_dir = 'checkpoints/' + gate_run_name
# moe_save_dir = 'checkpoints/' + moe_run_name
# os.makedirs(gate_save_dir, exist_ok=True)
# os.makedirs(moe_save_dir, exist_ok=True)
# print("Gate model save directory:", gate_save_dir)
# print("MoE model save directory:", moe_save_dir)
# print()

# Initialize evaluation metrics
gate_metrics = {
    'window': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm', 'train_loss', 'val_loss', 'train_time']),
    # 'person_majority_voting': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm', 'train_loss', 'val_loss']),
    # 'person_severity_voting': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm', 'train_loss', 'val_loss']),
    # 'person_max_severity': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm', 'train_loss', 'val_loss']),
}
moe_metrics = {
    # 'window': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm']),
    'person_majority_voting': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm']),
    # 'person_severity_voting': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm']),
    # 'person_max_severity': init_metrics(['acc', 'f1', 'precision', 'recall', 'cm']),
}

study_label_map = {
    'Ga': 0,
    'Ju': 1,
    'Si': 2,
}

# for i_fold in range(k_fold):
for fold_i_dir_name in sorted(os.listdir(k_fold_dir_map['Ga'])):
    # ================================================================================================================================
    # FOLD
    # ================================================================================================================================
    print_h(fold_i_dir_name, 128)
    
    expert_model_map = {
        'Ga': RNNInceptionTime(c_in=n_feat, c_out=n_class, seq_len=window_size, bidirectional=True).to(device),
        'Ju': RNNInceptionTime(c_in=n_feat, c_out=n_class, seq_len=window_size, bidirectional=True).to(device),
        'Si': RNNInceptionTime(c_in=n_feat, c_out=n_class, seq_len=window_size, bidirectional=True).to(device),
    }

    X_train_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
    y_train_window_GaJuSi = torch.empty(0).long()
    study_labels_train_window_GaJuSi = torch.empty(0).long()
    
    X_val_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
    y_val_window_GaJuSi = torch.empty(0).long()
    study_labels_val_window_GaJuSi = torch.empty(0).long()

    X_test_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
    y_test_window_GaJuSi = torch.empty(0).long()
    study_labels_test_window_GaJuSi = torch.empty(0).long()

    X_val_person_GaJuSi = torch.empty(0, max_vgrf_data_len, n_feat).float()
    y_val_person_GaJuSi = torch.empty(0).long()
    # study_labels_val_person_GaJuSi = torch.empty(0).long()

    X_test_person_GaJuSi = torch.empty(0, max_vgrf_data_len, n_feat).float()
    y_test_person_GaJuSi = torch.empty(0).long()
    # study_labels_test_person_GaJuSi = torch.empty(0).long()

    for study, k_fold_dir in k_fold_dir_map.items():
        # ================================================================================================
        # EXPERT MODEL
        # ================================================================================================
        print_h(f"EXPERT-{study} MODEL", 96)
        
        fold_i_dir = os.path.join(k_fold_dir, fold_i_dir_name)

        X_train_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_train_window.npy'))).float()
        y_train_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_train_window.npy'))).long()
        study_labels_train_window = torch.tensor([study_label_map[study]] * len(y_train_window)).long()
        X_train_window_GaJuSi = torch.cat((X_train_window_GaJuSi, X_train_window), dim=0)
        y_train_window_GaJuSi = torch.cat((y_train_window_GaJuSi, y_train_window), dim=0)
        study_labels_train_window_GaJuSi = torch.cat((study_labels_train_window_GaJuSi, study_labels_train_window), dim=0)

        X_val_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_val_window.npy'))).float()
        y_val_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_val_window.npy'))).long()
        study_labels_val_window = torch.tensor([study_label_map[study]] * len(y_val_window)).long()
        X_val_window_GaJuSi = torch.cat((X_val_window_GaJuSi, X_val_window), dim=0)
        y_val_window_GaJuSi = torch.cat((y_val_window_GaJuSi, y_val_window), dim=0)
        study_labels_val_window_GaJuSi = torch.cat((study_labels_val_window_GaJuSi, study_labels_val_window), dim=0)

        X_test_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_test_window.npy'))).float()
        y_test_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_test_window.npy'))).long()
        study_labels_test_window = torch.tensor([study_label_map[study]] * len(y_test_window)).long()
        X_test_window_GaJuSi = torch.cat((X_test_window_GaJuSi, X_test_window), dim=0)
        y_test_window_GaJuSi = torch.cat((y_test_window_GaJuSi, y_test_window), dim=0)
        study_labels_test_window_GaJuSi = torch.cat((study_labels_test_window_GaJuSi, study_labels_test_window), dim=0)

        X_val_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_val_person.npy'))).float()
        y_val_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_val_person.npy'))).long()
        X_val_person_GaJuSi = torch.cat((X_val_person_GaJuSi, X_val_person), dim=0)
        y_val_person_GaJuSi = torch.cat((y_val_person_GaJuSi, y_val_person), dim=0)

        X_test_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_test_person.npy'))).float()
        y_test_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_test_person.npy'))).long()
        X_test_person_GaJuSi = torch.cat((X_test_person_GaJuSi, X_test_person), dim=0)
        y_test_person_GaJuSi = torch.cat((y_test_person_GaJuSi, y_test_person), dim=0)

        train_window_dataset = TensorDataset(X_train_window, y_train_window)
        val_window_dataset = TensorDataset(X_val_window, y_val_window)
        test_window_dataset = TensorDataset(X_test_window, y_test_window)
        
        val_person_dataset = TensorDataset(X_val_person, y_val_person)
        test_person_dataset = TensorDataset(X_test_person, y_test_person)

        train_dataloader = DataLoader(train_window_dataset, batch_size=batch_size, shuffle=True)
        val_dataloader = DataLoader(val_window_dataset, batch_size=batch_size, shuffle=False)
        test_dataloader = DataLoader(test_window_dataset, batch_size=batch_size, shuffle=False)
        
        expert_model = expert_model_map[study]

        # Load pretrained model
        model_i_path = os.path.join(expert_model_dir_map[study], fold_i_dir_name + '.pth')
        expert_model.load_state_dict(torch.load(model_i_path, map_location=device))
    
        # ================================================================
        # EXPERT MODEL EVALUATION ON PERSON DATA BY MAJORITY VOTING
        # ================================================================
        print_h("EVALUATION ON PERSON DATA BY MAJORITY VOTING", 64)
        _, acc_person_majority_voting, f1_person_majority_voting, precision_person_majority_voting, recall_person_majority_voting, cm_person_majority_voting, *_ = eval_person_majority_voting(expert_model, test_person_dataset, criterion=None, average='weighted',
                                                                                                                                                                                                window_size=window_size, debug=False, seed=seed)
        print("acc:", acc_person_majority_voting)
        print("f1:", f1_person_majority_voting)
        print("precision:", precision_person_majority_voting)
        print("recall:", recall_person_majority_voting)
        print("cm:\n", np.array(cm_person_majority_voting))
        print()

    # ================================================================================================
    # MoE MODEL
    # ================================================================================================
    print_h("MoE MODEL", 96)

    # For MoE model training
    train_window_dataset_GaJuSi = TensorDataset(X_train_window_GaJuSi, y_train_window_GaJuSi)
    val_window_dataset_GaJuSi = TensorDataset(X_val_window_GaJuSi, y_val_window_GaJuSi)
    test_window_dataset_GaJuSi = TensorDataset(X_test_window_GaJuSi, y_test_window_GaJuSi)

    # For gate model training
    # train_window_dataset_GaJuSi = TensorDataset(X_train_window_GaJuSi, study_labels_train_window_GaJuSi)
    # val_window_dataset_GaJuSi = TensorDataset(X_val_window_GaJuSi, study_labels_val_window_GaJuSi)
    # test_window_dataset_GaJuSi = TensorDataset(X_test_window_GaJuSi, study_labels_test_window_GaJuSi)

    train_dataloader_GaJuSi = DataLoader(train_window_dataset_GaJuSi, batch_size=batch_size, shuffle=True)
    val_dataloader_GaJuSi = DataLoader(val_window_dataset_GaJuSi, batch_size=batch_size, shuffle=False)
    test_dataloader_GaJuSi = DataLoader(test_window_dataset_GaJuSi, batch_size=batch_size, shuffle=False)

    val_person_dataset_GaJuSi = TensorDataset(X_val_person_GaJuSi, y_val_person_GaJuSi)
    test_person_dataset_GaJuSi = TensorDataset(X_test_person_GaJuSi, y_test_person_GaJuSi)

    # ================================
    # TRAINING
    # ================================
    print_h("TRAINING", 64)
    moe_model = MoE(c_in=n_feat, seq_len=window_size, experts=expert_model_map.values())
    
    # Initialize optimizer and loss function
    optimizer = torch.optim.Adam(moe_model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()

    # Swith the model to training mode
    moe_model.train()
    
    # Loop training epochs
    global_val_loss_window_list = []
    global_val_loss_person_list = []
    global_train_loss_list = []
    global_train_time_list = []
    train_loss_list = []
    
    for epoch in range(n_epoch):
        start_time = time.time()
        
        # Loop training batches
        for iter, (X_train, y_train) in enumerate(train_dataloader_GaJuSi):
            # Flush the computed gradients
            optimizer.zero_grad()
            
            X_train = X_train.to(device)
            y_train = y_train.to(device)
            
            # Feed forward the model
            X_train = X_train.permute(0, 2, 1)
            y_pred = moe_model(X_train)

            # Compute training loss
            train_loss = criterion(y_pred, y_train)
            train_loss_list.append(train_loss)

            # Backward pass the model
            train_loss.backward()
            
            # Update the model weights based on computed gradients
            optimizer.step()
        
        # Compute training time
        train_time = time.time() - start_time
        global_train_time_list.append(train_time)

        # ================================
        # VALIDATION
        # ================================
        avg_val_loss_window, acc_window, f1_window, *_ = eval_window(moe_model, val_dataloader_GaJuSi, criterion, average='weighted')
        avg_val_loss_person, acc_person, f1_person, *_ = eval_person_majority_voting(moe_model, val_person_dataset_GaJuSi, criterion=criterion, average='weighted',
                                                                                          window_size=window_size)
        
        global_val_loss_window_list.append(avg_val_loss_window)
        global_val_loss_person_list.append(avg_val_loss_person)
        
        # Compute the average training loss for each epoch
        avg_train_loss = sum(train_loss_list) / len(train_dataloader)
        global_train_loss_list.append(avg_train_loss.item())
        train_loss_list = []
        
        # ================================
        # LOGGING
        # ================================
        print(f"epoch: {epoch+1}, "
              f"train/loss: {avg_train_loss:.3f}, "
              f"val/loss_window: {avg_val_loss_window:.3f}, "
              f"val/acc_window: {acc_window:.3f}, "
              f"val/f1_window: {f1_window:.3f}, "
              f"val/loss_person: {avg_val_loss_person:.3f}, "
              f"val/acc_person: {acc_person:.3f}, "
              f"val/f1_person: {f1_person:.3f}, "
              f"train/time: {train_time:.1f}s"
        )
        
        # Switch the model back to training mode
        moe_model.train()
    print()

    # ================================================================
    # MoE MODEL EVALUATION ON PERSON DATA BY MAJORITY VOTING
    # ================================================================
    print_h("EVALUATION ON PERSON DATA BY MAJORITY VOTING", 64)
    _, acc_person_majority_voting, f1_person_majority_voting, precision_person_majority_voting, recall_person_majority_voting, cm_person_majority_voting, *_ = eval_person_majority_voting(moe_model, test_person_dataset_GaJuSi, criterion=None, average='weighted',
                                                                                                                                                                                           window_size=window_size, debug=False)
    print("acc:", acc_person_majority_voting)
    print("f1:", f1_person_majority_voting)
    print("precision:", precision_person_majority_voting)
    print("recall:", recall_person_majority_voting)
    print("cm:\n", np.array(cm_person_majority_voting))
    print()

    moe_in_metrics = {
        # 'window': {
        #     'acc': acc_window,
        #     'f1': f1_window,
        #     'precision': precision_window,
        #     'recall': recall_window,
        #     'cm': cm_window,
        # },
        'person_majority_voting': {
            'acc': acc_person_majority_voting,
            'f1': f1_person_majority_voting,
            'precision': precision_person_majority_voting,
            'recall': recall_person_majority_voting,
            'cm': cm_person_majority_voting,
        },
        # 'person_severity_voting': {
        #     'acc': acc_person_severity_voting,
        #     'f1': f1_person_severity_voting,
        #     'precision': precision_person_severity_voting,
        #     'recall': recall_person_severity_voting,
        #     'cm': cm_person_severity_voting,
        # },
        # 'person_max_severity': {
        #     'acc': acc_person_max_severity,
        #     'f1': f1_person_max_severity,
        #     'precision': precision_person_max_severity,
        #     'recall': recall_person_max_severity,
        #     'cm': cm_person_max_severity,
        # },   
    }

    for metric_type in moe_in_metrics.keys():
        update_metrics(moe_metrics[metric_type], moe_in_metrics[metric_type])

    # ================================================================================================
    # MoE MODEL SAVING
    # ================================================================================================
    # moe_save_path = os.path.join(moe_save_dir, f'{fold_i_dir_name}.pth')
    # torch.save(moe_model.state_dict(), moe_save_path)

    # print(f"MoE model checkpoint for {fold_i_dir_name} is saved to:", moe_save_path)
    # print()

    # DEBUG: Test for only 1 fold
    break

# save_metrics_to_json(moe_metrics, moe_save_dir, filename='_evaluation_metrics.json')
# print("MoE model evaluation metrics is saved in:", moe_save_dir)

Gate model run name: RNNInceptionTimeGate_bidirectional_Ga_k10_w500_s500_Ju_k10_w500_s500_Si_k10_w500_s500_e20_v20250604005633
MoE model run name: RNNInceptionTimeMoE_bidirectional_Ga_k10_w500_s500_Ju_k10_w500_s500_Si_k10_w500_s500_e20_v20250604005633

                                                            fold_01                                                             


                                        EXPERT-Ga MODEL                                         
          EVALUATION ON PERSON DATA BY MAJORITY VOTING          

acc: 0.9375
f1: 0.9299242424242424
precision: 0.9479166666666667
recall: 0.9375
cm:
 [[5 0 0 0]
 [0 6 0 0]
 [0 0 3 0]
 [1 0 0 1]]

                                        EXPERT-Ju MODEL                                         
          EVALUATION ON PERSON DATA BY MAJORITY VOTING          

acc: 0.9333333333333333
f1: 0.9303703703703704
precision: 0.9466666666666667
recall: 0.9333333333333333
cm:
 [[2 1 0 0]
 [0 4 0 0]
 [0 0 7 0]
 [0 0 0 1]]

                                        EXPERT-Si MODEL                                         
          EVALUATION ON PERSON DATA BY MAJORITY VOTING          

acc: 0.875
f1: 0.8694444444444445
precision: 0.9
recall: 0.875
cm:
 [[4 0 0 0]
 [1 2 0 0]
 [0 0 1 0]
 [0 0 0 0]]

                                           MoE MODEL                                            
              