In [None]:
import torch
import numpy as np
import os
os.chdir('..')

import torch_converter as tc
import instance_generator as ig
from gnn_library.util import train, save, load, gen_train_input
from evaluate import evaluate_model
import evaluate as ev


%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')
print("PyTorch has version {}".format(torch.__version__))
print('Using device:', device)

## Train meta-GNN

In [None]:
GNN1, args1 = load('GNN1_hyperparam_tuned', device)
GNN2, args2 = load('GNN2_hyperparam_tuned', device)

In [None]:
META_TRAIN_CONFIG = {
    'train_num': 150,
    'val_num': 50,
    'configs': [
        {
            'graph_type': 'ER',
            'p': 1,
            'weighted': True
        },
        {
            'graph_type': 'BA',
            'ba_param': 4,
            'weighted': False
        },
        {
            'graph_type': 'GEOM',
            'q': 0.25,
            'd': 2,
            'weighted': True
        }
    ],
    'regimes': [(6, 10), (8, 8), (10, 6)]
}

In [None]:
args = {
    'processor':         'DeeperGCN',
    'head':              'meta',
    'num_layers':        4,
    'num_mlp_layers':    2,
    'aggr':              'max',
    'batch_size':        6,
    'node_feature_dim':  7,
    'edge_feature_dim':  1,
    'graph_feature_dim': 2,
    'hidden_dim':        8,
    'output_dim':        2,
    'head_mlp_dim':      8,
    'dropout':           0,
    'epochs':            35,
    'opt':               'adam',
    'opt_scheduler':     'none',
    'opt_restart':       0,
    'weight_decay':      5e-3,
    'lr':                0.001,
    'device':            device
}

train_loader, val_loader = gen_train_input(META_TRAIN_CONFIG, args, seed=0, base_models=[GNN1, GNN2])
_, _, _, META_GNN, _ = train(train_loader, val_loader, args)
save(META_GNN, args, 'META_GNN')


## Evaluate meta-GNN

In [None]:
META_GNN, args = load('META_GNN', device)

In [None]:
import pickle
import matplotlib.pyplot as plt

def graph_config_to_string(config):
    graph_type = config['graph_type']
    if graph_type == 'ER':
        return f"ER_{config['p']}"
    if graph_type == 'BA':
        return f"BA_{config['ba_param']}"
    if graph_type == 'GEOM':
        return f"GEOM_{config['q']}"
    if graph_type == 'GM':
        return "GM"
    if graph_type == 'OSMNX':
        return f"OSMNX_{config['location']}"

def save_meta_experiment(graph_str, data):
    with open(f"experiments/meta_{graph_str}.pickle", 'wb') as handle:
            pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

def upload_meta_experiment(graph_str, data):
    filepath = f"experiments/meta_{graph_str}.pickle"
    try:
        with open(filepath, 'rb') as handle:
            current_data = pickle.load(handle)
        
        for model in current_data.keys():
            if model == 'num_trials':
                current_data[model] += data[model]
            else:
                for i in range(len(current_data[model])):
                    current_data[model][i].extend(data[model][i])

        with open(filepath, 'wb') as handle:
            pickle.dump(current_data, handle, protocol=pickle.HIGHEST_PROTOCOL)

    except:
        with open(filepath, 'wb') as handle:
            pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)


def load_meta_experiments(configs):
    data = {}
    for config in configs:
        config_str = graph_config_to_string(config)
        with open(f"experiments/meta_{config_str}.pickle", 'rb') as handle:
            data[config_str] = pickle.load(handle)
    return data


import scipy.stats as st 

def _plot_approx_ratios(ratios, data, naming_function = lambda graph_type: graph_type, x_axis_name = "# online / # offline", confidence = 0.99):
    for graph_type, graph_data in data.items():
        avg_ratios = {}

        for model, cr_by_ratio in graph_data.items():
            if model != 'num_trials':
                avg_ratios[model] = []
                for raw_crs in cr_by_ratio:
                    mean = np.mean(raw_crs)
                    ci_lb, ci_ub = st.norm.interval(
                        alpha=0.95, 
                        loc=mean, 
                        scale=st.sem(raw_crs)
                    )
                    avg_ratios[model].append((mean, ci_lb, ci_ub))

        fig = plt.figure(figsize=(8,6))
        for model, model_ratios in avg_ratios.items():
            competitive_ratios = [val[0] for val in model_ratios]
            ci_lbs = [val[1] for val in model_ratios]
            ci_ubs = [val[2] for val in model_ratios]
            plt.plot(ratios, competitive_ratios, label=model)
            plt.fill_between(ratios, ci_lbs, ci_ubs, alpha = 0.2)

        title = f"{naming_function(graph_type)}"
        plt.title(title, fontsize = 18)
        plt.xlabel(x_axis_name, fontsize = 15)
        plt.ylabel('Average Competitive Ratio', fontsize = 15)
        plt.legend()
        plt.savefig(f"data/{title.replace(' ', '_')}.png")
        plt.show()

In [None]:
graph_configs = [
    {
        'graph_type': 'ER',
        'p': 0.25,
        'weighted': True
    },
    {
        'graph_type': 'ER',
        'p': 0.5,
        'weighted': True
    },
    {
        'graph_type': 'ER',
        'p': 0.75,
        'weighted': True
    },
    {
        'graph_type': 'BA',
        'ba_param': 4,
        'weighted': True
    },
    {
        'graph_type': 'BA',
        'ba_param': 6,
        'weighted': True
    },
    {
        'graph_type': 'BA',
        'ba_param': 8,
        'weighted': True
    },
    {
        'graph_type': 'GEOM',
        'q': 0.15,
        'd': 2,
        'weighted': True
    },
    {
        'graph_type': 'GEOM',
        'q': 0.25,
        'd': 2,
        'weighted': True
    },
    {
        'graph_type': 'GEOM',
        'q': 0.5,
        'd': 2,
        'weighted': True
    },
    {
        'graph_type': 'OSMNX',
        'location': 'Piedmont, California, USA'
    },
    {
        'graph_type': 'OSMNX',
        'location': 'Fremont, California, USA'
    },
    {
        'graph_type': 'GM'
    }
]

In [None]:
num_trials = 50
node_configs = [(x, 16) for x in np.arange(8, 65, 4)]
batch_size = 500 



baselines_kwargs = {
    'greedy': {},
    'greedy_t': {0.35},
    'lp_rounding': {}
}


    
ratios = [x/y for (x,y) in node_configs]
print(ratios)

def _init_data():
    return {
    "num_trials": num_trials,
    "meta_with_greedy": [],
    "greedy": [],
    "lp_rounding": [],
    "meta_no_greedy": [],
    "meta_threshold": []
}

for graph_config in graph_configs:
    data = _init_data()
    graph_str = graph_config_to_string(graph_config)

    for i, node_config in enumerate(node_configs):
        print(graph_config, node_config)
        seed = np.random.randint(0, 500000)
        rng = np.random.default_rng(seed)
        instances = ig.sample_instances(*node_config, num_trials, rng, args, **graph_config)


        rng = np.random.default_rng(seed)
        crs, _ = evaluate_model(
            meta_model=META_GNN,
            meta_model_type='gnn',
            base_models=[GNN1, GNN2],
            instances=instances,
            batch_size=batch_size,
            rng=rng,
            num_realizations=5,
            baselines=['greedy', 'lp_rounding'],
            **baselines_kwargs
        )

        no_greedy_crs, _ = evaluate_model(
            meta_model=META,
            meta_model_type='gnn',
            base_models=[GNN1, GNN2],
            instances=instances,
            batch_size=batch_size,
            rng=rng,
            num_realizations=5
        )

        threshold_crs, _ = evaluate_model(
            meta_model=None,
            meta_model_type='threshold',
            base_models=[GNN1, GNN2],
            instances=instances,
            batch_size=batch_size,
            rng=rng,
            num_realizations=5
        )



        data['meta_with_greedy'].append(crs['learned'])
        data['greedy'].append(crs['greedy'])
        data['lp_rounding'].append(crs['lp_rounding'])
        data['meta_no_greedy'].append(no_greedy_crs['learned'])
        data['meta_threshold'].append(threshold_crs['learned'])
    
    upload_meta_experiment(graph_str, data)

        