# Show combined boxplots of pruned and unpruned NH-AMPC results

## Imports

In [None]:
import os 
import pandas as pd
import numpy as np

from typing import Optional
from bokeh.io import output_notebook, show
from bokeh.models import Range1d
from bokeh.palettes import Category10_10, Category20_20
output_notebook()

In [None]:
from src.utils import load_results
from src.bokeh_saving import save_figures_button
from src.mpc_dataclass import AMPC_data
from src.plotting import boxplot, get_figure_size, scatter, histogram_pdf

## Settings

In [None]:
# Which files to use in Results folder
FILE_START_ADD: list[str] = ['ASRTID_'] # e.g. for NH_AMPC_results_ASRTID_... -> 'ASRTID_'

# Use only top n cost results
USE_TOP_N: Optional[int] = None # e.g. 5

CLIP_COST: Optional[float] = None

DF_FILTER: Optional[tuple[str, str]] = ('N_NN', 22) # e.g. ('N_NN', 17)

USE_LATEX_STYLE: bool = True

In [None]:
RESULTS_DIR = os.path.abspath('Results')
NH_AMPC_RESULTS_DIR = os.path.join(RESULTS_DIR, 'NH_AMPC_results_prun')
ORIG_R2_SCORES_PATH = os.path.join(RESULTS_DIR, 'OriginalR2scores.pkl')
PRUN_R2_SCORES_PATH = os.path.join(RESULTS_DIR, 'PrunedR2scores.pkl')
SVG_RESULTS_DIR = os.path.join(RESULTS_DIR, 'SVGs')
PNG_RESULTS_DIR = os.path.join(RESULTS_DIR, 'PNGs')

In [None]:
FIGURE_SIZE_0_8 = get_figure_size(fraction=1.0) if USE_LATEX_STYLE else (900, 600)

## Data Extraction

### R2 Scores

In [None]:

orig_r2_scores = load_results(ORIG_R2_SCORES_PATH)
orig_r2_scores.reset_index(inplace=True)
# print(orig_r2_scores['Version'].max())
orig_r2_scores = orig_r2_scores.set_index(['N_NN', 'N_hidden', 'N_hidden_end', 'Version']).sort_index()
orig_r2_scores.head()

In [None]:
prun_r2_scores = load_results(PRUN_R2_SCORES_PATH)
# print(prun_r2_scores.reset_index()['Version'].max())
prun_r2_scores.head()

### NH-AMPC results

#### Original NH-AMPC results

In [None]:
NH_AMPC_FILE_START = [f'NH_AMPC_results_{fs_add}' for fs_add in FILE_START_ADD]

orig_NH_AMPC_results = []
file_paths = os.listdir(NH_AMPC_RESULTS_DIR)
for file in file_paths:
    if not any(file.startswith(f_start) for f_start in NH_AMPC_FILE_START) or 'prun' in file or not file.endswith('.ph'):
        continue
    file_path = os.path.join(NH_AMPC_RESULTS_DIR, file)
    results = AMPC_data.load(file_path)

    orig_NH_AMPC_results.append({
                    'N_NN': results.P.N_NN, 
                    'N_hidden': results.P.N_hidden,
                    'N_hidden_end': results.P.N_hidden,
                    'Version': results.P.V_NN,
                    'Cost': results.Cost,
                    'Mean_Time': np.mean(results.Time) * 1e3,
                    'Median_Time': np.median(results.Time) * 1e3,
                })
    
orig_NH_AMPC_results = pd.DataFrame(orig_NH_AMPC_results)
orig_NH_AMPC_results = orig_NH_AMPC_results.set_index(['N_NN', 'N_hidden', 'N_hidden_end', 'Version']).sort_index()

In [None]:
# Find results that failed (are NaN) and drop them
dropped_orig = orig_NH_AMPC_results[orig_NH_AMPC_results.isna().any(axis=1)].reset_index()
orig_NH_AMPC_results.dropna(axis=0, inplace=True)
print(f'Dropped results:\n{dropped_orig}')

orig_results = orig_NH_AMPC_results.join(orig_r2_scores, how='inner')

#### Pruned NH-AMPC results

In [None]:
prun_NH_AMPC_results = []
file_paths = os.listdir(NH_AMPC_RESULTS_DIR)
for file in file_paths:
    if not any(file.startswith(f_start) for f_start in NH_AMPC_FILE_START) or 'prun' not in file or not file.endswith('.ph'):
        continue
    file_path = os.path.join(NH_AMPC_RESULTS_DIR, file)
    results = AMPC_data.load(file_path)

    prun_NH_AMPC_results.append({
                    'N_NN': results.P.N_NN, 
                    'N_hidden': results.P.N_hidden,
                    'N_hidden_end': results.P.N_hidden_end,
                    'Version': results.P.V_NN,
                    'Cost': results.Cost,
                    'Mean_Time': np.mean(results.Time) * 1e3,
                    'Median_Time': np.median(results.Time) * 1e3,
                })
    
prun_NH_AMPC_results = pd.DataFrame(prun_NH_AMPC_results)
prun_NH_AMPC_results = prun_NH_AMPC_results.set_index(['N_NN', 'N_hidden', 'N_hidden_end', 'Version']).sort_index()

In [None]:
dropped_prun = prun_NH_AMPC_results[prun_NH_AMPC_results.isna().any(axis=1)].reset_index()
prun_NH_AMPC_results.dropna(axis=0, inplace=True)
print(f'Dropped results:\n{dropped_prun}')

prun_results = prun_NH_AMPC_results.join(prun_r2_scores, how='inner')

### Get n top cost samples  

In [None]:
if USE_TOP_N is not None:
    idxs_orig = orig_results.groupby(['N_NN', 'N_hidden'], group_keys=False)['Cost'].nsmallest(n=USE_TOP_N).index
    mask_orig = orig_results.index.isin(idxs_orig)
    orig_results = orig_results[mask_orig]

    idxs_prun = prun_results.groupby(['N_NN', 'N_hidden'], group_keys=False)['Cost'].nsmallest(n=USE_TOP_N).index
    mask_prun = prun_results.index.isin(idxs_prun)
    prun_results = prun_results[mask_prun]

### Clip cost to 150

In [None]:
if CLIP_COST is not None:
    orig_results.loc[orig_results['Cost'] > CLIP_COST, 'Cost'] = CLIP_COST
    prun_results.loc[prun_results['Cost'] > CLIP_COST, 'Cost'] = CLIP_COST

### Reset Dataframe indexes

In [None]:
orig_df = orig_results.reset_index()

if DF_FILTER is not None:
    orig_df = orig_df[orig_df[DF_FILTER[0]]==DF_FILTER[1]]
    
orig_df.info()
orig_df.head()

In [None]:
df_pruned = prun_results.reset_index()

if DF_FILTER is not None:
    df_pruned = df_pruned[df_pruned[DF_FILTER[0]]==DF_FILTER[1]]

df_pruned.info()
df_pruned.head()

### Find relevant MPC_results <br>
Filter original NH-AMPC results, so only those that are also in the pruned results are inside.

In [None]:
unique_N_hidden_end = pd.concat([df_pruned['N_hidden_end'], df_pruned['N_hidden']]).unique()
print(unique_N_hidden_end)

unique_N_NN = df_pruned['N_NN'].unique()
print(unique_N_NN)

In [None]:
orig_df = orig_df[orig_df['N_hidden'].isin(unique_N_hidden_end) & orig_df['N_NN'].isin(unique_N_NN)].reset_index(drop=True)
orig_df.head()

#### Join pruned and original

In [None]:
sub_cat = 'Subcategory'
df_pruned[sub_cat] = 'Pruned'
orig_df[sub_cat] = 'Original'
df = pd.concat((df_pruned, orig_df)).reset_index(drop=True)
df.head()

# Boxplot

### Cost - N_hidden_end

In [None]:
value_name = 'Cost'         # version, Cost, Mean_Time, Median_Time, R2_score, Rel_err_mean, Rel_err_std, NN_param_size
category_name = 'N_hidden_end'          # N_NN, N_hidden, N_hidden_end

p_cost = boxplot(
    df,
    [category_name, sub_cat],
    value_name,
    legend_category=sub_cat,
    figure_size=FIGURE_SIZE_0_8,
    show_non_outliers=True, 
    show_outliers=True, 
    hover_tooltips=['NN_param_size'],
    y_range=(103., 120.),
    box_colors=[Category20_20[1], Category20_20[3]],
    scatter_colors=[Category20_20[0], Category20_20[2]],
    latex_style=USE_LATEX_STYLE,
    
)

show(p_cost)

### Mean_Time - N_hidden_end

In [None]:
value_name = 'Mean_Time'         # version, Cost, Mean_Time, Median_Time, R2_score, Rel_err_mean, Rel_err_std
category_name = 'N_hidden_end'          # N_NN, N_hidden, N_hidden_end

p_time = boxplot(
    df,
    [category_name, sub_cat],
    value_name,
    figure_size=FIGURE_SIZE_0_8,
    legend_category=sub_cat,
    show_non_outliers=True, 
    show_outliers=True, 
    y_unit='ms',
    hover_tooltips=['NN_param_size'],
    box_colors=[Category20_20[1], Category20_20[3]],
    scatter_colors=[Category20_20[0], Category20_20[2]],
    latex_style=USE_LATEX_STYLE,
)

show(p_time)

### R2_score - N_hidden_end

In [None]:
value_name = 'R2_score'         # version, Cost, Mean_Time, Median_Time, R2_score, Rel_err_mean, Rel_err_std
category_name = 'N_hidden_end'          # N_NN, N_hidden, N_hidden_end

p_r2 = boxplot(
    df,
    [category_name, sub_cat],
    value_name,
    figure_size=FIGURE_SIZE_0_8,
    legend_category=sub_cat,
    show_non_outliers=True, 
    show_outliers=True, 
    hover_tooltips=['NN_param_size', 'q2'],
    box_colors=[Category20_20[1], Category20_20[3]],
    scatter_colors=[Category20_20[0], Category20_20[2]],
    latex_style=USE_LATEX_STYLE,
)

show(p_r2)

# Scatter

In [None]:
p_cost_to_r2 = scatter(
    df,
    'Cost', 'R2_score', 'N_hidden', 'N_hidden_end', 
    figure_size=FIGURE_SIZE_0_8,
    latex_style=USE_LATEX_STYLE,
    y_range=(103., 110.),
    markers=['circle', 'inverted_triangle']
)
show(p_cost_to_r2)

In [None]:
p_hist = histogram_pdf(
    df,
    'Cost', ['N_hidden', 'N_hidden_end'],
    color_palette=Category10_10,
    legend_label_callable=lambda x: '->'.join(str(value) for value in sorted(set(x))[::-1]),
    cap_value=150.,
    bins=13,
    figure_size=FIGURE_SIZE_0_8,
    latex_style=USE_LATEX_STYLE,
)
show(p_hist)

# Save Plots

In [None]:
key_entries = {
    'N_NN': 'N',
    'N_MPC': 'M',
    'N_hidden': 'Nh',
    'N_hidden_end': 'Nhe',
}

all_plots = [
    (f'boxplot_prun_time_{DF_FILTER[1]}{key_entries[DF_FILTER[0]]}', p_time), 
    (f'boxplot_prun_cost_{DF_FILTER[1]}{key_entries[DF_FILTER[0]]}', p_cost),
    (f'boxplot_prun_r2_{DF_FILTER[1]}{key_entries[DF_FILTER[0]]}', p_r2),
    (f'scatter_prun_cost_to_r2_{DF_FILTER[1]}{key_entries[DF_FILTER[0]]}', p_cost_to_r2),
    (f'hist_prun_cost_{DF_FILTER[1]}{key_entries[DF_FILTER[0]]}', p_hist),
]

save_figures_button(all_plots, SVG_RESULTS_DIR, PNG_RESULTS_DIR)