In [None]:
import os
os.chdir('..')

In [None]:
import torch
import pickle
import numpy as np

from gnn_library.util import load
from evaluate import evaluate_model
from instance_generator import sample_instances
from params import REGIMES, TEST_CONFIGS
from util import _plot_approx_ratios_all

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device('cuda:6' if torch.cuda.is_available() else 'cpu')
print("PyTorch has version {}".format(torch.__version__))
print('Using device:', device)

GNN, args = load('GNN_large_10_6', device)

## Size generalization experiment

#### Evaluate CRs for all graph configurations across different graph sizes

In [None]:
rng = np.random.default_rng(seed=0)
num_trials = 150
batch_size = 500
graph_configs = TEST_CONFIGS['ALL']
node_configs = REGIMES['SIZE_GENERALIZATION']

sizes = [x+y for (x,y) in node_configs]
baselines_kwargs = {
    'greedy': {},
    'greedy_t': {'threshold': 0.35},
    'lp_rounding': {'rng': rng}
}

data = [[] for _ in range(len(graph_configs))]

for i, node_config in enumerate(node_configs):
    for data_index, graph_config in enumerate(graph_configs):
        print(node_config, graph_config)
        
        instances = sample_instances(
            *node_config,
            num_trials,
            rng,
            args.__dict__,
            **graph_config
        )

        cr_ratios, _ = evaluate_model(
            meta_model=None,
            meta_model_type=None,
            base_models=[GNN],
            instances=instances,
            batch_size=batch_size,
            rng=rng,
            num_realizations=5,
            baselines=['greedy', 'greedy_t', 'lp_rounding'],
            **baselines_kwargs
        )
        data[data_index].append(cr_ratios)
    

        with open(f"experiment_output/size_generalization_main1.pickle", 'wb') as handle:
            pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

        torch.cuda.empty_cache()

#### Generate size generalization plots

In [None]:

with open(f"experiment_output/size_generalization_main1.pickle", 'rb') as handle:
    results = pickle.load(handle)

filtered_results = {frozenset(g.items()): data for g, data in zip(graph_configs, results)}
_plot_approx_ratios_all(sizes, filtered_results, x_axis_name= "Total number of nodes $N$", confidence = 0.95)