# CT-LTI: Figure 4.
These are the plots found in Figures 4a, 4b and 4c and containing training metrics over epochs.

Please make sure that the required data folder is available at the paths used by the script.
You may generate the required data by running the python script
```nodec_experiments/ct_lti/gen_parameters.py```.

Please also make sure that a training and an evaluation proceedures has produced results in the corresponding paths used below.
Running ```nodec_experiments/ct_lti/single_sample/train.ipynb``` and 
```nodec_experiments/ct_lti/single_sample/figure_4_evaluate.ipynb```
with default paths is expected to generate at the required location.

As neural network intialization is stochastic, please make sure that appropriate seeds are used or expect some variance to paper results.

## Imports

In [None]:
# %load_ext autoreload
# %autoreload 2

In [None]:
import os
os.sys.path.append('../../../')

import torch
import numpy as np

import pandas as pd

from tqdm.cli import tqdm

import plotly
import plotly.express as px

from nnc.controllers.baselines.ct_lti.dynamics import ContinuousTimeInvariantDynamics
from nnc.controllers.baselines.ct_lti.optimal_controllers import ControllabiltyGrammianController

from nnc.helpers.torch_utils.graphs import adjacency_tensor, drivers_to_tensor
from nnc.helpers.graph_helper import load_graph
from nnc.helpers.torch_utils.evaluators import FixedInteractionEvaluator
from nnc.helpers.torch_utils.losses import FinalStepMSE
from nnc.helpers.torch_utils.trainers import NODECTrainer
from nnc.helpers.torch_utils.file_helpers import read_tensor_from_collection

from nnc.controllers.neural_network.nnc_controllers import NNCDynamics
from nnc.helpers.torch_utils.nn_architectures.fully_connected import StackedDenseTimeControl
from nnc.helpers.plot_helper import base_layout, sci_notation, ColorRegistry

## Load sample and parameters
Here we load the sample that we trained NODEC on as well as the parameters for the dynamics.

In [None]:
device = 'cpu'
graph='lattice'

# load graph data
experiment_data_folder = '../../../../data/parameters/ct_lti/'
results_data_folder = '../../../../results/ct_lti/single_sample/'

graph_folder = experiment_data_folder+graph+'/'
adj_matrix = torch.load(graph_folder+'adjacency.pt').to(dtype=torch.float, device=device)
n_nodes = adj_matrix.shape[0]
drivers = torch.load(graph_folder + 'drivers.pt')
n_drivers = len(drivers)
pos = pd.read_csv(graph_folder + 'pos.csv').set_index('index').values
driver_matrix = drivers_to_tensor(n_nodes, drivers).to(device)


target_states = torch.load(graph_folder+'target_states.pt').to(device)
initial_states = torch.load(experiment_data_folder+'init_states.pt').to(device)

current_sample_id = 24

x0 = initial_states[current_sample_id].unsqueeze(0)
xstar = target_states[current_sample_id].unsqueeze(0)

# total time for control

total_time=0.5

# select dynamics type and initial-target states

dyn = ContinuousTimeInvariantDynamics(adj_matrix, driver_matrix)

# Below is a helper function that loads parameters from a specific epoch and uses them to evaluate.
def check_for_params(params, n_interactions, logdir=None, epoch=0):
    nn = StackedDenseTimeControl(n_nodes, 
                                 n_drivers, 
                                 n_hidden=0, 
                                 hidden_size=15,
                                 activation=torch.nn.functional.elu,
                                 use_bias=True
                                ).to(x0.device)

    nndyn = NNCDynamics(dyn, nn).to(x0.device)
    nndyn.nnc.load_state_dict(params)


    loss_fn = FinalStepMSE(xstar, total_time=total_time)

    nn_evaluator = FixedInteractionEvaluator(
        'early_eval_nn_sample_ninter_' + str(n_interactions),
        log_dir=logdir,
        n_interactions=n_interactions,
        loss_fn=loss_fn,
        ode_solver=None,
        ode_solver_kwargs={'method' : 'dopri5'},
        preserve_intermediate_states=False,
        preserve_intermediate_controls=False,
        preserve_intermediate_times=False,
        preserve_intermediate_energies=False,
        preserve_intermediate_losses=False,
        preserve_params=False,
    )
    nn_res = nn_evaluator.evaluate(dyn, nndyn.nnc, x0, total_time, epoch=epoch)
    return nn_evaluator, nn_res

all_epochs = pd.read_csv(results_data_folder + 'nn_sample_train/epoch_metadata.csv')['epoch']


##  Generating Figure 4a.
For this figure we first need to load all stored parameters per epoch and evaluate them for all 3 different interaction intervals $10^{-2}, 10^{-3}, 10^{-4}$. Since this is a costly operation, we can also choose to reload an existing file if there is one. 

### Getting the data

In [None]:
losses_df = pd.read_csv(results_data_folder + 'nn_sample_train/losses_interactions_training.csv',
                        engine='python')
for i, column in enumerate(losses_df.columns):
    if i  >0:
        losses_df.columns.values[i] = int(column)

In [None]:
# Please check here if columns need to be string or int, different pandas version return different outcomes
losses_melted = losses_df.reset_index().melt(id_vars='Epoch', value_vars=[50,500,5000], 
                                             var_name='Interaction Interval',
                                             value_name = 'Total Loss')
# From total interactions to interval
losses_melted['Interaction Interval'] = total_time/losses_melted['Interaction Interval'].astype(float)

In [None]:
losses_melted['Interaction Interval'] = losses_melted['Interaction Interval'].map(lambda x: sci_notation(x))

## Plotting the figure

In [None]:
train_file = results_data_folder + 'nn_sample_train/'
evaluation_files =  dict(oc_50 = results_data_folder + 'oc_sample_ninter_50/',
                         oc_500 = results_data_folder + 'oc_sample_ninter_500/',
                         oc_5000 = results_data_folder + 'oc_sample_ninter_5000/',
                         nodec_50 = results_data_folder + 'eval_nn_sample_ninter_50/',
                         nodec_500 = results_data_folder + 'eval_nn_sample_ninter_500/',
                         nodec_5000 = results_data_folder + 'eval_nn_sample_ninter_5000/',
                         )

oc_500_df = pd.read_csv(evaluation_files['oc_500'] + 'epoch_metadata.csv')
nn_500_df = pd.read_csv(evaluation_files['nodec_500'] + 'epoch_metadata.csv')
oc_500_loss_val = oc_500_df['final_loss'].values[0]
oc_500_energy_val = oc_500_df['total_energy'].values[0]

df_training = pd.read_csv(train_file+'/epoch_metadata.csv')
epoch_range = [df_training['epoch'].min(), df_training['epoch'].max()]


nodec_500_loss = px.line(df_training[['total_energy', 'epoch', 'final_loss']], x='epoch', y='final_loss').data[0]
nodec_500_loss.name = 'NODEC Loss'



nodec_500_loss = px.line(df_training[['epoch', 'final_loss']], x='epoch', y='final_loss').data[0]
nodec_500_loss.line.color = ColorRegistry.nodec
nodec_500_loss.name = 'NODEC Loss'
nodec_500_loss.showlegend = True

oc_500_loss = px.line(x=epoch_range, y=[oc_500_loss_val, oc_500_loss_val]).data[0]
oc_500_loss.name = 'OC Loss'
oc_500_loss.line.color = ColorRegistry.oc
oc_500_loss.showlegend = True

nodec_500_energy = px.line(df_training[['total_energy', 'epoch', 'final_loss']], x='epoch', y='total_energy').data[0]
nodec_500_energy.line.color = ColorRegistry.nodec
nodec_500_energy.name = 'NODEC Energy'
nodec_500_energy.showlegend = True
nodec_500_energy.line.dash = 'dot'

oc_500_energy = px.line(x=epoch_range, y=[oc_500_energy_val, oc_500_energy_val]).data[0]
oc_500_energy.name = 'OC Energy'
oc_500_energy.line.dash = 'dot'
oc_500_energy.line.color = ColorRegistry.oc
oc_500_energy.showlegend = True


fig_epoch_comparison = plotly.subplots.make_subplots(1,1, specs=[[{"secondary_y": True}]])

fig_epoch_comparison.add_trace(nodec_500_energy, secondary_y=True)
fig_epoch_comparison.add_trace(oc_500_energy, secondary_y=True)

fig_epoch_comparison.add_trace(nodec_500_loss)
fig_epoch_comparison.add_trace(oc_500_loss)

fig_epoch_comparison.update_layout(base_layout)
fig_epoch_comparison.update_yaxes(type='log', exponentformat='power', showgrid=False)
fig_epoch_comparison.update_layout(#width=240, 
                                   #height=180,
                                   width = 210,
                                   height=210,
                                   margin = dict(t=50,b=0,l=0,r=20), 
                                   legend=dict(
                                        orientation="h",
                                  x=0.0,
                                  y=1.4,                                
                                  bgcolor="rgba(0,0,0,0)",
                                  bordercolor="Black",
                                  borderwidth=0
                                )

                                  )
fig_epoch_comparison.layout.yaxis.title = 'Final Loss'
fig_epoch_comparison.layout.yaxis2.title = 'Total Energy'
fig_epoch_comparison.layout.xaxis.title = 'Epoch'

fig_epoch_comparison.layout.yaxis.exponentformat = 'SI'
fig_epoch_comparison.layout.yaxis.nticks = 7

fig_epoch_comparison.update_layout(width=400, height=300)

fig_epoch_comparison

## Generating Figure 4b
For this figure we collect all loss and energy values for $10^{-3}$ interaction interval time per epoch.

In [None]:
param_squared_norms = []
for epoch in tqdm(all_epochs):
    params = read_tensor_from_collection(results_data_folder + 'nn_sample_train/' + 'epochs', 'nodec_params/ep_'+str(epoch)+'.pt')
    squared_norm = sum([(param**2).sum().item() for param in params.values()])
    param_squared_norms.append(squared_norm)

In [None]:
vcolors = np.array(plotly.colors.qualitative.Dark24)
vcolors = [col.replace('rgb', 'rgba').replace(')', ',0.3)') for col in vcolors]

fig = px.line(y=param_squared_norms, x=losses_df.index)
fig.data[0].line.color = vcolors[0]
fig.layout.xaxis.title = 'Epoch'
fig.layout.yaxis.title = r'$||w||_2^2$'

fig.update_layout(base_layout)
fig.update_layout(width=160, height=160, margin = dict(t=0,b=20,l=20,r=0), 
                                  legend=dict(
                                        orientation="h",
                                  font = dict(size=8),
                                  x=0,
                                  y=1.35,                                
                                  bgcolor="rgba(0,0,0,0)",
                                  bordercolor="Black",
                                  borderwidth=0,
                                      title = dict(side = 'top')
                                  )
                                )
fig.update_xaxes(tickangle=45)
fig.update_layout(width=400, height=300)
fig


## Generating Figure 4c
The figure that shows loss values per interaction interval for NODEC.

In [None]:
fig = px.line(losses_melted, x='Epoch', y='Total Loss', color='Interaction Interval', log_y=True, 
              color_discrete_sequence=vcolors, render_mode='svg')
fig.update_layout(base_layout)
#fig.data[0].line.dash = 'dot'
#fig.data[1].line.dash = 'dot'
fig.data[2].line.dash = 'dot'
fig.update_yaxes(exponentformat='power')
fig.update_layout(width=160, height=195, margin = dict(t=35,b=0,l=0,r=0), 
                                  legend=dict(
                                        orientation="h",
                                  font = dict(size=8),
                                  x=0,
                                  y=1.45,                                
                                  bgcolor="rgba(0,0,0,0)",
                                  bordercolor="Black",
                                  borderwidth=0,
                                      title = dict(side = 'top')
                                  )
                                )

fig.layout.yaxis.tickfont = dict(size=9)
fig.layout.yaxis.nticks = 7
fig.layout.yaxis.tickmode = 'auto'

fig.layout.yaxis.exponentformat = 'SI'

fig.update_layout(width=400, height=300)
fig
