# Comparison Plots of Hyperparameter Tests

In [None]:
# VARIABLES TO CHANGE
filepath = "." # this should be the path to the hp_tests directory (used to read files in {filepath}/hp_tests)

In [None]:
# imports
import h5py
import xarray as xr
from xhistogram.xarray import histogram
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.units as munits
from matplotlib.offsetbox import AnchoredText

In [None]:
def plot_all_diff_interval(hp_list, hyperparam, ylow=-0.06, yhigh=0.08, savefig=False):
    '''Plots subplots comparing the level-wise mean/std intervals for different hyperparameter versions

    Parameters
    ----------
    hp_list : list
        List of dictionaries, where each dictionary has learning rate, batch size, model version, seed (all scalers)
    hyperparam : str
        Name of hyperparameter being varied (either Batch Zize or Learning Rate)
    ylow : int
        Integer corresponding to the minimum y-axis value (default -0.06)
    yhigh : int
        Integer corresponding to the maximum y-axis value (default 0.08)
    savefig : bool
        Boolean that enables the saving of the plot

    Returns
    -------
    N/A
    '''

    
    plt.rc('font', size=20) # controls default text sizes
    plt.rc('axes', titlesize=22)
    plt.rc('axes', labelsize=22) 
    plt.rc('xtick', labelsize=14)
    plt.rc('ytick', labelsize=14)
    plt.rc('legend', fontsize=16)
    path = "{filepath}/hp_tests/hp_v{vers}/preds_v{vers}_-1times.csv"

    fig, axes = plt.subplots(nrows=len(hp_list)//2, ncols=2)
    fig.set_size_inches(16, 6*len(hp_list)//2) # width, height
    fig.subplots_adjust(wspace=0.4)
    fig.subplots_adjust(hspace=0.4)
    axes = axes.flatten()

    
    for (i, model) in enumerate(hp_list):
        df = pd.read_csv(path.format(filepath=filepath, vers=model['version']))
        errors = pd.DataFrame({
            'my-og': df['my_pred'] - df['og_pred'],
            'lev': df['lev']
        })

        mean = errors.groupby(['lev'], as_index=False).mean()
        std = errors.groupby(['lev'], as_index=False).std()
        axes[i].errorbar(mean['lev'], mean['my-og'], yerr=std['my-og'], fmt='-o', capsize=2)
        axes[i].axhline(y=0, color='r', linestyle='-')
        axes[i].set_ylim(ylow, yhigh)
        
        axes[i].set_title(f"{hyperparam}: {model[hyperparam]}")
        axes[i].set_xlabel('Level')
        axes[i].set_ylabel('Prediction Differences\n(New - Wnet-prior)')
        axes[i].grid(True)
    
    fig.suptitle(rf"Mean $\pm$ Std of Prediction Differences at Varying {hyperparam}")

    fig.show()
    
    if savefig:
        fig.savefig(f"hp_plots/all_diff_vary_{hyperparam}.png")

In [None]:
hp_batch = [
    {
        'Learning Rate': 0.0001,
        'Batch Size': 1024,
        'version': 2
    },
    {
        'Learning Rate': 0.0001,
        'Batch Size': 2048,
        'version': 5
    },
    {
        'Learning Rate': 0.0001,
        'Batch Size': 4096,
        'version': 9 # try 8 or 9
    },
    {
        'Learning Rate': 0.0001,
        'Batch Size': 8192,
        'version': 10 # try 10 or 11
    }
]
plot_all_diff_interval(hp_batch, 'Batch Size', savefig=True)

In [None]:
hp_lr = [
    {
        'Learning Rate': 0.000105,
        'Batch Size': 2048,
        'version': 14,
    },
    {
        'Learning Rate': 0.0001,
        'Batch Size': 2048,
        'version': 17
    },
    {
        'Learning Rate': 0.000095,
        'Batch Size': 2048,
        'version': 20
    },
    {
        'Learning Rate': 0.00009,
        'Batch Size': 2048,
        'version': 25
    },
    {
        'Learning Rate': 0.000085,
        'Batch Size': 2048,
        'version': 27 # 27, 28
    },
    {
        'Learning Rate': 0.00008,
        'Batch Size': 2048,
        'version': 33 # 30, 32, 33
    }
]
plot_all_diff_interval(hp_lr, 'Learning Rate', savefig=True)