In [None]:
import pandas as pd
import os
import numpy as np
import sys

from itertools import product
from bokeh.io import output_notebook, show
from bokeh.palettes import Category10_10
output_notebook()

# if not sys.warnoptions:
#     import warnings
#     warnings.simplefilter("ignore")

In [None]:
from src.parameters import MPC_Param
from src.plotting import plot_MPC_results, get_figure_size
from src.mpc_dataclass import AMPC_data
from src.bokeh_saving import save_figures_button

# Settings

In [None]:
HORIZONS = (30,)
DATAPOINTS = (5_000, )
VERSIONS = (9, )
ACADOS_NAMES = ['SQP_PCHPIPM_DISCRETE'] #'RTI_PCHPIPM_DISCRETE' 'ASRTID_FCH', 

PLOT_SAMPLE_NUM = 200

USE_LATEX_STYLE = True

In [None]:
RESULTS_DIR = os.path.abspath('Results')
MPC_DATASETS_DIR = os.path.join(RESULTS_DIR, 'MPC_data_gen')
SVG_RESULTS_DIR = os.path.join(RESULTS_DIR, 'SVGs')
PNG_RESULTS_DIR = os.path.join(RESULTS_DIR, 'PNGs')

FIGURE_SIZE_1_0 = get_figure_size(fraction=1.0, ratio=5.) if USE_LATEX_STYLE else (1200, 200)

## Data loading

In [None]:
dataset_names = [
    (
        horizon, 
        datapoints, 
        version, 
        acados_name,
        f'MPC_data_{horizon}steps_{datapoints}datapoints_{acados_name}_{version}v'
    ) for horizon, datapoints, version, acados_name in product(HORIZONS, DATAPOINTS, VERSIONS, ACADOS_NAMES)
]
# dataset_names.append((30, 20_000, 'inverted_pendulum_20k_30steps'))


dataset_results: list = []
for samples, datapoints, version, acados_name, dataset_file_name in dataset_names:

    print('File {} loading'.format(dataset_file_name))
    df_file = os.path.join(MPC_DATASETS_DIR, dataset_file_name + '.csv')

    if not os.path.exists(df_file):
        print('\t-> dont exist')
        continue

    df = pd.read_csv(df_file)

    if not dataset_file_name.startswith('inverted_'):
        json_file = os.path.join(MPC_DATASETS_DIR, dataset_file_name + '.json')
        mpc_param = MPC_Param.load(json_file)
    else:
        mpc_param = MPC_Param(N_MPC=30)
    mpc_param.N_sim = 1
    mpc_param.T_sim = mpc_param.Ts
    
    # set states and inputs that are used for the open and closed loops 
    startig_states = [f'{x}_p{0}' for x in mpc_param.xlabel]
    starting_input = 'u_p0'

    all_states = [f'{x}_p{i}' for x in mpc_param.xlabel for i in range(mpc_param.N_MPC+1)]
    all_inputs = [f'u_p{i}' for i in range(mpc_param.N_MPC)]

    for idx, row in df.sample(PLOT_SAMPLE_NUM if datapoints >= PLOT_SAMPLE_NUM else datapoints).iterrows():

        results = AMPC_data(mpc_param)
        results.X = row[startig_states].to_numpy().reshape((mpc_param.nx, 1))
        results.U = np.array([[row[starting_input], ], ])
        results.Time = np.zeros((1, ))
        results.X_traj = row[all_states].to_numpy().reshape((1, mpc_param.nx, mpc_param.N_MPC+1))
        results.U_traj = row[all_inputs].to_numpy().reshape((1, mpc_param.nu, mpc_param.N_MPC))
        results.acados_name = dataset_file_name # Missuse of acados name just for plotting
        
        results.freeze()
        dataset_results.append(results)

## plot data

In [None]:
p = plot_MPC_results(
    dataset_results, 
    time_type=None,
    xbnd=1.4, 
    cols=Category10_10,
    plot_mpc_trajectories=True,
    plot_Ts=False,
    theta_bnd=(-1.2*np.pi, 2.8*np.pi),
    width=FIGURE_SIZE_1_0[0],
    height=FIGURE_SIZE_1_0[1],
    alpha=[0.4 for _ in range(5)], 
    thickness=[3 for _ in range(5)],
    group_by=lambda x: x.acados_name,
    latex_style=USE_LATEX_STYLE
)
show(p)

In [None]:
save_figures_button([(f'datasets_{HORIZONS[0]}M_{DATAPOINTS[0]}steps', p)], SVG_RESULTS_DIR, PNG_RESULTS_DIR)