# SIRX: Figure 8
Control signal comparison of NODEC, RL and TCC over lattice.

To run this script:
1. Please make sure that the required data folder is available at the paths used by the script.
You may generate the required data by running the python script
```nodec_experiments/sirx/gen_parameters.py```.

2. The plots use the training results.
Please also make sure that a training proceedures for both RL and NODEC have produced results in the corresponding paths used in plot and table scripts.
Running ```nodec_experiments/sirx/nodec_train.ipynb``` and ```nodec_experiments/sirx/nodec_train.ipynb```with default paths is expected to generate at the requiered location for the plots and table scripts in each folder.

3. Sample evaluation is done across alla baseliens before running the plots that also require the following script to run:
`nodec_experiments/sirx/eval_baselines.ipynb`

4. Extra scripts on experiments that did not produce good results may not be provide for the sake of space and brevity.

5. The scripts below:
 - ```nodec_experiments/sirx/sirx.py```
 - ```nodec_experiments/sirx/rl_utils.py```
 - ```nodec_experiments/sirx/sirx_utils.py```
contain very important utilities for running training , evaluation and plotting scripts. Please make sure that they are available in the python path when running experiments.

Reinforcement Learning requires some significant time to train.

As neural network intialization is stochastic, please make sure that appropriate seeds are used or expect some variance to paper results.

## Imports

In [None]:
#%load_ext autoreload
#%autoreload 2

In [None]:
import os
import sys
sys.path.append("../../") # append modules from parent dir
sys.path.append("./sirx_utils.py")
import numpy as np
import torch
import pandas as pd

import networkx as nx

from sirx_utils import comparison, sirx_curves, stack_plot_grid, heats_for_steps

from plotly import graph_objects as go
from plotly import figure_factory as ff
import plotly.express as px
import random

from tqdm.auto import tqdm

import copy



from nnc.controllers.neural_network.nnc_controllers import NNCDynamics
from nnc.helpers.torch_utils.graphs import drivers_to_tensor

In [None]:
device = 'cpu'
dtype = torch.float


### Graph parameters

In [None]:
graph = 'lattice'
parameters_folder = '../../../data/parameters/sirx/'
results_folder = '../../../results/sirx/'+graph+'/'
evaluation_results_folder = results_folder + 'eval/'

graph_parameters_folder = parameters_folder  + 'lattice' + '/'

adjacency_matrix = torch.load(graph_parameters_folder + 'adjacency.pt', map_location=device).to(dtype)
n_nodes = adjacency_matrix.shape[-1]
side_size = int(np.sqrt(n_nodes))

### Dynamics Parameters

In [None]:
initial_infection_nodes = torch.load(graph_parameters_folder + 'initial_infection_nodes.pt')
target_subgraph = torch.load(graph_parameters_folder + 'target_subgraph_nodes.pt')

# Figure 8
Plot of initial and target state

In [None]:
## Our initial-target scheme
xshow = torch.zeros([n_nodes])
xshow[initial_infection_nodes] = 1
colorscale = px.colors.sequential.Plasma
new_cl = [(0, colorscale[0]), (0.5, colorscale[0]), (0.5, colorscale[-1]), (1, colorscale[-1])]
xshow = torch.zeros([n_nodes])
xshow[target_subgraph] = -1
xshow[initial_infection_nodes] = 1

fig = px.imshow(xshow.view(32,32))

fig = go.Figure(fig.data[0])
fig.layout.coloraxis.colorscale = ['#f1a340','#ffffff', '#998ec3']
fig.layout.width = 200
fig.layout.height = 200
fig.layout.margin = dict(t=0, b=0, r=0, l=0)

fig.layout.coloraxis.showscale = False
fig.layout.xaxis.visible = False
fig.layout.yaxis.visible = False

fig.data[0].showscale = False
fig

Plot of controls for different controllers. Clipping high control values from colorscaler

In [None]:
all_unn = np.load(evaluation_results_folder + "nodec_control_signal.npy")
all_ucc = np.load(evaluation_results_folder + 'constant_control_signal.npy')
all_urn = np.load(evaluation_results_folder + 'random_control_signal.npy')
all_url = np.load(evaluation_results_folder + 'td3_control_signal.npy')

#clip 0.005% of outliers for better color visualizaiton
clipmax = np.quantile(np.concatenate([all_unn, all_ucc, all_url]), 0.995) #

Plotting control signals for NODEC, RL, TCC

In [None]:
from sirx_utils import heats_for_steps
steps = [300, 600, 1200, 1600, 2000, 3500]
unn =  copy.deepcopy(all_unn)
unn = np.clip(unn, 0, clipmax)
ucc =  copy.deepcopy(all_ucc)
ucc = np.clip(ucc, 0, clipmax)
urn = copy.deepcopy(all_urn)
urn = np.clip(urn, 0, clipmax)
url = copy.deepcopy(all_url)
url = np.clip(url, 0, clipmax)

figs_rl = heats_for_steps('control_rl',url, px.colors.sequential.Viridis, steps, zmax=4.75, ztickvals=np.linspace(0, 4.75,6))
figs_nn = heats_for_steps('control_nn',unn, px.colors.sequential.Viridis, steps, zmax=4.75, ztickvals=np.linspace(0, 4.75,6))
figs_rn = heats_for_steps('control_rn',urn, px.colors.sequential.Viridis, steps, zmax=4.75, ztickvals=np.linspace(0, 4.75,6))
figs_cc = heats_for_steps('control_cc',ucc, px.colors.sequential.Viridis, steps, zmax=4.75, ztickvals=np.linspace(0, 4.75,6))

stack_fig = stack_plot_grid('u', [figs_nn, figs_rl, figs_cc], 
                            colorscale=px.colors.sequential.Viridis)
stack_fig.update_layout(
    width=700, 
    #height = 200
    height = 200 + 200/3
)
#fig.write_image('latest_plots/controls.pdf')

stack_fig
