In [None]:
import os
os.chdir('..')

In [None]:
import torch
import numpy as np
import pickle
from gnn_library.util import load
from evaluate import evaluate_model
from instance_generator import sample_instances
from params import TEST_CONFIGS
from util import _box_plots

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
print("PyTorch has version {}".format(torch.__version__))
print('Using device:', device)

models = {
	'GENConv': load('GNN_large_10_6', device),
	'DeeperGCN': load('other_GNN_deeperGCN', device),
	'GATv2Conv': load('other_GNN_Gatv2Conv', device),
	'GraphConv': load('other_GNN_GraphConv', device),
	'GCNConv': load('other_GNN_GCNConv', device),
}
_, args = load('GNN_large_10_6', device)

## Box plot baseline evaluation

#### Evaluate CRs on all graph configurations

In [None]:
rng = np.random.default_rng()
num_trials = 150
batch_size = 500
node_configs = [(20, 10)]
graph_configs = TEST_CONFIGS['MAIN']
baselines_kwargs = {}

data = {node_config: {str(graph_config): {} for graph_config in graph_configs} for node_config in node_configs}

for node_config in node_configs:
    for graph_config in graph_configs:
            print(node_config, graph_config)
            instances = sample_instances(
                *node_config,
                num_trials,
                rng,
                args.__dict__,
                **graph_config
            )
            for model_name, (GNN, args) in models.items():
                cr_ratios, _ = evaluate_model(
                    meta_model=None,
                    meta_model_type=None,
                    base_models=[GNN],
                    instances=instances,
                    batch_size=batch_size,
                    rng=rng,
                    num_realizations=20,
                    baselines=[],
                    **baselines_kwargs
                )

                data[node_config][str(graph_config)][model_name] = cr_ratios['learned']

            with open(f"experiment_output/box_plots_multiple_gnns.pickle", 'wb') as handle:
                pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)


#### Generate box plots

In [None]:
results = {}
for graph_type, val in data[node_configs[0]].items():	
	results[graph_type] = val

_box_plots(results, lambda graph_type: f"GNN2_classify_{graph_type} {node_config[1]}x{node_config[0]}")