# Libraries

In [1]:
import os
import subprocess

repo_url = 'https://github.com/alxxtexxr/Parkinsons_VGRF_Spatiotemporal_Diagnosis_v2.git'
repo_name = repo_url.rstrip('.git').split('/')[-1] # Extract repo name from URL
root_dir = '/content'
repo_dir = os.path.join(root_dir, repo_name)

if os.path.isdir(repo_dir):
    print(f"Repo '{repo_name}' already exists. Updating to the latest version...")
    os.chdir(repo_dir)
    subprocess.run(['git', 'reset', '--hard', 'origin/master'], check=True) # Reset to match remote
    subprocess.run(['git', 'pull'], check=True) # Pull latest changes
else:
    print(f"Repo '{repo_name}' not found. Cloning a fresh copy...")
    os.chdir(root_dir)
    subprocess.run(['git', 'clone', '--no-checkout', repo_url], check=True)
    os.chdir(repo_dir)
    print("Initializing sparse checkout for 'src' directory...")
    subprocess.run(['git', 'sparse-checkout', 'init', '--cone'], check=True)
    subprocess.run(['git', 'sparse-checkout', 'set', 'src'], check=True)
    subprocess.run(['git', 'checkout'], check=True)

Repo 'Parkinsons_VGRF_Spatiotemporal_Diagnosis_v2' not found. Cloning a fresh copy...
Initializing sparse checkout for 'src' directory...


In [2]:
%pip install -q tsai

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m324.3/324.3 kB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m278.0/278.0 kB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m79.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m906.5/906.5 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m74.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m62.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31

In [3]:
import os
import random
import numpy as np
import torch
from datetime import datetime
from torch.utils.data import TensorDataset, DataLoader

from src.utils import (
    print_h, eval_window, eval_person_majority_voting,
    set_seed,
)
from src.models.RNNInceptionTime import RNNInceptionTime

# Training

In [None]:
def update_output(output, eval_output):
    for metric in ['acc', 'f1', 'precision', 'recall', 'cm']:
        output[metric]['folds'] += [eval_output[metric]]
        if metric != 'cm':
            output[metric]['avg'] = np.mean(output[metric]['folds'])
            output[metric]['std'] = np.std(output[metric]['folds'])
    return output

class HardMoE(torch.nn.Module):
    def __init__(self, experts, gate):
        super(MoE, self).__init__()
        self.experts = torch.nn.ModuleList(experts)
        self.gate = gate

    def forward(self, x):
        gate_out = self.gate(x)
        gate_out_max_idxs = torch.argmax(gate_out, dim=1)
        expert_outs = torch.stack([expert(x) for expert in self.experts], dim=-1)
        output = expert_outs[torch.arange(expert_outs.size(0)), :, gate_out_max_idxs]
        return output

seed = 69
k_fold_dir_map = {
    'Ga': '/home/mitlab/Documents/alxxtexxr/projects/classification_parkinsons_vgrf/datasets/preprocessed_mixed_loocv_v20240731_val_2/kfold10_window500_stride500_feature16_Ga2',
    'Ju': '/home/mitlab/Documents/alxxtexxr/projects/classification_parkinsons_vgrf/datasets/preprocessed_mixed_loocv_v20240731_val_2/kfold10_window500_stride500_feature16_Ju3',
    'Si': '/home/mitlab/Documents/alxxtexxr/projects/classification_parkinsons_vgrf/datasets/preprocessed_mixed_loocv_v20240731_val_2/kfold10_window500_stride250_feature16_Si8',
}
expert_model_dir_map = {
    'Ga': '/home/mitlab/Documents/alxxtexxr/projects/classification_parkinsons_vgrf/outputs/train/RNNInceptionTime_Ga_baseline_20240911161249',
    'Ju': '/home/mitlab/Documents/alxxtexxr/projects/classification_parkinsons_vgrf/outputs/train/RNNInceptionTime_Ju_baseline_20240911143621',
    'Si': '/home/mitlab/Documents/alxxtexxr/projects/classification_parkinsons_vgrf/outputs/train/RNNInceptionTime_Si_baseline_20240911144140'
}
gate_model_dir = '/home/mitlab/Documents/alxxtexxr/projects/classification_parkinsons_vgrf/outputs/train/RNNInceptionTime_Gate_baseline_20240912075120'

k_fold = 10
batch_size = 8
n_feature = 16
n_class = 4
window_size = 500
max_vgrf_data_len = 25_000
device = 'cuda:0'

os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# torch.use_deterministic_algorithms(True)

moe_output = {
    'acc': { 'folds': [], 'avg': None, 'std': None, },
    'f1': { 'folds': [], 'avg': None, 'std': None, },
    'precision': { 'folds': [], 'avg': None, 'std': None, },
    'recall': { 'folds': [], 'avg': None, 'std': None, },
    'cm': { 'folds': [], },
    'val_loss': { 'folds': [], },
}
gate_output = {
    'acc': { 'folds': [], 'avg': None, 'std': None, },
    'f1': { 'folds': [], 'avg': None, 'std': None, },
    'precision': { 'folds': [], 'avg': None, 'std': None, },
    'recall': { 'folds': [], 'avg': None, 'std': None, },
    'cm': { 'folds': [], },
    'val_loss': { 'folds': [], },
}
expert_outputs = {
    'Ga': {
        'acc': { 'folds': [], 'avg': None, 'std': None, },
        'f1': { 'folds': [], 'avg': None, 'std': None, },
        'precision': { 'folds': [], 'avg': None, 'std': None, },
        'recall': { 'folds': [], 'avg': None, 'std': None, },
        'cm': { 'folds': [], },
        'val_loss': { 'folds': [], },
    },
    'Ju': {
        'acc': { 'folds': [], 'avg': None, 'std': None, },
        'f1': { 'folds': [], 'avg': None, 'std': None, },
        'precision': { 'folds': [], 'avg': None, 'std': None, },
        'recall': { 'folds': [], 'avg': None, 'std': None, },
        'cm': { 'folds': [], },
        'val_loss': { 'folds': [], },
    },
    'Si': {
        'acc': { 'folds': [], 'avg': None, 'std': None, },
        'f1': { 'folds': [], 'avg': None, 'std': None, },
        'precision': { 'folds': [], 'avg': None, 'std': None, },
        'recall': { 'folds': [], 'avg': None, 'std': None, },
        'cm': { 'folds': [], },
        'val_loss': { 'folds': [], },
    },
}

for i_fold in range(k_fold):
    print_h(f"FOLD-{i_fold+1}", 128)

    study_label_map = {
        'Ga': 0,
        'Ju': 1,
        'Si': 2,
    }
    
    expert_model_map = {
        'Ga': RNNInceptionTime(c_in=n_feature, c_out=n_class, seq_len=window_size, bidirectional=True).to(device),
        'Ju': RNNInceptionTime(c_in=n_feature, c_out=n_class, seq_len=window_size, bidirectional=True).to(device),
        'Si': RNNInceptionTime(c_in=n_feature, c_out=n_class, seq_len=window_size, bidirectional=True).to(device),
    }

    X_train_window_GaJuSi = torch.empty(0, window_size, n_feature).float()
    y_train_window_GaJuSi = torch.empty(0).long()
    study_labels_train_window_GaJuSi = torch.empty(0).long()
    
    X_val_window_GaJuSi = torch.empty(0, window_size, n_feature).float()
    y_val_window_GaJuSi = torch.empty(0).long()
    study_labels_val_window_GaJuSi = torch.empty(0).long()

    X_test_window_GaJuSi = torch.empty(0, window_size, n_feature).float()
    y_test_window_GaJuSi = torch.empty(0).long()
    study_labels_test_window_GaJuSi = torch.empty(0).long()

    X_val_person_GaJuSi = torch.empty(0, max_vgrf_data_len, n_feature).float()
    y_val_person_GaJuSi = torch.empty(0).long()
    # study_labels_val_person_GaJuSi = torch.empty(0).long()

    X_test_person_GaJuSi = torch.empty(0, max_vgrf_data_len, n_feature).float()
    y_test_person_GaJuSi = torch.empty(0).long()
    # study_labels_test_person_GaJuSi = torch.empty(0).long()

    for study, k_fold_dir in k_fold_dir_map.items():
        fold_i_dir_name = os.listdir(k_fold_dir)[i_fold]
        fold_i_dir = os.path.join(k_fold_dir, fold_i_dir_name)

        X_train_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_train_window.npy'))).float()
        y_train_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_train_window.npy'))).long()
        study_labels_train_window = torch.tensor([study_label_map[study]] * len(y_train_window)).long()
        X_train_window_GaJuSi = torch.cat((X_train_window_GaJuSi, X_train_window), dim=0)
        y_train_window_GaJuSi = torch.cat((y_train_window_GaJuSi, y_train_window), dim=0)
        study_labels_train_window_GaJuSi = torch.cat((study_labels_train_window_GaJuSi, study_labels_train_window), dim=0)

        X_val_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_val_window.npy'))).float()
        y_val_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_val_window.npy'))).long()
        study_labels_val_window = torch.tensor([study_label_map[study]] * len(y_val_window)).long()
        X_val_window_GaJuSi = torch.cat((X_val_window_GaJuSi, X_val_window), dim=0)
        y_val_window_GaJuSi = torch.cat((y_val_window_GaJuSi, y_val_window), dim=0)
        study_labels_val_window_GaJuSi = torch.cat((study_labels_val_window_GaJuSi, study_labels_val_window), dim=0)

        X_test_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_test_window.npy'))).float()
        y_test_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_test_window.npy'))).long()
        study_labels_test_window = torch.tensor([study_label_map[study]] * len(y_test_window)).long()
        X_test_window_GaJuSi = torch.cat((X_test_window_GaJuSi, X_test_window), dim=0)
        y_test_window_GaJuSi = torch.cat((y_test_window_GaJuSi, y_test_window), dim=0)
        study_labels_test_window_GaJuSi = torch.cat((study_labels_test_window_GaJuSi, study_labels_test_window), dim=0)

        X_val_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_val_person.npy'))).float()
        y_val_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_val_person.npy'))).long()
        X_val_person_GaJuSi = torch.cat((X_val_person_GaJuSi, X_val_person), dim=0)
        y_val_person_GaJuSi = torch.cat((y_val_person_GaJuSi, y_val_person), dim=0)

        X_test_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_test_person.npy'))).float()
        y_test_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_test_person.npy'))).long()
        X_test_person_GaJuSi = torch.cat((X_test_person_GaJuSi, X_test_person), dim=0)
        y_test_person_GaJuSi = torch.cat((y_test_person_GaJuSi, y_test_person), dim=0)

        train_window_dataset = TensorDataset(X_train_window, y_train_window)
        val_window_dataset = TensorDataset(X_val_window, y_val_window)
        test_window_dataset = TensorDataset(X_test_window, y_test_window)
        
        val_person_dataset = TensorDataset(X_val_person, y_val_person)
        test_person_dataset = TensorDataset(X_test_person, y_test_person)

        train_dataloader = DataLoader(train_window_dataset, batch_size=batch_size, shuffle=True)
        val_dataloader = DataLoader(val_window_dataset, batch_size=batch_size, shuffle=False)
        test_dataloader = DataLoader(test_window_dataset, batch_size=batch_size, shuffle=False)
        
        expert_model = expert_model_map[study]

        # Load pretrained model
        expert_model_dir = expert_model_dir_map[study]
        expert_model_i_name = os.listdir(expert_model_dir)[i_fold]
        expert_model_i_path = os.path.join(expert_model_dir, expert_model_i_name)
        expert_model.load_state_dict(torch.load(expert_model_i_path))

        print_h("EVALUATION ON PERSON DATA BY MAJORITY VOTING", 64)
        _, acc_person_majority_voting, f1_person_majority_voting, precision_person_majority_voting, recall_person_majority_voting, cm_person_majority_voting = eval_person_majority_voting(expert_model, val_person_dataset, criterion=None, average='weighted',
                                                                                                                                                                                                window_size=window_size, debug=False)
        print("acc:", acc_person_majority_voting)
        print("f1:", f1_person_majority_voting)
        print("precision:", precision_person_majority_voting)
        print("recall:", recall_person_majority_voting)
        print("cm:\n", np.array(cm_person_majority_voting))
        print()

        expert_outputs[study] = update_output(expert_outputs[study], {
            'acc': acc_person_majority_voting,
            'f1': f1_person_majority_voting,
            'precision': precision_person_majority_voting,
            'recall': recall_person_majority_voting,
            'cm': cm_person_majority_voting,
        })

    print_h("GATE", 96)

    # train_window_dataset_GaJuSi = TensorDataset(X_train_window_GaJuSi, y_train_window_GaJuSi)
    # val_window_dataset_GaJuSi = TensorDataset(X_val_window_GaJuSi, y_val_window_GaJuSi)
    # test_window_dataset_GaJuSi = TensorDataset(X_test_window_GaJuSi, y_test_window_GaJuSi)

    train_window_dataset_GaJuSi = TensorDataset(X_train_window_GaJuSi, study_labels_train_window_GaJuSi)
    val_window_dataset_GaJuSi = TensorDataset(X_val_window_GaJuSi, study_labels_val_window_GaJuSi)
    test_window_dataset_GaJuSi = TensorDataset(X_test_window_GaJuSi, study_labels_test_window_GaJuSi)

    train_dataloader_GaJuSi = DataLoader(train_window_dataset_GaJuSi, batch_size=batch_size, shuffle=True)
    val_dataloader_GaJuSi = DataLoader(val_window_dataset_GaJuSi, batch_size=batch_size, shuffle=False)
    test_dataloader_GaJuSi = DataLoader(test_window_dataset_GaJuSi, batch_size=batch_size, shuffle=False)

    gate_model = RNNInceptionTime(c_in=n_feature, c_out=len(study_label_map.keys()), seq_len=window_size, bidirectional=True).to(device)

    # Load pretrained model
    gate_model_i_name = os.listdir(gate_model_dir)[i_fold]
    gate_model_i_path = os.path.join(gate_model_dir, gate_model_i_name)
    gate_model.load_state_dict(torch.load(gate_model_i_path))

    print_h("EVALUATION ON WINDOW DATA", 64)
    
    _, acc_window, f1_window, precision_window, recall_window, cm_window = eval_window(gate_model, test_dataloader_GaJuSi, average='weighted')

    print("acc:", acc_window)
    print("f1:", f1_window)
    print("precision:", precision_window)
    print("recall:", recall_window)
    print("cm:\n", np.array(cm_window))
    print()

    gate_output = update_output(gate_output, {
        'acc': acc_window,
        'f1': f1_window,
        'precision': precision_window,
        'recall': recall_window,
        'cm': cm_window,
    })

    print_h("MoE", 96)

    val_person_dataset_GaJuSi = TensorDataset(X_val_person_GaJuSi, y_val_person_GaJuSi)
    test_person_dataset_GaJuSi = TensorDataset(X_test_person_GaJuSi, y_test_person_GaJuSi)

    moe_model = MoE(experts=expert_model_map.values(), gate=gate_model)

    print_h("EVALUATION ON PERSON DATA BY MAJORITY VOTING", 64)
    _, acc_person_majority_voting, f1_person_majority_voting, precision_person_majority_voting, recall_person_majority_voting, cm_person_majority_voting = eval_person_majority_voting(moe_model, val_person_dataset_GaJuSi, criterion=None, average='weighted',
                                                                                                                                                                                        window_size=window_size, debug=False)
    print("acc:", acc_person_majority_voting)
    print("f1:", f1_person_majority_voting)
    print("precision:", precision_person_majority_voting)
    print("recall:", recall_person_majority_voting)
    print("cm:\n", np.array(cm_person_majority_voting))
    print()

    moe_output = update_output(moe_output, {
        'acc': acc_person_majority_voting,
        'f1': f1_person_majority_voting,
        'precision': precision_person_majority_voting,
        'recall': recall_person_majority_voting,
        'cm': cm_person_majority_voting,
    })

    # for metric in ['acc', 'f1', 'precision', 'recall', 'cm']:
    #     moe_output[metric]['folds'] += [moe_eval_output[metric]]
    #     if metric != 'cm':
    #         moe_output[metric]['avg'] = np.mean(moe_output[metric]['folds'])
    #         moe_output[metric]['std'] = np.std(moe_output[metric]['folds'])
    
    # output['window']['train_loss']['folds'].append(global_train_loss_list)
    # output['window']['val_loss']['folds'].append(global_val_loss_window_list)

    # break # Test for only 1 fold