In [1]:
import torch
from utils import prepare_data
from train import get_dataloaders
import pandas as pd
import numpy as np
from params import PARAMS
from sklearn.model_selection import train_test_split
from models import GCNAuto
from main import init_model_params

from train import train_model_2
from main import model_predict, print_classification_report

In [2]:
import pickle
from utils import plot_adj

def kernel_picker(kernel_name, random_seed, results_path, device):
    model = GCNAuto(kernel_type=kernel_name, 
    in_features=PARAMS['SEQ_LEN'], 
    n_nodes=PARAMS['N_CHANNELS'], 
    num_classes=PARAMS['N_CLASSES'], 
    hidden_sizes=PARAMS['GCNAUTO_HIDDEN_SIZES'], 
    dropout_p=PARAMS['GCNAUTO_DROPOUT_P'], 
    device=PARAMS['DEVICE'])

    model = init_model_params(model, random_seed=random_seed)
    
    if 'ab' in kernel_name:
        model.init_node_embeddings()
        if results_path is not None:
            pickle.dump(model.adj.cpu().detach().numpy(), open(f'{results_path}/untrained_adj.pickle', 'wb'))
            plot_adj(model.adj.cpu().detach().numpy(), f'{results_path}/untrained_adj.png')

    model = model.to(device)

    return model

In [3]:
from datetime import datetime
from tqdm import tqdm
import os
from main import  prepare_datasets, show_metrics, run_model


kernel_names = ['a', 'b', 'c', 'd']
dataset_names = [f'cross_subject_data_{i}_5_subjects' for i in range(5)]
random_seeds = [0]

time_now = datetime.now().strftime('%Y-%m-%d-%H-%M')
for kernel_name in tqdm(kernel_names):
    for dataset_name in dataset_names:
        for random_seed in random_seeds:
            results_path = os.path.join('output', time_now, f'gcn-{kernel_name}', dataset_name, str(random_seed))
            os.makedirs(results_path, exist_ok=True)
            with open(os.path.join('output', time_now, 'params.txt'), 'w') as f:
                f.write(str(PARAMS))

            dataloaders = prepare_datasets(random_seed, dataset_name, results_path)

            model = kernel_picker(kernel_name, random_seed, results_path, device=PARAMS['DEVICE'])
            run_model(random_seed, dataloaders, model, results_path)

kernel_names = ['gcn-' + i for i in kernel_names]
final_results = show_metrics(time_now, kernel_names, dataset_names, random_seeds)

100%|██████████| 4/4 [45:01<00:00, 675.50s/it]

            accuracy  precision_macro  precision_weighted  recall_macro  \
model_name                                                                
gcn-a       0.375625         0.377317            0.377781      0.374891   
gcn-b       0.288750         0.289932            0.291385      0.291742   
gcn-c       0.285625         0.292845            0.296160      0.285452   
gcn-d       0.240000         0.163706            0.160699      0.251899   

            recall_weighted     AUROC  n_params  accuracy_std  \
model_name                                                      
gcn-a              0.375625  0.583108  358404.0      0.071412   
gcn-b              0.288750  0.527639  358404.0      0.048038   
gcn-c              0.285625  0.523931  358404.0      0.033263   
gcn-d              0.240000  0.501137  358404.0      0.020658   

            precision_macro_std  precision_weighted_std  recall_macro_std  \
model_name                                                                  
gcn-




            accuracy  precision_macro  precision_weighted  recall_macro  \
model_name                                                                
gcn-a       0.375625         0.377317            0.377781      0.374891   
gcn-b       0.288750         0.289932            0.291385      0.291742   
gcn-c       0.285625         0.292845            0.296160      0.285452   
gcn-d       0.240000         0.163706            0.160699      0.251899   

            recall_weighted     AUROC  n_params  accuracy_std  \
model_name                                                      
gcn-a              0.375625  0.583108  358404.0      0.071412   
gcn-b              0.288750  0.527639  358404.0      0.048038   
gcn-c              0.285625  0.523931  358404.0      0.033263   
gcn-d              0.240000  0.501137  358404.0      0.020658   

            precision_macro_std  precision_weighted_std  recall_macro_std  \
model_name                                                                  
gcn-a                  0.068864                0.070144          0.070409   
gcn-b                  0.052426                0.055291          0.047405   
gcn-c                  0.020333                0.022615          0.029750   
gcn-d                  0.099775                0.101857          0.018610   

            recall_weighted_std  
model_name                       
gcn-a                  0.071412  
gcn-b                  0.048038  
gcn-c                  0.033263  
gcn-d                  0.020658  