In [1]:
ROOT_DIR = '/gpfs/commons/groups/gursoy_lab/aelhussein/ot_cost/otcost_fl_rebase'
import pandas as pd
import sys
import numpy as np
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import bootstrap

In [2]:
def load_results(DATASET):
    with open(f'{ROOT_DIR}/results/{DATASET}_hyperparameter_search.pkl', 'rb') as f :
        results = pickle.load(f)
    return results
    
def bootstrap_ci(data):
    estimates = {}
    for c in data:
        estimates[c]={}
        for arch in data[c]:
            estimates[c][arch]={}
            for optim in data[c][arch]:
                estimates[c][arch][optim]={}
                for lr in data[c][arch][optim]:
                    res = data[c][arch][optim][lr]
                    if len(res) == 1:
                        res = [res[0] + np.random.normal(0,1e-5) for i in range(10)]
                    mean = np.mean(res)
                    bs_reps = bootstrap(np.array(res).reshape(1,-1), statistic=np.mean, n_resamples=1000)
                    ci = bs_reps.confidence_interval[0:2]
                    estimates[c][arch][optim][lr] = np.mean([mean, mean, mean, ci[0], ci[1]])
    return estimates

def best_parameters(results_estimates):
    best_combinations = {}
    for cost, architectures in results_estimates.items():
        for architecture, optimizers in architectures.items():
            best_value = -float('inf')
            best_optimizer_lr = ""
            for optimizer, lrs in optimizers.items():
                for lr, value in lrs.items():
                    if value > best_value:
                        best_value = value
                        best_optimizer_lr = f"{optimizer}: {lr}: {value:.3f}"
            if cost not in best_combinations:
                best_combinations[cost] = {}
            best_combinations[cost][architecture] = best_optimizer_lr
    return best_combinations

def process_results(DATASET):
    results = load_results(DATASET)
    results_estimates = bootstrap_ci(results)
    best_hyperparams = best_parameters(results_estimates)
    return results_estimates, best_hyperparams

### Synthetic

In [110]:
DATASET = 'Synthetic'
results_estimates, best_hyperparams = process_results(DATASET)

### Credit

In [108]:
DATASET = 'Credit'
results_estimates, best_hyperparams = process_results(DATASET)

### Weather

In [121]:
DATASET = 'Weather'
results_estimates, best_hyperparams = process_results(DATASET)

### EMNIST

In [115]:
DATASET = 'EMNIST'
results_estimates, best_hyperparams = process_results(DATASET)

### CIFAR

In [49]:
DATASET = 'CIFAR'
results_estimates, best_hyperparams = process_results(DATASET)

### IXITiny

In [34]:
DATASET = 'IXITiny'
results_estimates, best_hyperparams = process_results(DATASET)

### ISIC

In [5]:
DATASET = 'ISIC'
results_estimates, best_hyperparams = process_results(DATASET)