In [1]:
import pickle
from utils import plot_adj
from params import PARAMS
from models import GCNAuto
from main import init_model_params
from datetime import datetime
from tqdm import tqdm
import os
from main import  prepare_datasets, show_metrics, run_model

In [2]:
def kernel_picker(kernel_name, random_seed, results_path, device):
    model = GCNAuto(kernel_type=kernel_name, 
    in_features=PARAMS['SEQ_LEN'], 
    n_nodes=PARAMS['N_CHANNELS'], 
    num_classes=PARAMS['N_CLASSES'], 
    hidden_sizes=PARAMS['GCNAUTO_HIDDEN_SIZES'], 
    dropout_p=PARAMS['GCNAUTO_DROPOUT_P'], 
    device=PARAMS['DEVICE'])

    model = init_model_params(model, random_seed=random_seed)
    if kernel_name in 'bc':
        model.init_adj_diag()

    pickle.dump(model.adj.cpu().detach().numpy(), open(f'{results_path}/untrained_adj.pickle', 'wb'))
    plot_adj(model.adj.cpu().detach().numpy(), f'{results_path}/untrained_adj.png')

    return model.to(device)

In [3]:
def main():
    kernel_names = ['a', 'b', 'c', 'd', 'e']
    dataset_names = [f'cross_subject_data_{i}_new' for i in range(5)]
    dataset_names = dataset_names[:1]
    random_seeds = [0]

    time_now = datetime.now().strftime('%Y-%m-%d-%H-%M')
    for kernel_name in tqdm(kernel_names):
        for dataset_name in dataset_names:
            for random_seed in random_seeds:
                results_path = os.path.join('output', time_now, f'gcn-{kernel_name}', dataset_name, str(random_seed))
                os.makedirs(results_path, exist_ok=True)
                with open(os.path.join('output', time_now, 'params.txt'), 'w') as f:
                    f.write(str(PARAMS))

                dataloaders = prepare_datasets(random_seed, dataset_name, results_path)

                model = kernel_picker(kernel_name, random_seed, results_path, device=PARAMS['DEVICE'])
                run_model(random_seed, dataloaders, model, results_path)

    kernel_names = ['gcn-' + i for i in kernel_names]
    final_results = show_metrics(time_now, kernel_names, dataset_names, random_seeds)

In [4]:
def main():
    kernel_names = ['a', 'b', 'c', 'd', 'e']
    dataset_names = ['mnist_data']
    dataset_names = dataset_names[:1]
    random_seeds = [0]

    time_now = datetime.now().strftime('%Y-%m-%d-%H-%M')
    for kernel_name in tqdm(kernel_names):
        for dataset_name in dataset_names:
            for random_seed in random_seeds:
                results_path = os.path.join('output', time_now, f'gcn-{kernel_name}', dataset_name, str(random_seed))
                os.makedirs(results_path, exist_ok=True)
                with open(os.path.join('output', time_now, 'params.txt'), 'w') as f:
                    f.write(str(PARAMS))

                dataloaders = prepare_datasets(random_seed, dataset_name, results_path)

                model = kernel_picker(kernel_name, random_seed, results_path, device=PARAMS['DEVICE'])
                run_model(random_seed, dataloaders, model, results_path)

    kernel_names = ['gcn-' + i for i in kernel_names]
    final_results = show_metrics(time_now, kernel_names, dataset_names, random_seeds)

In [5]:
main()

100%|██████████| 5/5 [01:08<00:00, 13.70s/it]

            accuracy  precision_macro  precision_weighted  recall_macro  \
model_name                                                                
gcn-a       0.902444         0.903049            0.904020      0.902056   
gcn-b       0.878005         0.877341            0.878903      0.877425   
gcn-c       0.766627         0.765180            0.767747      0.764876   
gcn-d       0.808293         0.808524            0.810822      0.807580   
gcn-e       0.884215         0.892124            0.893010      0.883504   

            recall_weighted     AUROC   n_params  accuracy_std  \
model_name                                                       
gcn-a              0.902444  0.945613  1209114.0           NaN   
gcn-b              0.878005  0.931944  1209114.0           NaN   
gcn-c              0.766627  0.869487  1209114.0           NaN   
gcn-d              0.808293  0.893153  1209114.0           NaN   
gcn-e              0.884215  0.935324  1209114.0           NaN   

           


