In [None]:
run_label = 0
gpu = 0
use_adv = False
network_name = 'fcdrop'
network_size = 128
log_root = 'log/'
normalize = False

In [None]:
import sys
sys.path.append('../../')
from functools import partial
import pickle 
import torch
from torch import optim
from torch.utils.data import DataLoader

from torchuq.transform.conformal import ConformalCalibrator
from torchuq.transform.naive import *
from torchuq.metric.distribution import *
from torchuq.dataset import regression

device = torch.device('cuda:%d' % gpu)

In [None]:
import time, os
start_time = time.time()

while True:
    run_name = 'network=%s-%d-use_adv=%r-normalize=%r/run_label=%d' % (network_name, network_size, use_adv, normalize, run_label)
    log_dir = os.path.join(log_root, run_name)
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)
        break
    run_label += 1
    break
    
print("Run number = %d" % run_label)

In [None]:
from torchuq.models.network import NetworkFC, NetworkFCDrop
network_classes = {'fc': NetworkFC, 'fcdrop': NetworkFCDrop}
def make_network(x_dim, out_dim):
    return network_classes[network_name](x_dim=x_dim, out_dim=out_dim, feat_dim=network_size)

In [None]:

# This function uses global variables network_class, run_label and device
def evaluate_performance(train_model, prediction_type, repeat=1, use_adv=False, debug_mode=False, verbose=False):
    results = {}
    
    for name in regression.dataset_names:
        results[name] = {'crps': torch.zeros(repeat), 'std': torch.zeros(repeat), 'nll': torch.zeros(repeat), 'ece': torch.zeros(repeat)}
        for rep in range(repeat):
            train_dataset, val_dataset, test_dataset = regression.get_regression_datasets(name, split_seed=run_label+rep, val_fraction=0.2, 
                                                                                          test_fraction=0.2, normalize=normalize, verbose=verbose)
            if debug_mode and len(train_dataset) > 1500:
                continue
            val_labels = val_dataset[:][1]
            test_labels = test_dataset[:][1]
            prediction_val, prediction_test = train_model(train_dataset, val_dataset, test_dataset, use_adv=use_adv, verbose=verbose, network_class=make_network, device=device)
            
            calibrator = ConformalCalibrator(input_type=prediction_type, interpolation='linear')
            calibrator.train(prediction_val, val_labels)
            final_prediction = calibrator(prediction_test)

            # Evaluate some performance metrics
    #         plot_reliability_diagram(final_prediction, test_labels)
    #         plot_density(final_prediction, test_labels)
            
            results[name]['crps'][rep] = compute_crps(final_prediction, test_labels).mean()
            results[name]['std'][rep] = compute_std(final_prediction).mean()
            results[name]['nll'][rep] = compute_nll(final_prediction, test_labels).mean()
            results[name]['ece'][rep] = compute_ece(final_prediction, test_labels, debiased=True)
        print("Finished dataset %s, nll=%.4f, std=%.4f, crps=%.4f, ece=%.4f" % 
              (name, results[name]['nll'].mean(), results[name]['std'].mean(), results[name]['crps'].mean(), results[name]['ece'].mean()))
    return results


In [None]:
from base_learner import train_point, train_normal, train_quantile

In [None]:
def train_ensemble(train_dataset, val_dataset, test_dataset, n_ensemble=10, use_adv=False):
    predictions_val = []
    predictions_test = []
    for n in range(n_ensemble):
        val, test = train_normal(train_dataset, val_dataset, test_dataset, use_adv=use_adv, verbose=False)
        predictions_val.append(val)
        predictions_test.append(test)
    

In [None]:
results_point = evaluate_performance(train_point, 'point', use_adv=use_adv)
with open(os.path.join(log_dir, 'results_point.pickle'), 'wb') as handle:
    pickle.dump(results_point, handle)

In [None]:
results_dist = evaluate_performance(train_normal, 'distribution', use_adv=use_adv)
with open(os.path.join(log_dir, 'results_dist.pickle'), 'wb') as handle:
    pickle.dump(results_dist, handle)

In [None]:
results_quantile = []
for i in [2, 3, 5, 7, 10, 15, 20]:
    results_quantile = evaluate_performance(partial(train_quantile, n_quantiles=i), 'quantile', use_adv=use_adv)
    with open(os.path.join(log_dir, 'results_quantile_%d.pickle' % i), 'wb') as handle:
        pickle.dump(results_quantile, handle)