# CT-LTI: Single Sample Training
This notebook contains the adaption from the training script.
It can produced data for plotting and trains one NODEC and one OC baselines on a given control setting.
This script has not be tested on cpu only machines so please use with care and edit any gpu induced errors.

Furthermore, please make sure that the required data folder is available at the paths used by the script.
You may generate the required data by running the python script
```nodec_experiments/ct_lti/gen_parameters.py```.

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.sys.path.append('../../../')

from copy import deepcopy
import torch
from torchdiffeq import odeint

import numpy as np
import pandas as pd
import networkx as nx

from tqdm.cli import tqdm

from nnc.controllers.baselines.ct_lti.dynamics import ContinuousTimeInvariantDynamics
from nnc.controllers.baselines.ct_lti.optimal_controllers import ControllabiltyGrammianController

from nnc.helpers.torch_utils.graphs import adjacency_tensor, drivers_to_tensor
from nnc.helpers.graph_helper import load_graph
from nnc.helpers.torch_utils.evaluators import FixedInteractionEvaluator
from nnc.helpers.torch_utils.losses import FinalStepMSE
from nnc.helpers.torch_utils.trainers import NODECTrainer

from nnc.helpers.torch_utils.file_helpers import read_tensor_from_collection
from nnc.controllers.neural_network.nnc_controllers import NNCDynamics
from nnc.helpers.torch_utils.nn_architectures.fully_connected import StackedDenseTimeControl

from plotly import graph_objects as go
from plotly.subplots import make_subplots

## Loading the Data and Experiment Parameters

In [None]:
# we define the data folder and the device
experiment_data_folder = '../../../../data/parameters/ct_lti/'
graph='lattice'
device = 'cuda:0' #'cuda:0' if cuda is available to speed up experiements by a lot.

results_data_folder = '../../../../results/ct_lti/single_sample/'
os.makedirs(results_data_folder, exist_ok=True)


In [None]:
# load graph data
graph_folder = experiment_data_folder+graph+'/'
adj_matrix = torch.load(graph_folder+'adjacency.pt').to(dtype=torch.float, device=device)
n_nodes = adj_matrix.shape[0]
drivers = torch.load(graph_folder + 'drivers.pt').to(dtype=torch.long, device=device)
n_drivers = len(drivers)
pos = pd.read_csv(graph_folder + 'pos.csv').set_index('index').values
driver_matrix = drivers_to_tensor(n_nodes, drivers).to(dtype=torch.float, device=device)

# select dynamics type and initial-target states
dyn = ContinuousTimeInvariantDynamics(adj_matrix, driver_matrix)

target_states = torch.load(graph_folder+'target_states.pt').to(dtype=torch.float, device=device)
initial_states = torch.load(experiment_data_folder+'init_states.pt').to(dtype=torch.float, device=device)

# we pick sample 24 as it looks  nice
current_sample_id = 24

# we load the corresponding initial and target states
x0 = initial_states[current_sample_id].unsqueeze(0) # we introduce a singular batch dimension
xstar = target_states[current_sample_id].unsqueeze(0) # we introduce a singular batch dimension

# total time for control
total_time=0.5


In [None]:
# we plot the states using plotly. Square lattice can be directly embedded on a heatmap.

initial_state_fig =  go.Heatmap(z=x0.view(32,32).cpu().numpy(), zmin=-1, zmax=1, 
                               colorscale='Plasma', 
                               colorbar=None, showscale=False, showlegend=False)
target_state_fig =   go.Heatmap(z=xstar.view(32,32).cpu().numpy(), zmin=-1, zmax=1, colorscale='Plasma',  
                                colorbar=dict(title ='State Value'))
fig = make_subplots(cols=2, subplot_titles=("Initials State", "Target State"))
fig.add_trace(initial_state_fig, row=1, col=1)
fig.add_trace(target_state_fig, row=1, col=2)

fig.update_layout(dict(
                       width = 500, 
                       height = 200, 
                       margin = dict(t=20, b=2, l=2, r=50),
                      )
                 )
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)
fig.data[0].showscale = False
fig

## Optimal Control 
### Calculate Optimal Control Parameters

In [None]:
# optimal control parametrization
oc = ControllabiltyGrammianController(
    adj_matrix, # in paper symbol A
    driver_matrix, # in paper symbol B
    total_time, # in paper T
    x0, # in paper x(0)
    xstar, # in paper x^*
    simpson_evals=100, # number of simpson evaluations
    progress_bar=tqdm, # a progress bar on simpson evals
    use_inverse=False, # Whether to use torch.inverse or torch.solve for grammian calculation
)

### Evaluate Optimal Control

In [None]:
# optimal control evaluations
loss_fn = FinalStepMSE(xstar, total_time=total_time) # the loss function to keep during evaluation

# number of interaction, if we divide T with n_interactions we get the interaction intervals reported in paper
all_n_interactions = [50, 500, 5000] 

for n_interactions in all_n_interactions:
    oc_evaluator = FixedInteractionEvaluator(
        'oc_sample_ninter_' + str(n_interactions),
        log_dir=results_data_folder,
        n_interactions=n_interactions,
        loss_fn=loss_fn,
        ode_solver=None,
        ode_solver_kwargs={'method' : 'dopri5'},
        preserve_intermediate_states=False,
        preserve_intermediate_controls=True,
        preserve_intermediate_times=False,
        preserve_intermediate_energies=True,
        preserve_intermediate_losses=True,
        preserve_params=False,
    )
    oc_res = oc_evaluator.evaluate(dyn, oc, x0, total_time, epoch=0)
    oc_evaluator.write_to_file(oc_res)

## Neural Network
### Initialize Neural Network

In [None]:
# Neural Network controller is generated here! The seed is set in an effort to improve reproducability.
torch.manual_seed(1)

# The neural network:
nn = StackedDenseTimeControl(n_nodes, 
                             n_drivers, 
                             n_hidden=0,#1 layer is created for 0. 
                             hidden_size=15,#*n_nodes,
                             activation=torch.nn.functional.elu,
                             use_bias=True
                            ).to(x0.device)

# The dynamics that allow gradient flows 
nndyn = NNCDynamics(dyn, nn).to(x0.device)

# This evaluator is used to to log the parameters while training.
nn_logger = FixedInteractionEvaluator(
        'nn_sample_train',
        log_dir=results_data_folder,
        n_interactions=500,
        loss_fn=loss_fn,
        ode_solver=None,
        ode_solver_kwargs={'method' : 'dopri5'},
        preserve_intermediate_states=False,
        preserve_intermediate_controls=False,
        preserve_intermediate_times=False,
        preserve_intermediate_energies=False,
        preserve_intermediate_losses=False,
        preserve_params=True,
    )

# The trainer following algorithm 3 from the paper appendix.
nn_trainer = NODECTrainer(
    nndyn,
    x0,
    xstar,
    total_time,
    obj_function=None,
    optimizer_class = torch.optim.LBFGS,
    optimizer_params=dict(lr=1.2,
                          #momentum =0.5
                          max_iter=1,
                          max_eval=1,
                          history_size=100
                         ),
    ode_solver_kwargs=dict(method='dopri5'),
    logger=nn_logger,
    closure=None,
    use_adjoint=False,
)

In [None]:
# Neural Network parameter init, we have tested Xavier and Kaiming
# For the current example the Kaiming divided by 1000 would yield better models more often.
# Please feel free to change, if you would like to evaluate in a single example case.
torch.manual_seed(1)
for name, param in nn.named_parameters():
    if len(param.shape) > 1:
        torch.nn.init.kaiming_normal_(param) # or torch.nn.xavier(param)
        param = param/100.0

### Train  NODEC

In [None]:
%%time
# The training process... May take a lot of time without gpu
nndyn = nn_trainer.train_best(epochs=2500, 
                              lr_acceleration_rate=0,
                              lr_deceleration_rate=0.9,
                              loss_variance_tolerance=10,
                              verbose=True
                             )
print()

### Evaluate NODEC
First we evaluate the trained model for 500 and 5000 interactions, then we will evaluate it for less by loading earlier epoch parameters.

In [None]:
# control evaluations using the evaluator for all interactions similar to OC
loss_fn = FinalStepMSE(xstar, total_time=total_time)
all_n_interactions = [500, 5000] # we skip 50, because we want to use weights from earlier epoch
for n_interactions in all_n_interactions:
    nn_evaluator = FixedInteractionEvaluator(
        'eval_nn_sample_ninter_' + str(n_interactions),
        log_dir=results_data_folder,
        n_interactions=n_interactions,
        loss_fn=loss_fn,
        ode_solver=None,
        ode_solver_kwargs={'method' : 'dopri5'},
        preserve_intermediate_states=False,
        preserve_intermediate_controls=True,
        preserve_intermediate_times=False,
        preserve_intermediate_energies=True,
        preserve_intermediate_losses=True,
        preserve_params=False,
    )
    nn_res = nn_evaluator.evaluate(dyn, nndyn.nnc, x0, total_time, epoch=0)
    nn_evaluator.write_to_file(nn_res)

Evaluate and save for the highest interaction interval $10^{-2}$ with 50 interactions

In [None]:
nndyn2 = deepcopy(nndyn)
n_interactions = 50
high_interval_epoch = 100
params = read_tensor_from_collection(results_data_folder + 'nn_sample_train/' + 'epochs', 
                                     'nodec_params/ep_'+str(high_interval_epoch)+'.pt')

nndyn2.nnc.load_state_dict(params)
nn_evaluator = FixedInteractionEvaluator(
    'eval_nn_sample_ninter_' + str(n_interactions),
    log_dir=results_data_folder,
    n_interactions=n_interactions,
    loss_fn=loss_fn,
    ode_solver=None,
    ode_solver_kwargs={'method' : 'dopri5'},
    preserve_intermediate_states=False,
    preserve_intermediate_controls=True,
    preserve_intermediate_times=False,
    preserve_intermediate_energies=True,
    preserve_intermediate_losses=True,
    preserve_params=False,
)
nn_res = nn_evaluator.evaluate(dyn, nndyn2.nnc, x0, total_time, epoch=0)
nn_evaluator.write_to_file(nn_res)