# Setup

## Imports

In [None]:
import pandas as pd
import os
import numpy as np

from bokeh.io import output_notebook, show
output_notebook()
from bokeh.palettes import Viridis256, Turbo256, Magma256, Cividis256, Inferno256

In [None]:
from src.bokeh_saving import save_figures_button
from src.mpc_dataclass import AMPC_data
from src.plotting import heatmap, get_figure_size

## Settings

In [None]:
# Use test folder 
TEST_RESULTS: bool = False # e.g. /Results/NH_AMPC_results_Test instead of /Results/NH_AMPC_results

# Which files to use in Results folder
FILE_START_ADD: list[str] = ['ASRTID_'] # e.g. for NH_AMPC_results_ASRTID_... -> 'ASRTID_'

# Use only top n cost results
USE_TOP_N: int | None = None # e.g. 5

USE_LATEX_STYLE: bool = True

In [None]:
RESULTS_DIR = os.path.abspath('Results')
AMPC_RESULTS_DIR = os.path.join(RESULTS_DIR, 'AMPC_results')
NH_AMPC_RESULTS_DIR = os.path.join(RESULTS_DIR, 'NH_AMPC_results')
SVG_RESULTS_DIR = os.path.join(RESULTS_DIR, 'SVGs')
PNG_RESULTS_DIR = os.path.join(RESULTS_DIR, 'PNGs')

if TEST_RESULTS:
    AMPC_RESULTS_DIR += '_Test'
    NH_AMPC_RESULTS_DIR += '_Test'

In [None]:
FIGURE_SIZE_1_0 = get_figure_size(fraction=1.0) if USE_LATEX_STYLE else (1200, 800)
FIGURE_SIZE_0_8 = get_figure_size(fraction=0.8) if USE_LATEX_STYLE else (1000, 750)

# Data

### AMPC results extraction

In [None]:
AMPC_FILE_STARTS = [f'AMPC_results_{fs_add}' for fs_add in FILE_START_ADD]

AMPC_results = []
ampc_file_paths = os.listdir(AMPC_RESULTS_DIR)
for file in ampc_file_paths:
    if not any(file.startswith(f_start) for f_start in AMPC_FILE_STARTS) or not file.endswith('.ph'):
        continue
    file_path = os.path.join(AMPC_RESULTS_DIR, file)
    results = AMPC_data.load(file_path)
    AMPC_results.append({
                        'N_MPC': results.P.N_MPC,
                        'Cost': results.Cost,
                        'Mean_Time': np.mean(results.Acados_Time) * 1e3,
                        'Median_Time': np.median(results.Acados_Time) * 1e3,
                        'Max_Time': np.nanargmax(results.Acados_Time) * 1e3,
                    })
AMPC_results = pd.DataFrame(AMPC_results).median()

### NH-AMPC results extraction

In [None]:
NH_AMPC_FILE_START = [f'NH_AMPC_results_{fs_add}' for fs_add in FILE_START_ADD]

NH_AMPC_results = []
file_paths = os.listdir(NH_AMPC_RESULTS_DIR)
for file in file_paths:
    if not any(file.startswith(f_start) for f_start in NH_AMPC_FILE_START):
        continue
    file_path = os.path.join(NH_AMPC_RESULTS_DIR, file)
    results = AMPC_data.load(file_path)

    NH_AMPC_results.append({
                    'N_NN': results.P.N_NN, 
                    'N_hidden': results.P.N_hidden,
                    'acados_name': results.acados_name,
                    'Version': results.P.V_NN,
                    'Cost': results.Cost,
                    'Mean_Time': np.mean(results.Acados_Time) * 1e3,
                    'Median_Time': np.median(results.Acados_Time) * 1e3,
                    'Max_Time': np.nanargmax(results.Acados_Time) * 1e3,
                })
    
NH_AMPC_results = pd.DataFrame(NH_AMPC_results).set_index(['N_NN', 'N_hidden', 'Version']).sort_index()

<div class="alert alert-block alert-warning">
<b>Attention:</b> Drops all failed results
</div>

In [None]:
NH_AMPC_results.dropna(axis=0,inplace=True)
NH_AMPC_results.info()
NH_AMPC_results.head(20)

In [None]:
if USE_TOP_N is not None:
    indices = NH_AMPC_results.groupby(['N_NN', 'N_hidden'], group_keys=False)['Cost'].nsmallest(n=USE_TOP_N).index
    mask = NH_AMPC_results.index.isin(indices)
    NH_AMPC_results = NH_AMPC_results[mask]

### Clip cost to 150

In [None]:
NH_AMPC_results.loc[NH_AMPC_results['Cost'] > 150, 'Cost'] = 150

### get mean and medians of seeds

In [None]:
median_cost = NH_AMPC_results.groupby(['N_NN', 'N_hidden'], group_keys=False)['Cost'].median()
mean_time = NH_AMPC_results.groupby(['N_NN', 'N_hidden'], group_keys=False)[['Mean_Time', 'Median_Time', 'Max_Time']].median()
mm_df = pd.concat((mean_time, median_cost), axis=1).sort_index().reset_index()
mm_df.head()

# Heatmaps

In [None]:
p_time = heatmap(
    mm_df, 
    'N_NN', 'N_hidden', 'Median_Time', 
    color_palette=Turbo256, 
    cbar_unit='ms', 
    cmap_cap=AMPC_results['Median_Time'], 
    figure_size=FIGURE_SIZE_0_8,
    latex_style=USE_LATEX_STYLE,
)
show(p_time)

In [None]:
p_cost = heatmap(
    mm_df, 
    'N_NN', 'N_hidden', 'Cost', 
    cmap_cap=110., 
    color_palette=Turbo256, 
    figure_size=FIGURE_SIZE_0_8, 
    latex_style=USE_LATEX_STYLE,
)
show(p_cost)

## Save Heatmaps

In [None]:
all_plots = [
    ('heatmap_time', p_time), 
    ('heatmap_cost', p_cost)
]

save_figures_button(all_plots, SVG_RESULTS_DIR, PNG_RESULTS_DIR)