# SIRX: NODEC Training


Baseline comparion in terms of total loss and energy.

To run this script:
1. Please make sure that the required data folder is available at the paths used by the script.
You may generate the required data by running the python script
```nodec_experiments/sirx/gen_parameters.py```.

2. The scripts below:
 - ```nodec_experiments/sirx/sirx.py```
 - ```nodec_experiments/sirx/rl_utils.py```
 - ```nodec_experiments/sirx/sirx_utils.py```
contain very important utilities for running training , evaluation and plotting scripts. Please make sure that they are available in the python path when running experiments.

As neural network intialization is stochastic, please make sure that appropriate seeds are used or expect some variance to paper results.

SIRX experiments are briefly presented in this notebook.
This code was developed after CT-LTI and there was not enough time to be included in the main repository.
It has been tested only in large networks.
Please run it either in the provided lattice setting generate a graph that is large enough. 
For a smaller graph you might need to pick recovery and infection rate parameters that are smaller:
E.g. infection rate gamma = 8 us too high for a 16 node graph.
The experimental code will be moved to the main nnc repo code soon

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.sys.path.append('../../')

import numpy as np
import pandas as pd
import torch
from torchdiffeq import odeint, odeint_adjoint

import networkx as nx

from plotly import graph_objects as go
from plotly import figure_factory as ff

from sirx import SIRDelta, GCNNControl, flat_to_channels, neighborhood_mask

import random

from tqdm.auto import tqdm

import torch

import copy
import timeit
from tqdm.cli import tqdm

# Here we use a custom trajectory evaluator that was used as the basis of the FixedIntervalEvaluator
from sirx_utils import trajectory_eval
from plotly import express as px


from nnc.controllers.neural_network.nnc_controllers import NNCDynamics
from nnc.helpers.torch_utils.graphs import drivers_to_tensor


## Parameters and data
Plese change to ```'cuda:0'``` to use a gpu or ```'cpu'``` to use the cpu.
Here we load the adjacency matrix of a square lattice with $1024$ nodes.

In [None]:
device = 'cuda:0'
dtype = torch.float

### Graph parameters

In [None]:
graph = 'lattice'
parameters_folder = '../../../data/parameters/sirx/'
results_folder = '../../../results/sirx/'+graph+'/'

graph_parameters_folder = parameters_folder + '/' + 'lattice' + '/'

adjacency_matrix = torch.load(graph_parameters_folder + 'adjacency.pt', map_location=device).to(dtype)
n_nodes = adjacency_matrix.shape[-1]
drivers = torch.load(graph_parameters_folder + 'drivers.pt', map_location='cpu').to(torch.long)
driver_matrix = drivers_to_tensor(n_nodes, drivers).to(dtype=dtype, device=device)
alpha = adjacency_matrix
beta = driver_matrix
side_size = int(np.sqrt(n_nodes))

### Dynamics Parameters

In [None]:
x0 = torch.load(graph_parameters_folder + 'initial_state.pt', map_location=device).to(device=device, dtype=dtype)
target_subgraph = torch.load(graph_parameters_folder + 'target_subgraph_nodes.pt', map_location=device)
dynamics_params = torch.load(graph_parameters_folder + 'dynamics_parameters.pt', map_location=device)
# budget and rates need to be choosen according to graph size
budget = dynamics_params['budget']
infection_rate = dynamics_params['infection_rate']
recovery_rate = dynamics_params['recovery_rate']
total_time = 5 # determined via no control testing

In [None]:
sirx_dyn = SIRDelta(
             adjacency_matrix=alpha,
             infection_rate=infection_rate,
             recovery_rate=recovery_rate,
             driver_matrix=beta,
             k_0=0.0,
            ).to(device=device, dtype=dtype)

### Neural network initialization

In [None]:
mask, ninds = neighborhood_mask(alpha)
in_preprocessor = lambda x: flat_to_channels(x, n_nodes=n_nodes, mask=mask, inds=ninds)
cnet = GCNNControl(alpha, 
                   beta, 
                   input_preprocessor=in_preprocessor,
                   budget=budget, 
                   in_channels=4, 
                   feat_channels=5).to(device=device, dtype=dtype)
cd = NNCDynamics(sirx_dyn, cnet).to(device=device, dtype=dtype)

#### Training

In [None]:
sample_points = 50 # to check for loss

best_model = [cnet]
best_loss = [np.inf]
c_model =  [cnet]

cnet = best_model[0]
lr = 0.07
optim = torch.optim.Adam(cnet.parameters(), lr=lr)

learning_curve = []
pbar =  tqdm(range(100), postfix=dict(peak_position = 0, peak_infection = 0, now_loss = 0))
for i in pbar:
    crit = torch.nn.MSELoss()
    now_loss = []
    def closure():
        optim.zero_grad()
        allx = odeint(cd, x0, t=torch.linspace(0, total_time, sample_points).to(device, dtype=dtype), method='dopri5')

        peak_point = torch.max(allx[:, :, target_subgraph].mean(-1),0)

        sel_losses = peak_point[1].cpu().numpy().tolist()

        x_reached = allx[sel_losses, :, :].squeeze(0)

        l = crit(x_reached[:, target_subgraph], torch.zeros_like(x_reached[:, target_subgraph]).detach())
        learning_curve.append(l.item())

        l.backward(retain_graph=False)
        now_loss.append(l.cpu().detach().item())
        pbar.set_postfix(dict(peak_position = peak_point[1].cpu().detach().item(), 
                  peak_infection = peak_point[0].cpu().detach().item()),
                  now_loss = l.item())
        return l

    def closee():
        lll = optim.step(closure)

    try: 
        closee()
    except AssertionError:
        print('possible instability due to stiffness!')
        cd.control_net = best_model[0]
        cnet = best_model[0] # preservation of best model.
        lr/=2 # Learning rate adaption.
        optim = torch.optim.Adam(cnet.parameters(), lr=lr)

    if len(now_loss) > 0 and np.mean(now_loss) < best_loss[0]:
        best_model[0] =  copy.deepcopy(cd.nnc.neural_net)
        best_loss[0] = np.mean(now_loss)
    if device == 'cuda':
        torch.cuda.empty_cache()

### Saving best model and leaning curve

In [None]:
lrc = pd.Series(learning_curve, name='learning_curve')
lrc.to_json(results_folder + '/nodec_learning_curve.json', orient='records')
torch.save(best_model[0].state_dict(), results_folder + 'nodec_best.pt')