# Pruned Network generation for NH-AMPC

## Imports

In [None]:
import os
import pandas as pd
import torch

from itertools import product
from pprint import pprint
from tqdm.notebook import tqdm_notebook

In [None]:
from src.neural_horizon import NN_for_casadi, load_NN
from src.parameters import NH_AMPC_Param
from src.utils import get_features_and_labels
from src.torch_utils import count_parameters

from src.pruning.prun_dataclasses import Node_Prun_LTH, Node_Prun_Finetune

## Settings

In [None]:
RESULTS_DIR = os.path.abspath('Results')
MPC_DATASETS_DIR = os.path.join(RESULTS_DIR, 'MPC_data_gen')
NNS_DIR = os.path.join(RESULTS_DIR, 'Trained_Networks')
PRUNED_NNS_DIR = os.path.join(RESULTS_DIR, 'Prun_Networks')

NUM_SAMPLES = 30_000

RETRAIN_NNS = False         # Skip training of networks if already existent
PRUNE_NNS = True            # Can be set to False to get only the scores of existing ones
NUM_NNS = 50

USE_CUDA = True

DROP_OLD_R2SCORES = False

MPC_PARAM_DICT = {
    'T_sim': 5, # length of the closed-loop simulation (in seconds)
}

NH_AMPC_OPTIONS = [tup for tup in product(
    (8, ),                                      # N_MPC
    (17, ),                                     # N_NN -> if USE_BEGINING_OF_DATASET != 'begin, it has to be 70-max(N_MPCs)
    (70, ),                                     # N_DATASET
    (5, ),                                      # TRAIN_DATASET_VERSION
    (6, ),                                      # TEST_DATASET_VERSION
    ('fixed', ),                                # USE_BEGINING_OF_DATASET ('begin', 'fixed', '') if '' use DS_FEATURE of N_MPC
    (8, ),                                      # DS_FEATURE_IF_FIXED
    ('RTI_PCHPIPM_DISCRETE', ),                 # DATASET_NAME_ADD ('RTI_PCHPIPM_DISCRETE', 'RTI_PCHPIPM_DISCRETE_50ITER', 'RTI_PCHPIPM_ROBUST_DISCRETE')
    (48, ),                                     # HIDDEN_NEURONS
    (24, ),                           # END_HIDDEN_SIZES (left from pruning)
)]

NH_AMPC_OPTIONS.extend([tup for tup in product(
    (8, ),                                      # N_MPC
    (22, ),                                     # N_NN -> if USE_BEGINING_OF_DATASET != 'begin, it has to be 70-max(N_MPCs)
    (70, ),                                     # N_DATASET
    (5, ),                                      # TRAIN_DATASET_VERSION
    (6, ),                                      # TEST_DATASET_VERSION
    ('fixed', ),                                # USE_BEGINING_OF_DATASET ('begin', 'fixed', '') if '' use DS_FEATURE of N_MPC
    (8, ),                                      # DS_FEATURE_IF_FIXED
    ('RTI_PCHPIPM_DISCRETE', ),                 # DATASET_NAME_ADD ('RTI_PCHPIPM_DISCRETE', 'RTI_PCHPIPM_DISCRETE_50ITER', 'RTI_PCHPIPM_ROBUST_DISCRETE')
    (64, ),                                     # HIDDEN_NEURONS
    (24, 32, ),                           # END_HIDDEN_SIZES (left from pruning)
)])



In [None]:
pprint(NH_AMPC_OPTIONS)
print(NNS_DIR)
print(PRUNED_NNS_DIR)

In [None]:
# parameters needed for neural horizon acados MPC
NH_AMPC_PARAMS = [
    NH_AMPC_Param(
        # Param
        N_MPC = N_MPC, 
        N = N_NN+N_MPC,

        # Dataset stuff
        N_DS = N_DSET, 
        TRAIN_V_DS = TRAIN_DATASET_VERSION, 
        TEST_V_DS = TEST_DATASET_VERSION, 
        DS_begin = USE_BEGIN,
        DS_samples = NUM_SAMPLES,
        DS_opts_name = DS_OPT_NAME,
        DS_feature = DS_FEATURES,

        # NN stuff
        V_NN = NN_VERSION,
        N_hidden = N_HIDDEN,
        N_hidden_end = END_N_SIZES,

        # Param
        **MPC_PARAM_DICT
    ) for NN_VERSION in range(NUM_NNS) \
        for N_MPC, N_NN, N_DSET, TRAIN_DATASET_VERSION, TEST_DATASET_VERSION, USE_BEGIN,\
            DS_FEATURES, DS_OPT_NAME, N_HIDDEN, END_N_SIZES in NH_AMPC_OPTIONS \
]

## Train Networks

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() and USE_CUDA else 'cpu') 
dtype = torch.float32
print(device)

In [None]:
if os.path.exists(NNS_DIR):
    nn_paths = os.listdir(NNS_DIR)
else:
    nn_paths = []


with tqdm_notebook(total=len(NH_AMPC_PARAMS), unit='Networks', desc='Network train progress: ') as tqdm_handle:
    for nh_ampc_params in NH_AMPC_PARAMS:
        name = f'{nh_ampc_params.N_MPC}M_{nh_ampc_params.N_NN}N {nh_ampc_params.N_hidden}Nh_{nh_ampc_params.V_NN}v' \
            if nh_ampc_params.N_hidden_end is None else \
            f'{nh_ampc_params.N_MPC}M_{nh_ampc_params.N_NN}N {nh_ampc_params.N_hidden}Nh_{nh_ampc_params.N_hidden_end}NhP_{nh_ampc_params.V_NN}v'
        tqdm_handle.set_description_str(f'Get trajectory of:\n{name}')
        
        # skip already existent NNs
        if nh_ampc_params.NN_name in nn_paths and not RETRAIN_NNS:
            tqdm_handle.update(1)
            continue
        
        mpc_dataset_file = os.path.join(MPC_DATASETS_DIR, nh_ampc_params.train_DS_name)

        features, labels = get_features_and_labels(nh_ampc_params)
        Unpruned_NN_fc = NN_for_casadi(
            mpc_dataset_file, 
            nh_ampc_params, 
            features=features,
            labels=labels,
            device=device, 
            dtype=dtype
        )
        Unpruned_NN_fc.NNcompile(show_tqdm=False, n_neurons=nh_ampc_params.N_hidden)
        Unpruned_NN_fc.NNsave(file=nh_ampc_params.NN_name, filedir=NNS_DIR)

        tqdm_handle.update(1)

## Prune and retrain Networks

In [None]:
if PRUNE_NNS:
    with tqdm_notebook(total=len(NH_AMPC_PARAMS), unit='Networks', desc='Network prun progress: ') as tqdm_handle:
        for nh_ampc_params in NH_AMPC_PARAMS:
            name = f'{nh_ampc_params.N_MPC}M_{nh_ampc_params.N_NN}N {nh_ampc_params.N_hidden}Nh_{nh_ampc_params.V_NN}v' \
                if nh_ampc_params.N_hidden_end is None else \
                f'{nh_ampc_params.N_MPC}M_{nh_ampc_params.N_NN}N {nh_ampc_params.N_hidden}Nh_{nh_ampc_params.N_hidden_end}NhP_{nh_ampc_params.V_NN}v'
            tqdm_handle.set_description_str(f'Prune NN:\n{name}')
             
            Pruned_NN_fc = load_NN(nh_ampc_params, NNS_DIR, MPC_DATASETS_DIR, device, dtype, force_load_unpruned=True)
    
            amount = Pruned_NN_fc.NN.n_neurons[1] - nh_ampc_params.N_hidden_end
            prun_params = Node_Prun_LTH(1, amount, dim=1)
            Pruned_NN_fc.NNprunCasadi(prun_params, show_tqdm=False)
            Pruned_NN_fc.NNsave(file=nh_ampc_params.Pruned_NN_name, filedir=PRUNED_NNS_DIR)
            
            tqdm_handle.update(1)

## Network evaluations 

In [None]:
NN_evals = []

for nh_ampc_params in NH_AMPC_PARAMS:
    print('#' + '='*100)    
    Pruned_NN_fc = load_NN(nh_ampc_params, PRUNED_NNS_DIR, MPC_DATASETS_DIR, device, dtype)

    test_datasets_file = os.path.join(MPC_DATASETS_DIR, nh_ampc_params.test_DS_name)
    r2_score, relative_error = Pruned_NN_fc.evaluate_NN(test_datasets_file)
    
    NN_evals.append({
        'N_NN': nh_ampc_params.N_NN,
        'N_hidden': nh_ampc_params.N_hidden,
        'N_hidden_end': nh_ampc_params.N_hidden_end,
        'Version': nh_ampc_params.V_NN,
        'R2_score': r2_score, 
        'Rel_err_mean': 100*relative_error.mean(),
        'Rel_err_std': 100*relative_error.std(),
        'NN_param_size': count_parameters(Pruned_NN_fc.NN),
    })

In [None]:
scores = pd.DataFrame.from_dict(NN_evals).set_index(['N_NN', 'N_hidden', 'N_hidden_end', 'Version'])
scores.head(10)

## Save evaluations 

In [None]:
scores_path = os.path.join(RESULTS_DIR, 'PrunedR2scores.pkl')

if not DROP_OLD_R2SCORES and os.path.exists(scores_path):
    existing_scores = pd.read_pickle(scores_path)s
    scores = existing_scores.append(scores)
    scores = scores[~scores.index.duplicated(keep='last')]
    
scores.to_pickle(scores_path)