In [1]:
import wandb
import yaml
from utilities.utils import correct_type_of_entry
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
from copy import deepcopy
import seaborn as sn
pd.set_option('display.max_rows', None)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = "amazon"

if "mnist" in dataset:
    values_to_fetch = ['complement_error', 'validation_error', 'test_error', 'test_loss']
    if "mnist" == dataset:
        sweep_config = "pretraining"
    else:
        sweep_config = "default"
elif "amazon" in dataset:
    values_to_fetch = ['complement_error', 'validation_error', 'test_error','complement_loss', 'validation_loss', 'test_loss']
    sweep_config = "baseline_transformer"
else:
    values_to_fetch = ['complement_loss', 'validation_loss', 'test_loss']
    sweep_config = "forest"

In [3]:
sweep_config_name = "./configs/sweep_configs/" + sweep_config + ".yaml"
with open(sweep_config_name) as file:
    sweep_configuration = yaml.safe_load(file)
    
hps = {}
for key, item in sweep_configuration['parameters'].items():
    if item.get('values', None) is not None:
        hps[key] = correct_type_of_entry(item['values'])
size_hyperparams = tuple([len(l) for l in hps.values()])

In [4]:
api = wandb.Api()
entity = "mathieu-bazinet"
project = "baseline_" + dataset
runs = api.runs(entity + "/" + project)

In [5]:
results_matrix = np.ones(((len(values_to_fetch),) + size_hyperparams))

for run_idx in range(len(runs)):
    run = runs[run_idx]
    if run.config['model_type'] == sweep_config or run.config['model_type'] not in ['forest', 'tree']:
        for val_to_fetch_idx in range(len(values_to_fetch)):
            try:
                matrix_idx = tuple([val_to_fetch_idx] + [hps[key].index(run.config[key]) for key in hps.keys()])
            except ValueError:
                continue
            val_to_fetch = values_to_fetch[val_to_fetch_idx]
            try:
                results_matrix[matrix_idx] = run.summary[val_to_fetch]
            except KeyError:
                results_matrix[matrix_idx] = [row[val_to_fetch] for row in run.scan_history(keys=[val_to_fetch])][-1]

In [6]:
hp_list = list(hps.values())[1:]
params_product = list(product(*hp_list))
name_list = []
idx_list = []
for params in params_product:
    name = ""
    for p in params:
        name += str(p) + " "
    name_list.append(name[:-1])
    idx = ()
    for p_idx in range(len(params)):
        p_key = list(hps.keys())[1:][p_idx]
        idx += (hps[p_key].index(params[p_idx]),)
    idx_list.append(tuple(idx))

In [7]:
reshaped_matrix = results_matrix.mean(1).reshape(results_matrix.shape[0],np.prod(results_matrix.shape[2:])).T
mean_df = pd.DataFrame(reshaped_matrix, index=name_list, columns=values_to_fetch)
mean_df

Unnamed: 0,complement_error,validation_error,test_error,complement_loss,validation_loss,test_loss
2 1e-06,0.0,0.048798,0.050146,2.7e-05,0.41288,0.419674
2 1e-07,0.0,0.049631,0.051344,3.1e-05,0.429045,0.435943
2 1e-08,0.03064,0.040807,0.041851,0.090995,0.115626,0.116354
5 1e-06,0.0,0.050011,0.050391,1.3e-05,0.453767,0.451666
5 1e-07,0.0,0.050605,0.051422,1e-05,0.485091,0.484729
5 1e-08,0.031271,0.042525,0.043428,0.093233,0.120747,0.121375


In [8]:
reshaped_std = results_matrix.std(1).reshape(results_matrix.shape[0],np.prod(results_matrix.shape[2:])).T
std_df = pd.DataFrame(reshaped_std, index=name_list, columns=values_to_fetch)
std_df

Unnamed: 0,complement_error,validation_error,test_error,complement_loss,validation_loss,test_loss
2 1e-06,0.0,0.001013,0.00103,1.1e-05,0.021279,0.016675
2 1e-07,0.0,0.000969,0.000794,1.3e-05,0.034926,0.028359
2 1e-08,0.000284,0.000448,0.000137,0.000677,0.001411,0.000562
5 1e-06,0.0,0.000445,0.000503,7e-06,0.047466,0.040548
5 1e-07,0.0,0.000869,0.000585,5e-06,0.042048,0.036284
5 1e-08,0.000833,0.000345,0.000152,0.000909,0.001644,0.000284


In [9]:
if "mnist" in dataset:
    val_of_interest = "validation_error"
elif "amazon" in dataset:
    val_of_interest = "validation_error"
else:
    val_of_interest = "validation_loss"

if "mnist" in dataset and "mnist" != dataset:
    wanted_model_type = "mlp"
    model_df = mean_df[[wanted_model_type in idx for idx in mean_df.index]]
    std_model_df = std_df[[wanted_model_type in idx for idx in std_df.index]]
    best_params = correct_type_of_entry(model_df.index[model_df[val_of_interest].argmin()].split())
    print(best_params)
    best_val_arr = model_df.loc[model_df.index[model_df[val_of_interest].argmin()]]
    std_val_arr = std_model_df.loc[model_df.index[model_df[val_of_interest].argmin()]]
else:
    best_params = correct_type_of_entry(mean_df.index[mean_df[val_of_interest].argmin()].split())
    print(best_params)
    best_val_arr = mean_df.loc[mean_df.index[mean_df[val_of_interest].argmin()]]
    std_val_arr = std_df.loc[mean_df.index[mean_df[val_of_interest].argmin()]]

if "mnist" in dataset:
    print(
    f"Complement error: {100*best_val_arr['complement_error']:.2f}±{100*std_val_arr['complement_error']:.2f} ",
    f"Validation error: {100*best_val_arr['validation_error']:.2f}±{100*std_val_arr['validation_error']:.2f} ",
      f'Test error : {100*best_val_arr['test_error']:.2f}±{100*std_val_arr['test_error']:.2f} ',
    f"Test loss : {best_val_arr['test_loss']:.4f}±{std_val_arr['test_loss']:.4f}"
)
elif "amazon" in dataset:
    print(
        f"Complement error: {100*best_val_arr['complement_error']:.2f}±{100*std_val_arr['complement_error']:.2f} ",
        f"Validation error: {100*best_val_arr['validation_error']:.2f}±{100*std_val_arr['validation_error']:.2f} ",
          f'Test error : {100*best_val_arr['test_error']:.2f}±{100*std_val_arr['test_error']:.2f} ',
        f"Complement loss: {best_val_arr['complement_loss']:.4f}±{std_val_arr['complement_loss']:.4f} ",
        f"Validation loss: {best_val_arr['validation_loss']:.4f}±{std_val_arr['validation_loss']:.4f} ",
        f"Test loss : {best_val_arr['test_loss']:.4f}±{std_val_arr['test_loss']:.4f}"
    )
else:
    print(
    f"Complement loss: {best_val_arr['complement_loss']:.4f}±{std_val_arr['complement_loss']:.4f} ",
    f"Validation loss: {best_val_arr['validation_loss']:.4f}±{std_val_arr['validation_loss']:.4f} ",
      f'Test loss : {best_val_arr['test_loss']:.4f}±{std_val_arr['test_loss']:.4f} '
)

[2.0, 1e-08]
Complement error: 3.06±0.03  Validation error: 4.08±0.04  Test error : 4.19±0.01  Complement loss: 0.0910±0.0007  Validation loss: 0.1156±0.0014  Test loss : 0.1164±0.0006
