In [None]:
import os
os.chdir('..')

In [None]:
import torch
import numpy as np
from instance_generator import sample_instances
import evaluate as ev
from gnn_library.OBM_greedy import OBM_Greedy
from params import TEST_CONFIGS, REGIMES
from tqdm import tqdm

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("PyTorch has version {}".format(torch.__version__))
print('Using device:', device)

## Hyperparameter tune threshold for greedy-t baseline

In [None]:
THRESHOLDS = np.linspace(0, 1, 101)

thresholded_greedy_models = {
    threshold: OBM_Greedy(threshold) 
    for threshold in THRESHOLDS
}

In [None]:
rng = np.random.default_rng(seed=0)
node_configs = REGIMES['BASE_TEST']
graph_configs = TEST_CONFIGS['ALL']
train_num = 40

train_instances = [
        sample_instances(
            *node_config,
            train_num,
            rng,
            {'noise': 0},
            **graph_config
        )
        for graph_config in graph_configs
        for node_config in node_configs
]

# flat map the train instances
train_instances = [
    instance 
    for instance_lst in train_instances
    for instance in instance_lst
]

greedy_ratios = {}
for threshold, model in tqdm(thresholded_greedy_models.items()): 
    ratio, _ = ev.evaluate_model(
        meta_model=None,
        meta_model_type=None,
        base_models=[model],
        instances=train_instances,
        batch_size=50,
        rng=rng,
        num_realizations=5
    )
    greedy_ratios[threshold] = np.mean(ratio['learned'])
    

max_threshold = max(greedy_ratios, key = greedy_ratios.get)
print(f"Best threshold value: {max_threshold} achieves CR: {greedy_ratios[max_threshold]}")