In [1]:
import wandb
import yaml
from utilities.utils import correct_type_of_entry
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sn
pd.set_option('display.max_rows', None)

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
dataset = "amazon"
sweep_config = "transformer"
values_to_fetch = ['kl_bound', 'binomial_approximation_shah', 'compression_set_size',
                   'complement_error', 'validation_error', 'test_error', 'CE_kl_bound', "CE_min_val_catoni_bound",
                   'complement_loss', 'validation_loss', 'test_loss', '_runtime']

sweep_config_name = "./configs/sweep_configs/" + sweep_config + ".yaml"
with open(sweep_config_name) as file:
    sweep_configuration = yaml.safe_load(file)
    
hps = {}
for key, item in sweep_configuration['parameters'].items():
    if item.get('values', None) is not None:
        hps[key] = correct_type_of_entry(item['values'])

size_hyperparams = tuple([len(l) for l in hps.values()])

In [16]:
api = wandb.Api()
entity = "mathieu-bazinet"
project = "sweep" + dataset
runs = api.runs(entity + "/" + project)

new_runs = []
for run in runs:
    if run.createdAt[:-4] > '2024-10-03T3:43':
        new_runs.append(run)
runs = new_runs

In [17]:
results_matrix = np.ones(((len(values_to_fetch),) + size_hyperparams))

for run_idx in range(len(runs)):
    run = runs[run_idx]
    if run.state == "finished":
        for val_to_fetch_idx in range(len(values_to_fetch)):
            matrix_idx = tuple([val_to_fetch_idx] + [hps[key].index(run.config[key]) for key in hps.keys()])
            val_to_fetch = values_to_fetch[val_to_fetch_idx]
            results_matrix[matrix_idx] = run.summary[val_to_fetch]
from itertools import product

hp_list = list(hps.values())[1:]
params_product = list(product(*hp_list))
name_list = []
idx_list = []
for params in params_product:
    name = ""
    for p in params:
        name += str(p) + " "
    name_list.append(name[:-1])
    idx = ()
    for p_idx in range(len(params)):
        p_key = list(hps.keys())[1:][p_idx]
        idx += (hps[p_key].index(params[p_idx]),)
    idx_list.append(tuple(idx))

In [18]:
reshaped_matrix = results_matrix.mean(1).reshape(results_matrix.shape[0],np.prod(results_matrix.shape[2:])).T

mean_df = pd.DataFrame(reshaped_matrix, index=name_list, columns=values_to_fetch)
mean_df

Unnamed: 0,kl_bound,binomial_approximation_shah,compression_set_size,complement_error,validation_error,test_error,CE_kl_bound,CE_min_val_catoni_bound,complement_loss,validation_loss,test_loss,_runtime
2 1e-06,0.736923,0.836869,531.4,0.645119,0.645545,0.644008,1.547068,1.614458,0.82631,0.827247,0.825272,8702.406852
2 1e-07,0.345396,0.436467,589.0,0.27663,0.279902,0.281002,0.993035,1.024934,0.431583,0.438555,0.439412,11560.148614
2 1e-08,0.139132,0.21847,1132.8,0.047288,0.054099,0.05598,0.859392,0.863772,0.119909,0.145947,0.147787,39436.586806
5 1e-06,0.514199,0.675601,672.0,0.39546,0.395631,0.397512,1.590304,1.691864,0.715004,0.719531,0.720913,12929.432608
5 1e-07,0.511449,0.640656,825.6,0.39965,0.402755,0.402114,1.619324,1.73076,0.66924,0.681019,0.682015,18903.68906
5 1e-08,0.156006,0.251868,1017.6,0.061651,0.068105,0.069413,0.914293,0.915383,0.168024,0.206368,0.208886,55197.383638


In [19]:
std_matrix = results_matrix.std(1).reshape(results_matrix.shape[0],np.prod(results_matrix.shape[2:])).T
std_df = pd.DataFrame(std_matrix, index=name_list, columns=values_to_fetch)
std_df

Unnamed: 0,kl_bound,binomial_approximation_shah,compression_set_size,complement_error,validation_error,test_error,CE_kl_bound,CE_min_val_catoni_bound,complement_loss,validation_loss,test_loss,_runtime
2 1e-06,0.15569,0.10015,265.489435,0.198209,0.197932,0.198375,0.280313,0.312035,0.097817,0.097381,0.098156,4387.949629
2 1e-07,0.33772,0.304975,306.282223,0.367991,0.36627,0.365711,0.189353,0.212791,0.297887,0.293945,0.293671,6884.191858
2 1e-08,0.027295,0.032868,337.808466,0.010945,0.010527,0.011932,0.162235,0.165393,0.011785,0.016286,0.018172,11994.913627
5 1e-06,0.115288,0.096339,0.0,0.113483,0.114815,0.113019,0.261117,0.23028,0.189021,0.190573,0.188458,1030.288578
5 1e-07,0.249237,0.208854,86.813824,0.259915,0.256227,0.256036,0.366457,0.382403,0.232898,0.226218,0.225358,1724.269451
5 1e-08,0.022496,0.031276,146.502696,0.01176,0.012467,0.012326,0.074488,0.073921,0.017683,0.018254,0.019893,11909.283575


In [20]:
# index = 0
# mean_arr = results_matrix.mean(1)[index]
# df = pd.DataFrame(mean_arr, index=hps['dropout_probability'], columns=hps['training_lr'])
# sn.heatmap(df, annot=True, cmap="Blues")
# plt.ylabel("Dropout probability")
# plt.xlabel("Training LR")
# # fig.suptitle(f"Comparison of the {values_to_fetch[index].replace("_", " ")} for the dataset {dataset}")
# # # Layout so plots do not overlap
# # fig.tight_layout()
# # fig.align_labels()

# plt.title(f"Comparison of the {values_to_fetch[index].replace("_", " ")} for the dataset {dataset}")
# plt.savefig(f"./results/amazon/{dataset}_{values_to_fetch[index]}_heatmaps.jpg",bbox_inches='tight')

In [23]:
val_of_interest = "kl_bound"
best_params = correct_type_of_entry(mean_df.index[mean_df[val_of_interest].argmin()].split())
print(best_params)
best_val_arr = mean_df.loc[mean_df.index[mean_df[val_of_interest].argmin()]]
std_val_arr = std_df.loc[mean_df.index[mean_df[val_of_interest].argmin()]]
print(f"complement error: {best_val_arr['complement_error']:.4f}±{std_val_arr['complement_error']:.4f} ",
    f"Validation error: {best_val_arr['validation_error']:.4f}±{std_val_arr['validation_error']:.4f} ",
      f'Test error : {best_val_arr['test_error']:.4f}±{std_val_arr['test_error']:.4f} ', 
      f"KL bound : {best_val_arr['kl_bound']:.4f}±{std_val_arr['kl_bound']:.4f} ",
    f'binomial : {best_val_arr['binomial_approximation_shah']:.4f}±{std_val_arr['binomial_approximation_shah']:.4f} ',
      f'compression set size : {best_val_arr['compression_set_size']:.4f}±{std_val_arr['compression_set_size']:.4f} ',
f"Complement loss: {best_val_arr['complement_loss']:.4f}±{std_val_arr['complement_loss']:.4f} "
    f"Validation loss: {best_val_arr['validation_loss']:.4f}±{std_val_arr['validation_loss']:.4f} "
      f'Test loss : {best_val_arr['test_loss']:.4f}±{std_val_arr['test_loss']:.4f} ', 
    f"CE KL bound: {best_val_arr['CE_kl_bound']:.4f}±{std_val_arr['CE_kl_bound']:.4f} "
      f'CE Catoni bound : {best_val_arr['CE_min_val_catoni_bound']:.4f}±{std_val_arr['CE_min_val_catoni_bound']:.4f} ',
    f"Runtime : {best_val_arr['_runtime']:.4f}±{std_val_arr['_runtime']:.4f}"
)

[2.0, 1e-08]
complement error: 0.0473±0.0109  Validation error: 0.0541±0.0105  Test error : 0.0560±0.0119  KL bound : 0.1391±0.0273  binomial : 0.2185±0.0329  compression set size : 1132.8000±337.8085  Complement loss: 0.1199±0.0118 Validation loss: 0.1459±0.0163 Test loss : 0.1478±0.0182  CE KL bound: 0.8594±0.1622 CE Catoni bound : 0.8638±0.1654  Runtime : 39436.5868±11994.9136
