In [1]:
ROOT_DIR = '/gpfs/commons/groups/gursoy_lab/aelhussein/ot_cost/otcost_fl_rebase'
import pandas as pd
import sys
import numpy as np
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import bootstrap

# LR only

In [2]:
def load_results(DATASET, grid):
    if grid:
        with open(f'{ROOT_DIR}/results/{DATASET}/{DATASET}_hyperparameter_search_personal.pkl', 'rb') as f :
            results = pickle.load(f)
    else:
        with open(f'{ROOT_DIR}/results/{DATASET}/{DATASET}_hyperparameter_search.pkl', 'rb') as f :
            results = pickle.load(f)
    return results
    
def bootstrap_ci(data):
    estimates = {}
    for c in data:
        estimates[c]={}
        for arch in data[c]:
            estimates[c][arch]={}
            for optim in data[c][arch]:
                estimates[c][arch][optim]={}
                for lr in data[c][arch][optim]:
                    res = data[c][arch][optim][lr]
                    if len(res) == 1:
                        res = [res[0] + np.random.normal(0,1e-5) for i in range(10)]
                    mean = np.mean(res)
                    bs_reps = bootstrap(np.array(res).reshape(1,-1), statistic=np.mean, n_resamples=1000)
                    ci = bs_reps.confidence_interval[0:2]
                    estimates[c][arch][optim][lr] = np.mean([mean, mean, mean, ci[0], ci[1]])
    return estimates

def best_parameters(results_estimates):
    best_combinations = {}
    for cost, architectures in results_estimates.items():
        for architecture, optimizers in architectures.items():
            best_value = -float('inf')
            best_optimizer_lr = ""
            for optimizer, lrs in optimizers.items():
                for lr, value in lrs.items():
                    if value > best_value:
                        best_value = value
                        best_optimizer_lr = f"{optimizer}: {lr}: {value:.3f}"
            if cost not in best_combinations:
                best_combinations[cost] = {}
            best_combinations[cost][architecture] = best_optimizer_lr
    return best_combinations

def process_results(DATASET, grid=False):
    results = load_results(DATASET, grid)
    results_estimates = bootstrap_ci(results)
    best_hyperparams = best_parameters(results_estimates)
    return results_estimates, best_hyperparams

### Synthetic

In [18]:
DATASET = 'Synthetic'
results_estimates, best_hyperparams = process_results(DATASET)

### Credit

In [108]:
DATASET = 'Credit'
results_estimates, best_hyperparams = process_results(DATASET)

### Weather

In [121]:
DATASET = 'Weather'
results_estimates, best_hyperparams = process_results(DATASET)

### EMNIST

In [24]:
DATASET = 'EMNIST'
results_estimates, best_hyperparams = process_results(DATASET)

### CIFAR

In [30]:
DATASET = 'CIFAR'
results_estimates, best_hyperparams = process_results(DATASET)

### IXITiny

In [34]:
DATASET = 'IXITiny'
results_estimates, best_hyperparams = process_results(DATASET)

### ISIC

In [5]:
DATASET = 'ISIC'
results_estimates, best_hyperparams = process_results(DATASET)

# Grid search LR and reg param (DITTO, pFedMe)

In [4]:
grid = True

### Synthetic

In [5]:
DATASET = 'Synthetic'
results_estimates, best_hyperparams = process_results(DATASET, grid)
best_hyperparams

{0.03: {'pfedme': 'ADM: (0.05, 0.001): 0.682',
  'ditto': 'ADM: (0.1, 0.01): 0.679'},
 0.4: {'pfedme': 'ADM: (0.1, 0.001): 0.644',
  'ditto': 'ADM: (0.1, 0.01): 0.641'},
 0.1: {'pfedme': 'ADM: (0.1, 0.001): 0.644',
  'ditto': 'ADM: (0.1, 0.01): 0.626'}}

### Credit

In [6]:
DATASET = 'Credit'
results_estimates, best_hyperparams = process_results(DATASET, grid)
best_hyperparams

{0.12: {'pfedme': 'ADM: (0.1, 0.01): 0.952', 'ditto': 'ADM: (0.1, 1): 0.941'},
 0.4: {'pfedme': 'ADM: (0.1, 0.01): 0.920',
  'ditto': 'ADM: (0.05, 0.001): 0.911'},
 0.23: {'pfedme': 'ADM: (0.1, 0.001): 0.932',
  'ditto': 'ADM: (0.1, 0.001): 0.919'}}

### Weather

In [7]:
DATASET = 'Weather'
results_estimates, best_hyperparams = process_results(DATASET, grid)
best_hyperparams

{0.11: {'pfedme': 'ADM: (0.1, 0.1): 0.911', 'ditto': 'ADM: (0.1, 0.5): 0.907'},
 0.4: {'pfedme': 'ADM: (0.1, 0.01): 0.920', 'ditto': 'ADM: (0.1, 0.5): 0.912'},
 0.3: {'pfedme': 'ADM: (0.1, 0.1): 0.917', 'ditto': 'ADM: (0.1, 0.5): 0.916'}}

### EMNIST

In [8]:
DATASET = 'EMNIST'
results_estimates, best_hyperparams = process_results(DATASET, grid)
best_hyperparams

{0.11: {'pfedme': 'ADM: (0.05, 0.1): 0.962',
  'ditto': 'ADM: (0.1, 0.5): 0.942'},
 0.25: {'pfedme': 'ADM: (0.05, 0.01): 0.801',
  'ditto': 'ADM: (0.1, 0.001): 0.746'},
 0.34: {'pfedme': 'ADM: (0.05, 0.01): 0.830',
  'ditto': 'ADM: (0.1, 0.01): 0.746'},
 0.39: {'pfedme': 'ADM: (0.05, 0.01): 0.877',
  'ditto': 'ADM: (0.1, 0.01): 0.814'},
 0.19: {'pfedme': 'ADM: (0.05, 0.1): 0.834', 'ditto': 'ADM: (0.1, 1): 0.793'}}

### CIFAR

In [9]:
DATASET = 'CIFAR'
results_estimates, best_hyperparams = process_results(DATASET, grid)
best_hyperparams

{0.08: {'pfedme': 'ADM: (0.01, 0.5): 0.901',
  'ditto': 'ADM: (0.005, 1): 0.874'}}

### IXITiny

In [None]:
DATASET = 'IXITiny'
results_estimates, best_hyperparams = process_results(DATASET, grid)

### ISIC

In [None]:
DATASET = 'ISIC'
results_estimates, best_hyperparams = process_results(DATASET, grid)