In [None]:
import os
os.chdir('..')

In [None]:
import numpy as np
import pickle
import torch

from gnn_library.util import train, save, load, gen_train_input
from gnn_library.train import train_base_model
from evaluate import evaluate_model
from instance_generator import sample_instances
from util import _plot_approx_ratios_all
from params import TRAIN_CONFIGS, TEST_CONFIGS, REGIMES

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("PyTorch has version {}".format(torch.__version__))
print('Using device:', device)

## Noise robustness experiment

#### Train GNNs on each noise level

In [None]:
def create_noise_robust_models(noise_values):
    models = []
    for noise_value in noise_values:
        print(f"Training model for noise {noise_value}")
        args = {
            'processor':         'GENConv',
            'head':              'regression',     
            'num_layers':        3,
            'num_mlp_layers':    3,
            'aggr':              'max',
            'batch_size':        8,
            'node_feature_dim':  5,
            'edge_feature_dim':  1,
            'graph_feature_dim': 2,
            'hidden_dim':        32,
            'output_dim':        1,
            'dropout':           0.0306,
            'epochs':            2, # 64
            'opt':               'adagrad',
            'opt_scheduler':     'none',
            'opt_restart':       0,
            'weight_decay':      5e-3,
            'lr':                0.0121,
            'device':            device,
            'noise':             noise_value
        }

        GNN = train_base_model(
            regime_key='NOISE',
            train_config=TRAIN_CONFIGS['NOISE'],
            name=f'GNN_noise_{noise_value}',
            args=args
        )
        models.append((f"GNN_{args['noise']}", GNN))

    return models

In [None]:
# If the models have already been trained, use get_models
def get_models(noise_values):
	models = []
	for noise_value in noise_values:
		GNN, args = load(f"GNN_noise_{noise_value}", device)
		models.append((f"GNN_{args.noise}", GNN))
	return models

In [None]:
NOISE_VALUES = np.linspace(0, 1, 21) 
models = create_noise_robust_models(NOISE_VALUES)

# If the models have already been trained, use get_models
# models = get_models(noise_values)

#### Evaluate CRs for all graph configurations and noise levels

In [None]:

rng = np.random.default_rng(seed=0)
num_trials = 100
batch_size = 500
graph_configs = TEST_CONFIGS['ALL']
node_configs = REGIMES['BASE_TEST']
baselines_kwargs = {
    'greedy': {},
    'greedy_t': {'threshold': 0.35},
    'lp_rounding': {}
}

data = [{node_config:[] for node_config in node_configs} for _ in range(len(graph_configs))]

for data_index, graph_config in enumerate(graph_configs):
    print(f"current graph {graph_config}")

    for node_config in node_configs:
        for noise_value, model in zip(NOISE_VALUES, models):
            print(f"Evaluating model for noise {noise_value}")

            instances = sample_instances(
                *node_config, 
                num_trials,
                rng,
                {'noise': noise_value},
                **graph_config
            )

            cr_ratios, _ = evaluate_model(
                meta_model=None,
                meta_model_type=None,
                base_models=[model[1]],
                instances=instances,
                batch_size=batch_size,
                rng=rng,
                num_realizations=5,
                baselines=['greedy', 'greedy_t', 'lp_rounding'],
                **baselines_kwargs
            )

            data[data_index][node_config].append(cr_ratios)

with open(f"experiment_output/noise.pickle", 'wb') as handle:
    pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

#### Generate noise generalization plots

In [None]:

with open(f"experiment_output/noise.pickle", 'rb') as handle:
    data_copy = pickle.load(handle)

filtered_data = {
    frozenset(g.items()): d[node_configs[0]] 
    for d,g in zip(data_copy, graph_configs)
}

_plot_approx_ratios_all(
    NOISE_VALUES,
    filtered_data,
    lambda graph_type: f"noise {graph_config} {node_config[1]}x{node_config[0]}",
    x_axis_name = "Noise standard deviation $\\rho$",
    confidence = 0.95
)