Refactored from: https://github.com/alxxtexxr/Parkinsons_VGRF_Spatiotemporal_Diagnosis/blob/master/eval_RNNInceptionTimeMoe_kfold10.ipynb

# Libraries

In [1]:
import os
import subprocess

def sparse_clone_repo(repo_url, sparse_dirs, root_dir='.'):
    """
    Clones or updates a Git repository with sparse checkout for specified directories.
    Parameters:
    - repo_url (str): The URL of the Git repository.
    - sparse_dirs (list of str): List of directories to retrieve using sparse checkout.
    - root_dir (str): The root directory where the repository will be stored (default is '/content').
    """
    repo_name = repo_url.rstrip('.git').split('/')[-1] # Extract repo name from URL
    repo_dir = os.path.join(root_dir, repo_name)

    if os.path.isdir(repo_dir):
        print(f"Repo '{repo_name}' already exists. Updating to the latest version...")
        os.chdir(repo_dir)
        subprocess.run(['git', 'reset', '--hard', 'origin/master'], check=True) # Reset to match remote
        subprocess.run(['git', 'pull'], check=True) # Pull latest changes
    else:
        print(f"Repo '{repo_name}' not found. Cloning a fresh copy...")
        os.chdir(root_dir)
        subprocess.run(['git', 'clone', '--no-checkout', repo_url], check=True)
        os.chdir(repo_dir)
        print(f"Initializing sparse checkout for directories: {', '.join(sparse_dirs)}...")
        subprocess.run(['git', 'sparse-checkout', 'init', '--cone'], check=True)
        subprocess.run(['git', 'sparse-checkout', 'set', *sparse_dirs], check=True)
        subprocess.run(['git', 'checkout'], check=True)
    
    print("Now working at repo directory:", os.getcwd())

sparse_clone_repo(
    repo_url='https://github.com/alxxtexxr/Parkinsons_VGRF_Spatiotemporal_Diagnosis_v2.git',
    sparse_dirs=['src', 'data', 'checkpoints'],
    root_dir='/content'
)

Repo 'Parkinsons_VGRF_Spatiotemporal_Diagnosis_v2' already exists. Updating to the latest version...
Now working at repo directory: /content/Parkinsons_VGRF_Spatiotemporal_Diagnosis_v2


In [2]:
%%capture
%pip install tsai # Required library for the models

In [4]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader

from src.utils import (
    print_h, eval_window, eval_person_majority_voting,
    set_seed,
)
from src.models import RNNInceptionTime, HardMoE

# Config

In [5]:
# Project config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 69
set_seed(seed)

# Model config
k_fold_dir_map = {
    'Ga': 'data/preprocessed_LOOCV/Ga_k_fold_10_window_500_stride_500_n_feat_16',
    'Ju': 'data/preprocessed_LOOCV/Ju_k_fold_10_window_500_stride_500_n_feat_16',
    'Si': 'data/preprocessed_LOOCV/Si_k_fold_10_window_500_stride_250_n_feat_16',
}
expert_model_dir_map = {
    'Ga': 'checkpoints/RNNInceptionTime_Ga_k_fold_10',
    'Ju': 'checkpoints/RNNInceptionTime_Ju_k_fold_10',
    'Si': 'checkpoints/RNNInceptionTime_Si_k_fold_10'
}
gate_model_dir = 'checkpoints/RNNInceptionTime_Gate_k_fold_10'

# Training config
k_fold = 10
batch_size = 8
n_feat = 16
n_class = 4
window_size = 500
max_vgrf_data_len = 25_000

Random seed set to: 69


# Training

In [6]:
def init_metrics():
    metric_names = ['acc', 'f1', 'precision', 'recall']
    metrics = {metric_name: {'folds': [], 'avg': None, 'std': None} for metric_name in metric_names}
    metrics.update({'cm': {'folds': []}, 'val_loss': {'folds': []}})
    return metrics

def update_metrics(metrics, in_metrics):
    for metric_name in ['acc', 'f1', 'precision', 'recall', 'cm']:
        metrics[metric_name]['folds'] += [in_metrics[metric_name]]
        if metric_name != 'cm':
            metrics[metric_name]['avg'] = np.mean(metrics[metric_name]['folds'])
            metrics[metric_name]['std'] = np.std(metrics[metric_name]['folds'])
    return metrics

moe_metrics = init_metrics()
gate_metrics = init_metrics()
expert_outputs = {
    'Ga': init_metrics(),
    'Ju': init_metrics(),
    'Si': init_metrics(),
}

In [11]:
study_label_map = {
    'Ga': 0,
    'Ju': 1,
    'Si': 2,
}

for i_fold in range(k_fold):
    print_h(f"FOLD-{i_fold+1}", 128)
    
    expert_model_map = {
        'Ga': RNNInceptionTime(c_in=n_feat, c_out=n_class, seq_len=window_size, bidirectional=True).to(device),
        'Ju': RNNInceptionTime(c_in=n_feat, c_out=n_class, seq_len=window_size, bidirectional=True).to(device),
        'Si': RNNInceptionTime(c_in=n_feat, c_out=n_class, seq_len=window_size, bidirectional=True).to(device),
    }

    X_train_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
    y_train_window_GaJuSi = torch.empty(0).long()
    study_labels_train_window_GaJuSi = torch.empty(0).long()
    
    X_val_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
    y_val_window_GaJuSi = torch.empty(0).long()
    study_labels_val_window_GaJuSi = torch.empty(0).long()

    X_test_window_GaJuSi = torch.empty(0, window_size, n_feat).float()
    y_test_window_GaJuSi = torch.empty(0).long()
    study_labels_test_window_GaJuSi = torch.empty(0).long()

    X_val_person_GaJuSi = torch.empty(0, max_vgrf_data_len, n_feat).float()
    y_val_person_GaJuSi = torch.empty(0).long()
    # study_labels_val_person_GaJuSi = torch.empty(0).long()

    X_test_person_GaJuSi = torch.empty(0, max_vgrf_data_len, n_feat).float()
    y_test_person_GaJuSi = torch.empty(0).long()
    # study_labels_test_person_GaJuSi = torch.empty(0).long()

    for study, k_fold_dir in k_fold_dir_map.items():
        fold_i_dir_name = os.listdir(k_fold_dir)[i_fold]
        fold_i_dir = os.path.join(k_fold_dir, fold_i_dir_name)

        print(os.path.join(fold_i_dir, f'X_train_window.npy'))
        X_train_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_train_window.npy'))).float()
        y_train_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_train_window.npy'))).long()
        study_labels_train_window = torch.tensor([study_label_map[study]] * len(y_train_window)).long()
        X_train_window_GaJuSi = torch.cat((X_train_window_GaJuSi, X_train_window), dim=0)
        y_train_window_GaJuSi = torch.cat((y_train_window_GaJuSi, y_train_window), dim=0)
        study_labels_train_window_GaJuSi = torch.cat((study_labels_train_window_GaJuSi, study_labels_train_window), dim=0)

        X_val_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_val_window.npy'))).float()
        y_val_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_val_window.npy'))).long()
        study_labels_val_window = torch.tensor([study_label_map[study]] * len(y_val_window)).long()
        X_val_window_GaJuSi = torch.cat((X_val_window_GaJuSi, X_val_window), dim=0)
        y_val_window_GaJuSi = torch.cat((y_val_window_GaJuSi, y_val_window), dim=0)
        study_labels_val_window_GaJuSi = torch.cat((study_labels_val_window_GaJuSi, study_labels_val_window), dim=0)

        X_test_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_test_window.npy'))).flo2t()
        y_test_window = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_test_window.npy'))).long()
        study_labels_test_window = torch.tensor([study_label_map[study]] * len(y_test_window)).long()
        X_test_window_GaJuSi = torch.cat((X_test_window_GaJuSi, X_test_window), dim=0)
        y_test_window_GaJuSi = torch.cat((y_test_window_GaJuSi, y_test_window), dim=0)
        study_labels_test_window_GaJuSi = torch.cat((study_labels_test_window_GaJuSi, study_labels_test_window), dim=0)

        X_val_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_val_person.npy'))).float()
        y_val_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_val_person.npy'))).long()
        X_val_person_GaJuSi = torch.cat((X_val_person_GaJuSi, X_val_person), dim=0)
        y_val_person_GaJuSi = torch.cat((y_val_person_GaJuSi, y_val_person), dim=0)

        X_test_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'X_test_person.npy'))).float()
        y_test_person = torch.tensor(np.load(os.path.join(fold_i_dir, f'y_test_person.npy'))).long()
        X_test_person_GaJuSi = torch.cat((X_test_person_GaJuSi, X_test_person), dim=0)
        y_test_person_GaJuSi = torch.cat((y_test_person_GaJuSi, y_test_person), dim=0)

        train_window_dataset = TensorDataset(X_train_window, y_train_window)
        val_window_dataset = TensorDataset(X_val_window, y_val_window)
        test_window_dataset = TensorDataset(X_test_window, y_test_window)
        
        val_person_dataset = TensorDataset(X_val_person, y_val_person)
        test_person_dataset = TensorDataset(X_test_person, y_test_person)

        train_dataloader = DataLoader(train_window_dataset, batch_size=batch_size, shuffle=True)
        val_dataloader = DataLoader(val_window_dataset, batch_size=batch_size, shuffle=False)
        test_dataloader = DataLoader(test_window_dataset, batch_size=batch_size, shuffle=False)
        
        expert_model = expert_model_map[study]

        # Load pretrained expert model
        expert_model_dir = expert_model_dir_map[study]
        expert_model_i_name = os.listdir(expert_model_dir)[i_fold]
        expert_model_i_path = os.path.join(expert_model_dir, expert_model_i_name)
        expert_model.load_state_dict(torch.load(expert_model_i_path))

        print_h("EVALUATION ON PERSON DATA BY MAJORITY VOTING", 64)
        _, acc_person_majority_voting, f1_person_majority_voting, precision_person_majority_voting, recall_person_majority_voting, cm_person_majority_voting = eval_person_majority_voting(expert_model, val_person_dataset, criterion=None, average='weighted',
                                                                                                                                                                                                window_size=window_size, debug=False)
        print("acc:", acc_person_majority_voting)
        print("f1:", f1_person_majority_voting)
        print("precision:", precision_person_majority_voting)
        print("recall:", recall_person_majority_voting)
        print("cm:\n", np.array(cm_person_majority_voting))
        print()

        expert_outputs[study] = update_metrics(expert_outputs[study], {
            'acc': acc_person_majority_voting,
            'f1': f1_person_majority_voting,
            'precision': precision_person_majority_voting,
            'recall': recall_person_majority_voting,
            'cm': cm_person_majority_voting,
        })

    print_h("GATE", 96)

    # train_window_dataset_GaJuSi = TensorDataset(X_train_window_GaJuSi, y_train_window_GaJuSi)
    # val_window_dataset_GaJuSi = TensorDataset(X_val_window_GaJuSi, y_val_window_GaJuSi)
    # test_window_dataset_GaJuSi = TensorDataset(X_test_window_GaJuSi, y_test_window_GaJuSi)

    train_window_dataset_GaJuSi = TensorDataset(X_train_window_GaJuSi, study_labels_train_window_GaJuSi)
    val_window_dataset_GaJuSi = TensorDataset(X_val_window_GaJuSi, study_labels_val_window_GaJuSi)
    test_window_dataset_GaJuSi = TensorDataset(X_test_window_GaJuSi, study_labels_test_window_GaJuSi)

    train_dataloader_GaJuSi = DataLoader(train_window_dataset_GaJuSi, batch_size=batch_size, shuffle=True)
    val_dataloader_GaJuSi = DataLoader(val_window_dataset_GaJuSi, batch_size=batch_size, shuffle=False)
    test_dataloader_GaJuSi = DataLoader(test_window_dataset_GaJuSi, batch_size=batch_size, shuffle=False)

    gate_model = RNNInceptionTime(c_in=n_feat, c_out=len(study_label_map.keys()), seq_len=window_size, bidirectional=True).to(device)

    # Load pretrained gate model
    gate_model_i_name = os.listdir(gate_model_dir)[i_fold]
    gate_model_i_path = os.path.join(gate_model_dir, gate_model_i_name)
    gate_model.load_state_dict(torch.load(gate_model_i_path))

    print_h("EVALUATION ON WINDOW DATA", 64)
    
    _, acc_window, f1_window, precision_window, recall_window, cm_window = eval_window(gate_model, test_dataloader_GaJuSi, average='weighted')

    print("acc:", acc_window)
    print("f1:", f1_window)
    print("precision:", precision_window)
    print("recall:", recall_window)
    print("cm:\n", np.array(cm_window))
    print()

    gate_metrics = update_metrics(gate_metrics, {
        'acc': acc_window,
        'f1': f1_window,
        'precision': precision_window,
        'recall': recall_window,
        'cm': cm_window,
    })

    print_h("MoE", 96)

    val_person_dataset_GaJuSi = TensorDataset(X_val_person_GaJuSi, y_val_person_GaJuSi)
    test_person_dataset_GaJuSi = TensorDataset(X_test_person_GaJuSi, y_test_person_GaJuSi)

    moe_model = HardMoE(experts=expert_model_map.values(), gate=gate_model)

    print_h("EVALUATION ON PERSON DATA BY MAJORITY VOTING", 64)
    _, acc_person_majority_voting, f1_person_majority_voting, precision_person_majority_voting, recall_person_majority_voting, cm_person_majority_voting = eval_person_majority_voting(moe_model, val_person_dataset_GaJuSi, criterion=None, average='weighted',
                                                                                                                                                                                        window_size=window_size, debug=False)
    print("acc:", acc_person_majority_voting)
    print("f1:", f1_person_majority_voting)
    print("precision:", precision_person_majority_voting)
    print("recall:", recall_person_majority_voting)
    print("cm:\n", np.array(cm_person_majority_voting))
    print()

    moe_metrics = update_metrics(moe_metrics, {
        'acc': acc_person_majority_voting,
        'f1': f1_person_majority_voting,
        'precision': precision_person_majority_voting,
        'recall': recall_person_majority_voting,
        'cm': cm_person_majority_voting,
    })

    # for metric in ['acc', 'f1', 'precision', 'recall', 'cm']:
    #     moe_metrics[metric]['folds'] += [moe_eval_output[metric]]
    #     if metric != 'cm':
    #         moe_metrics[metric]['avg'] = np.mean(moe_metrics[metric]['folds'])
    #         moe_metrics[metric]['std'] = np.std(moe_metrics[metric]['folds'])
    
    # output['window']['train_loss']['folds'].append(global_train_loss_list)
    # output['window']['val_loss']['folds'].append(global_val_loss_window_list)

    # break # Test for only 1 fold

                                                             FOLD-1                                                             
data/preprocessed_LOOCV/Ga_k_fold_10_window_500_stride_500_n_feat_16/fold_3/X_train_window.npy


EOFError: No data left in file