In [None]:
import torch
import numpy as np

import os
os.chdir('..')
from torch_geometric.loader import DataLoader
from gnn_library.util import train, save, load
from evaluate import evaluate_model, pp_output
import instance_generator as ig
import torch_converter as tc
import evaluate as ev
import osmnx as ox
from util import Dataset
from gnn_library.OBM_threshold_greedy import OBM_Threshold_Greedy

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print("PyTorch has version {}".format(torch.__version__))
print('Using device:', device)

In [None]:
# Hyperparam optimized for 10,6

args = {
    'processor':         'GENConv',
    'head':              'regression',     
    'num_layers':        3,
    'num_mlp_layers':    3,
    'aggr':              'max',
    'batch_size':        8,
    'node_feature_dim':  5,
    'edge_feature_dim':  1,
    'graph_feature_dim': 2,
    'hidden_dim':        32,
    'output_dim':        1,
    'dropout':           0.0306,
    'epochs':            20,
    'opt':               'adagrad',
    'opt_scheduler':     'none',
    'opt_restart':       0,
    'weight_decay':      5e-3,
    'lr':                0.0121,
    'device':            device,
    'noise':             0.5 # Set to 0 to not train on noisy features
}

In [None]:
train_num = 200; test_num = 100
node_config = (30,10)

er_config = {
    'graph_type': 'ER',
    'p': 0.75,
    'weighted': True
}
ba_config = {
    'graph_type': 'BA',
    'ba_param': 2,
    'weighted': True
}
geom_config = {
        'graph_type': 'GEOM',
        'q': 0.15,
        'd': 2,
        'weighted': True
}

rng = np.random.default_rng(10)

train_instances = [
        *ig.sample_instances(*node_config, train_num, rng, args, **er_config),
        *ig.sample_instances(*node_config, train_num, rng, args, **ba_config),
        *ig.sample_instances(*node_config, train_num, rng, args, **geom_config),
    ]

test_instances = [
        *ig.sample_instances(*node_config, test_num, rng, args, **er_config),
        *ig.sample_instances(*node_config, test_num, rng, args, **ba_config),
        *ig.sample_instances(*node_config, test_num, rng, args, **geom_config),
    ]

train_data = Dataset(tc._instances_to_train_samples(train_instances, args['head']))
test_data = Dataset(tc._instances_to_train_samples(test_instances, args['head']))

train_loader = DataLoader(
    train_data,
    batch_size=args['batch_size'],
    shuffle=True,
    num_workers=4
)

test_loader = DataLoader(
    test_data,
    batch_size=args['batch_size'],
    shuffle=True,
    num_workers=4
)

In [None]:
_, _, _, GNN, _ = train(train_loader, test_loader, args)

In [None]:
save(GNN, args, 'GNN2')

In [None]:
GNN, args = load('GNN2', device)

## Box plot baseline evaluation

In [None]:
import gnn_library.util as util
num_trials = 40
threshold = 0.35
batch_size = 500

models = [("GNN", GNN)]
t_greedy = OBM_Threshold_Greedy(threshold) 

rng = np.random.default_rng()

node_configs = util.node_configs_gnn1
graph_configs = util.graph_configs_standard

data = {node_config: {} for node_config in node_configs}


for node_config in node_configs:
    for graph_config in graph_configs:
            print(graph_config)
            instances = ig.sample_instances(*node_config, num_trials, rng, args, **graph_config)

            for model_name, model in models:
                cr_ratios, _ = evaluate_model(
                    meta_model=None,
                    meta_model_type=None,
                    base_models=[model],
                    instances=instances,
                    batch_size=batch_size,
                    rng=rng,
                    num_realizations=5,
                    baselines=['greedy', 'lp_rounding'],
                )

                data[node_config][graph_config['graph_type']] = cr_ratios

                t_greedy_ratio, _ = evaluate_model(
                    meta_model=None,
                    meta_model_type=None,
                    base_models=[t_greedy],
                    instances=instances,
                    batch_size=batch_size,
                    rng=rng,
                    num_realizations=5,
                )

                # Threshold greedy is considered as a model here
                data[node_config][graph_config['graph_type']]['threshold_greedy'] = t_greedy_ratio['learned']


In [None]:
filtered_data = {}
for graph_type, val in data[util.node_configs_gnn[0]].items():
	if dict(graph_type) in util.graph_configs_main:
		filtered_data[graph_type] = val

In [None]:
from util import _box_plots
_box_plots(filtered_data, lambda graph_type: f"GNN2_classify_{graph_type} {node_config[1]}x{node_config[0]}")