# Setup

## Imports

In [None]:
import os
import numpy as np

from typing import Literal, Callable, Optional
from bokeh.io import output_notebook, show
output_notebook()
from bokeh.palettes import Category20_20, Category10_10

In [None]:
from src.utils import get_masked_results, add_and_or_str
from src.bokeh_saving import save_figures_button
from src.means import get_mean_of_results
from src.plotting import plot_MPC_results, get_figure_size
from src.mpc_dataclass import dataclass_group_by, AMPC_data, find_top_costs, MPC_data

## Settings

In [None]:
MPC_TYPES: list[Literal['AMPC', 'NH_AMPC', 'CMPC', 'NH_CMPC']] = ['AMPC', 'NH_AMPC']

# Use test folder
TEST_RESULTS: bool = False # e.g. /Results/NH_AMPC_results_Test instead of /Results/NH_AMPC_results

# Use prun folder results
PRUN_RESULTS: bool = False # e.g. /Results/NH_AMPC_results_prun instead of /Results/NH_AMPC_results

# Which files to use in Results folder
FILE_START_ADD: list[str] = [] # e.g. for NH_AMPC_results_ASRTID_... -> 'ASRTID_'

# Use only top n cost results
USE_TOP_N: Optional[int] = None

# Use maximal time
TIME_AVG_FUN: Optional[Callable] = lambda x, **kwargs: np.nanpercentile(x, 75, **kwargs) # np.nanmean

# Use median cost
COST_AVG_FUN: Optional[Callable] = np.nanmedian

# Additional plots 
ADD_PLOTS: list[Literal['Iterations', 'Prep_Time', 'Fb_Time', 'Prep_Iterations', 'Fb_Iterations']] = []
ADD_PLOTS_OPTIONS: dict[Literal['Iterations', 'Prep_Time', 'Fb_Time', 'Prep_Iterations', 'Fb_Iterations'], dict] = {}

# Dataframe filter
AND_FILTER_DICT: Optional[dict[str, object]] = None
OR_FILTER_DICT: Optional[dict[str, object]] = {'N_NN': [0, 22]}

# Use Latex style plots
USE_LATEX_STYLE: bool = True

In [None]:
RESULTS_DIR = os.path.abspath('Results')
SVG_RESULTS_DIR = os.path.join(RESULTS_DIR, 'SVGs')
PNG_RESULTS_DIR = os.path.join(RESULTS_DIR, 'PNGs')

In [None]:
FIGURE_SIZE_1_0 = get_figure_size(fraction=1.0, ratio=5.) if USE_LATEX_STYLE else (1200, 200)

## by Callable

In [None]:
if len(MPC_TYPES) > 1 and (any('CMPC' in mpc_type for mpc_type in MPC_TYPES) or any('AMPC' in mpc_type for mpc_type in MPC_TYPES)) and not PRUN_RESULTS:
    def by_callable(results: AMPC_data | MPC_data):
        if type(results) == MPC_data and results.P.N_NN == 0:
            return f'CMPC_{results.P.N_MPC}M'
        elif type(results) == MPC_data and results.P.N_NN != 0:
            return f'NH_CMPC_{results.P.N_MPC}M_{results.P.N}N'
        elif type(results) == AMPC_data and results.P.N_NN == 0:
            return f'AMPC_{results.P.N_MPC}M'
        elif type(results) == AMPC_data and results.P.N_NN != 0:
            return f'NH_AMPC_{results.P.N_MPC}M_{results.P.N}N_{results.P.N_hidden}Nh'
            
elif len(FILE_START_ADD) > 1 and not PRUN_RESULTS and 'NH_AMPC' in MPC_TYPES:
    def by_callable(results: AMPC_data):
        return f'{results.acados_name[:-4]}_{results.P.N_MPC}M_{results.P.N}N_{results.P.N_hidden}Nh'
        
elif 'NH_AMPC' in MPC_TYPES and PRUN_RESULTS:
    def by_callable(results: AMPC_data):
        return f'{results.P.N_MPC}M_{results.P.N}N_{results.P.N_hidden}Nh_{results.P.N_hidden_end}Nhe' \
            if results.P.N_hidden_end is not None else f'{results.P.N_MPC}M_{results.P.N}N_{results.P.N_hidden}Nh'
    
else:
    def by_callable(results: AMPC_data | MPC_data):
        return f'{results.P.N_MPC}M'

# Data

In [None]:
MPC_Results: list[AMPC_data | MPC_data] = []

for mpc_type in MPC_TYPES:
    mpc_results_dir = os.path.join(RESULTS_DIR, f'{mpc_type}_results')
    if PRUN_RESULTS:
        mpc_results_dir = f'{mpc_results_dir}_prun'
    if TEST_RESULTS:
        mpc_results_dir = f'{mpc_results_dir}_Test'
    
    file_paths = os.listdir(mpc_results_dir)
    file_start = [f'{mpc_type}_results_{fs_add}' for fs_add in FILE_START_ADD] \
      if FILE_START_ADD else [f'{mpc_type}_results']
    
    for file in file_paths:
        if not any(file.startswith(f_start) for f_start in file_start) or not file.endswith('.ph'):
            continue
        file_path = os.path.join(mpc_results_dir, file)
        results = AMPC_data.load(file_path) if 'AMPC' in file else MPC_data.load(file_path)
        MPC_Results.append(results)

print(f'Results length: {len(MPC_Results)}')

In [None]:
if USE_TOP_N is not None:
    Top_MPC_Results: list[AMPC_data | MPC_data] = []
    for results_key, ampc_result in dataclass_group_by(MPC_Results, by=by_callable):
        top_res = find_top_costs(ampc_result, USE_TOP_N)
        Top_MPC_Results.extend(top_res)
    MPC_Results = Top_MPC_Results

In [None]:
Mean_MPC_Results: list[AMPC_data | MPC_data] = []
for results_key, ampc_result in dataclass_group_by(MPC_Results, by=by_callable):
    ampc_result = list(ampc_result)
    Mean_MPC_Results.append(get_mean_of_results(ampc_result, cost_fun=COST_AVG_FUN, time_fun=TIME_AVG_FUN, keep_fields=['acados_name']))

print(f'Mean results length: {len(Mean_MPC_Results)}')

# Result Plots

## Plot All

In [None]:
p_res = plot_MPC_results(
    get_masked_results(MPC_Results, AND_FILTER_DICT, OR_FILTER_DICT),
    plot_mpc_trajectories=False, 
    xbnd=1.5, 
    group_by=by_callable,
    thickness=[3 for _ in range(len(MPC_Results))],
    dash=[*['solid' for _ in range(7)], *['dashed' for _ in range(7)]],
    # solver_time_scale='linear',
    cols=Category10_10, # Category20_20
    additional_plots=ADD_PLOTS,
    width=FIGURE_SIZE_1_0[0],
    height=FIGURE_SIZE_1_0[1],
    latex_style=USE_LATEX_STYLE
)
show(p_res)

## Plot Averages

In [None]:
p_mean = plot_MPC_results(
    get_masked_results(Mean_MPC_Results, AND_FILTER_DICT, OR_FILTER_DICT), 
    plot_mpc_trajectories=False, 
    xbnd=1.5, 
    group_by=by_callable,
    thickness=[3 for _ in range(len(Mean_MPC_Results))],
    dash=[*['solid' for _ in range(7)], *['dashed' for _ in range(7)]],
    # solver_time_scale='linear',
    cols=Category10_10 if len(Mean_MPC_Results) <= 10 else Category20_20,
    additional_plots=ADD_PLOTS,
    width=FIGURE_SIZE_1_0[0],
    height=FIGURE_SIZE_1_0[1],
    latex_style=USE_LATEX_STYLE
)
show(p_mean)

## Save Plots

In [None]:
mpc_types_str = '_'.join(MPC_TYPES)
all_plots = [
    (f'{mpc_types_str}_results', p_res), 
    (f'{mpc_types_str}_mean_results', p_mean)
]

for i, (name, p) in enumerate(all_plots):
    name = add_and_or_str(name, OR_FILTER_DICT, AND_FILTER_DICT)
    all_plots[i] = (name, p)

save_figures_button(all_plots, SVG_RESULTS_DIR, PNG_RESULTS_DIR)