# CT-LTI: Multi-sample Training and Eval
In this notebook we train over different graphs and initial-target state pairs.
We change parametrization slightly from the single sample, using Xavier normal instead of Kaiming initialization and higher decelaration rate for training. Preliminary results on few runs indicated the above choices would lead to faster convergence on BA and Tree graphs. Still, extensive hyper-parameter optimization would be preferable in the future, especially to optimize performance further.

Please make sure that the required data folder is available at the paths used by the script.
You may generate the required data by running the python script
```nodec_experiments/ct_lti/gen_parameters.py```.


## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch
from torchdiffeq import odeint

import numpy as np
import pandas as pd
import networkx as nx

from tqdm.auto import tqdm

from nnc.controllers.baselines.ct_lti.dynamics import ContinuousTimeInvariantDynamics
from nnc.controllers.baselines.ct_lti.optimal_controllers import ControllabiltyGrammianController

from nnc.helpers.torch_utils.graphs import adjacency_tensor, drivers_to_tensor
from nnc.helpers.graph_helper import load_graph
from nnc.helpers.torch_utils.evaluators import FixedInteractionEvaluator
from nnc.helpers.torch_utils.losses import FinalStepMSE
from nnc.helpers.torch_utils.trainers import NODECTrainer


from nnc.controllers.neural_network.nnc_controllers import NNCDynamics
from nnc.helpers.torch_utils.nn_architectures.fully_connected import StackedDenseTimeControl

from plotly import graph_objects as go
from plotly.subplots import make_subplots

## Load graph and dynamics parameters

In [3]:
experiment_data_folder = '../../data/parameters/ct_lti/'
graph='tree' # please use one of the following: lattice, ba, tree
device = 'cuda:0'

results_data_folder = '../../data/results/ct_lti/multi_sample/'+graph + '/'
os.makedirs(results_data_folder, exist_ok=True)


In [4]:
# load graph data

graph_folder = experiment_data_folder+graph+'/'
adj_matrix = torch.load(graph_folder+'adjacency.pt').to(dtype=torch.float, device=device)
n_nodes = adj_matrix.shape[0]
drivers = torch.load(graph_folder + 'drivers.pt')
n_drivers = len(drivers)
pos = pd.read_csv(graph_folder + 'pos.csv').set_index('index').values
driver_matrix = drivers_to_tensor(n_nodes, drivers).to(device)

# select dynamics type and initial-target states

dyn = ContinuousTimeInvariantDynamics(adj_matrix, driver_matrix)

target_states = torch.load(graph_folder+'target_states.pt').to(device)
initial_states = torch.load(experiment_data_folder+'init_states.pt').to(device)

# total time for control

total_time=0.5


## Train and evaluate all baselines

In [None]:
# For all sample indices
for i in tqdm(range(initial_states.shape[0])):

    current_sample_id = i
    # load current sample
    x0 = initial_states[current_sample_id].unsqueeze(0)
    xstar = target_states[current_sample_id].unsqueeze(0)
    
    # calculate optimal control
    oc = ControllabiltyGrammianController(
        adj_matrix,
        driver_matrix,
        total_time,
        x0,
        xstar,
        simpson_evals=100,
        progress_bar=tqdm,
        use_inverse=False,
    )
    
    # OC evaluations for different interaciton intervals.
    loss_fn = FinalStepMSE(xstar, total_time=total_time)
    all_n_interactions = [50, 500, 5000]
    for n_interactions in all_n_interactions:
        oc_evaluator = FixedInteractionEvaluator(
            'oc_sample'+str(current_sample_id)+'_ninter_' + str(n_interactions),
            log_dir=results_data_folder,
            n_interactions=n_interactions,
            loss_fn=loss_fn,
            ode_solver=None,
            ode_solver_kwargs={'method' : 'dopri5'},
            preserve_intermediate_states=False,
            preserve_intermediate_controls=True,
            preserve_intermediate_times=False,
            preserve_intermediate_energies=False,
            preserve_intermediate_losses=False,
            preserve_params=False,
        )
        oc_res = oc_evaluator.evaluate(dyn, oc, x0, total_time, epoch=0)
        oc_evaluator.write_to_file(oc_res)
        # neural network controller
        
        
    # prepare neural network.
    torch.manual_seed(1)

    nn = StackedDenseTimeControl(n_nodes, 
                                 n_drivers, 
                                 n_hidden=0,#1, 
                                 hidden_size=15,#*n_nodes,
                                 activation=torch.nn.functional.elu,
                                 use_bias=True
                                ).to(x0.device)
    nndyn = NNCDynamics(dyn, nn).to(x0.device)
    
    nn_trainer = NODECTrainer(
        nndyn,
        x0,
        xstar,
        total_time,
        obj_function=None,
        optimizer_class = torch.optim.LBFGS,
        optimizer_params=dict(lr=1.2,
                              #momentum =0.5
                              max_iter=1,
                              max_eval=1,
                              history_size=100
                             ),
        ode_solver_kwargs=dict(method='dopri5'),
        logger=None,
        closure=None,
        use_adjoint=False,
    )
    
    # here we initialize with Xavier which seemed to help NODEC converge faster for tree/ba graphs
    for name, param in nn.named_parameters():
        if len(param.shape) > 1:
            torch.nn.init.xavier_normal_(param)
            
    # here we use higher decelaration rate, which seemed to help NODEC converge faster for tree/ba graphs
    # train for 100 epochs
    nndyn = nn_trainer.train_best(epochs=100, 
                                  lr_acceleration_rate=0,
                                  lr_deceleration_rate=0.99,
                                  loss_variance_tolerance=10,
                                  verbose=True
                                 )
    
    # Evaluate after 100 epochs of training for 50 interactions.
    nn_logger_50 = FixedInteractionEvaluator('nn_sample_'+str(current_sample_id)+'_train_50',
                                             log_dir=results_data_folder,
                                             n_interactions=50,
                                             loss_fn=loss_fn,
                                             ode_solver=None,
                                             ode_solver_kwargs={'method' : 'dopri5'},
                                             preserve_intermediate_states=False,
                                             preserve_intermediate_controls=False,
                                             preserve_intermediate_times=False,
                                             preserve_intermediate_energies=False,
                                             preserve_intermediate_losses=False,
                                             preserve_params=True,
                                            )
    nn_res = nn_logger_50.evaluate(dyn, nndyn.nnc, x0, total_time, epoch=100)
    nn_logger_50.write_to_file(nn_res)
    
    # keep training for 2400 epochs
    nndyn = nn_trainer.train_best(epochs=2400, 
                      lr_acceleration_rate=0,
                      lr_deceleration_rate=0.99,
                      loss_variance_tolerance=10,
                      verbose=True)
    
    # evaluate for 500 interactions
    nn_logger_500 = FixedInteractionEvaluator(
            'nn_sample_'+str(current_sample_id)+'_train_500',
            log_dir=results_data_folder,
            n_interactions=500,
            loss_fn=loss_fn,
            ode_solver=None,
            ode_solver_kwargs={'method' : 'dopri5'},
            preserve_intermediate_states=False,
            preserve_intermediate_controls=False,
            preserve_intermediate_times=False,
            preserve_intermediate_energies=False,
            preserve_intermediate_losses=False,
            preserve_params=False,
        )
    nn_res = nn_logger_500.evaluate(dyn, nndyn.nnc, x0, total_time, epoch=2500)
    nn_logger_500.write_to_file(nn_res)
    
    # evaluate for 5000 interactions
    nn_logger_5000= FixedInteractionEvaluator(
            'nn_sample_'+str(current_sample_id)+'_train_5000',
            log_dir=results_data_folder,
            n_interactions=5000,
            loss_fn=loss_fn,
            ode_solver=None,
            ode_solver_kwargs={'method' : 'dopri5'},
            preserve_intermediate_states=False,
            preserve_intermediate_controls=False,
            preserve_intermediate_times=False,
            preserve_intermediate_energies=False,
            preserve_intermediate_losses=False,
            preserve_params=True,
        )
    nn_res = nn_logger_5000.evaluate(dyn, nndyn.nnc, x0, total_time, epoch=2500)
    nn_logger_5000.write_to_file(nn_res)
    