In [1]:
import wandb
import yaml
from utilities.utils import correct_type_of_entry
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
from copy import deepcopy
import seaborn as sn
pd.set_option('display.max_rows', None)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
dataset = "amazon"

if "mnist" in dataset:
    values_to_fetch = ['complement_error', 'validation_error', 'test_error', 'test_loss']
    if "mnist" == dataset:
        sweep_config = "pretraining"
    else:
        sweep_config = "default"
elif "amazon" in dataset:
    values_to_fetch = ['complement_error', 'validation_error', 'test_error','complement_loss', 'validation_loss', 'test_loss']
    sweep_config = "baseline_transformer"
else:
    values_to_fetch = ['complement_loss', 'validation_loss', 'test_loss']
    sweep_config = "forest"

In [7]:
sweep_config_name = "./configs/sweep_configs/" + sweep_config + ".yaml"
with open(sweep_config_name) as file:
    sweep_configuration = yaml.safe_load(file)
    
hps = {}
for key, item in sweep_configuration['parameters'].items():
    if item.get('values', None) is not None:
        hps[key] = correct_type_of_entry(item['values'])
size_hyperparams = tuple([len(l) for l in hps.values()])

In [8]:
api = wandb.Api()
entity = "mathieu-bazinet"
project = "baseline_" + dataset
runs = api.runs(entity + "/" + project)

In [9]:
results_matrix = np.ones(((len(values_to_fetch),) + size_hyperparams))

for run_idx in range(len(runs)):
    run = runs[run_idx]
    if run.config['model_type'] == sweep_config or run.config['model_type'] not in ['forest', 'tree']:
        for val_to_fetch_idx in range(len(values_to_fetch)):
            try:
                matrix_idx = tuple([val_to_fetch_idx] + [hps[key].index(run.config[key]) for key in hps.keys()])
            except ValueError:
                continue
            val_to_fetch = values_to_fetch[val_to_fetch_idx]
            try:
                results_matrix[matrix_idx] = run.summary[val_to_fetch]
            except KeyError:
                results_matrix[matrix_idx] = [row[val_to_fetch] for row in run.scan_history(keys=[val_to_fetch])][-1]

In [10]:
hp_list = list(hps.values())[1:]
params_product = list(product(*hp_list))
name_list = []
idx_list = []
for params in params_product:
    name = ""
    for p in params:
        name += str(p) + " "
    name_list.append(name[:-1])
    idx = ()
    for p_idx in range(len(params)):
        p_key = list(hps.keys())[1:][p_idx]
        idx += (hps[p_key].index(params[p_idx]),)
    idx_list.append(tuple(idx))

In [11]:
reshaped_matrix = results_matrix.mean(1).reshape(results_matrix.shape[0],np.prod(results_matrix.shape[2:])).T
mean_df = pd.DataFrame(reshaped_matrix, index=name_list, columns=values_to_fetch)
mean_df

Unnamed: 0,complement_error,validation_error,test_error,complement_loss,validation_loss,test_loss
2 0.1 1e-06,0.0,0.039356,0.049822,1e-05,0.370166,0.45731
2 0.1 1e-07,0.0,0.040722,0.051265,1e-05,0.417975,0.507113
2 0.1 1e-08,0.023393,0.032428,0.042477,0.071929,0.097224,0.122323
2 0.2 1e-06,0.0,0.039356,0.049822,1e-05,0.370166,0.45731
2 0.2 1e-07,0.0,0.040722,0.051265,1e-05,0.417975,0.507113
2 0.2 1e-08,0.023393,0.032428,0.042477,0.071929,0.097224,0.122323
5 0.1 1e-06,0.0,0.037265,0.049454,1e-05,0.331698,0.427432
5 0.1 1e-07,0.0,0.037638,0.050417,1.8e-05,0.359962,0.452604
5 0.1 1e-08,0.022476,0.032484,0.044137,0.069232,0.097731,0.130811
5 0.2 1e-06,0.0,0.037265,0.049454,1e-05,0.331698,0.427432


In [12]:
reshaped_std = results_matrix.std(1).reshape(results_matrix.shape[0],np.prod(results_matrix.shape[2:])).T
std_df = pd.DataFrame(reshaped_std, index=name_list, columns=values_to_fetch)

In [15]:
if "mnist" in dataset:
    val_of_interest = "validation_error"
elif "amazon" in dataset:
    val_of_interest = "validation_error"
else:
    val_of_interest = "validation_loss"

if "mnist" in dataset and "mnist" != dataset:
    wanted_model_type = "cnn"
    model_df = mean_df[[wanted_model_type in idx for idx in mean_df.index]]
    std_model_df = std_df[[wanted_model_type in idx for idx in std_df.index]]
    best_params = correct_type_of_entry(mean_df.index[mean_df[val_of_interest].argmin()].split())
    print(best_params)
    best_val_arr = model_df.loc[model_df.index[model_df[val_of_interest].argmin()]]
    std_val_arr = std_model_df.loc[model_df.index[std_model_df[val_of_interest].argmin()]]
else:
    best_params = correct_type_of_entry(mean_df.index[mean_df[val_of_interest].argmin()].split())
    print(best_params)
    best_val_arr = mean_df.loc[mean_df.index[mean_df[val_of_interest].argmin()]]
    std_val_arr = std_df.loc[mean_df.index[std_df[val_of_interest].argmin()]]

if "mnist" in dataset:
    print(
    f"Complement error: {best_val_arr['complement_error']:.4f}±{std_val_arr['complement_error']:.4f} ",
    f"Validation error: {best_val_arr['validation_error']:.4f}±{std_val_arr['validation_error']:.4f} ",
      f'Test error : {best_val_arr['test_error']:.4f}±{std_val_arr['test_error']:.4f} ',
    f"Test loss : {best_val_arr['test_loss']:.4f}±{std_val_arr['test_loss']:.4f}"
)
elif "amazon" in dataset:
    print(
        f"Complement error: {best_val_arr['complement_error']:.4f}±{std_val_arr['complement_error']:.4f} ",
        f"Validation error: {best_val_arr['validation_error']:.4f}±{std_val_arr['validation_error']:.4f} ",
          f'Test error : {best_val_arr['test_error']:.4f}±{std_val_arr['test_error']:.4f} ',
        f"Complement error: {best_val_arr['complement_loss']:.4f}±{std_val_arr['complement_loss']:.4f} ",
        f"Validation error: {best_val_arr['validation_loss']:.4f}±{std_val_arr['validation_loss']:.4f} ",
        f"Test loss : {best_val_arr['test_loss']:.4f}±{std_val_arr['test_loss']:.4f}"
    )
else:
    print(
    f"Complement loss: {best_val_arr['complement_loss']:.4f}±{std_val_arr['complement_loss']:.4f} ",
    f"Validation loss: {best_val_arr['validation_loss']:.4f}±{std_val_arr['validation_loss']:.4f} ",
      f'Test loss : {best_val_arr['test_loss']:.4f}±{std_val_arr['test_loss']:.4f} '
)

[2.0, 0.1, 1e-08]
Complement error: 0.0234±0.0079  Validation error: 0.0324±0.0082  Test error : 0.0425±0.0006  Complement error: 0.0719±0.0182  Validation error: 0.0972±0.0178  Test loss : 0.1223±0.0067
