# SIRX: Figure 7
Comparison of peak infection and used energy from baselines
Other comparisons are also provided for all state variables.

To run this script:
1. Please make sure that the required data folder is available at the paths used by the script.
You may generate the required data by running the python script
```nodec_experiments/sirx/gen_parameters.py```.

2. The plots use the training results.
Please also make sure that a training proceedures for both RL and NODEC have produced results in the corresponding paths used in plot and table scripts.
Running ```nodec_experiments/sirx/nodec_train.ipynb``` and ```nodec_experiments/sirx/nodec_train.ipynb```with default paths is expected to generate at the requiered location for the plots and table scripts in each folder.

3. Sample evaluation is done across alla baseliens before running the plots that also require the following script to run:
`nodec_experiments/sirx/eval_baselines.ipynb`

4. Extra scripts on experiments that did not produce good results may not be provide for the sake of space and brevity.

5. The scripts below:
 - ```nodec_experiments/sirx/sirx.py```
 - ```nodec_experiments/sirx/rl_utils.py```
 - ```nodec_experiments/sirx/sirx_utils.py```
contain very important utilities for running training , evaluation and plotting scripts. Please make sure that they are available in the python path when running experiments.

Reinforcement Learning requires some significant time to train.

As neural network intialization is stochastic, please make sure that appropriate seeds are used or expect some variance to paper results.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
sys.path.append('../../')
import numpy as np
import torch
import pandas as pd

import networkx as nx

from sirx_utils import comparison, sirx_curves, stack_plot_grid, heats_for_steps

from plotly import graph_objects as go
from plotly import figure_factory as ff
import plotly.express as px

from nnc.helpers.plot_helper import base_layout, sci_notation

In [None]:
device = 'cpu'
dtype = torch.float

### Graph parameters

In [None]:
graph = 'lattice'
parameters_folder = '../../../data/parameters/sirx/'
results_folder = '../../../results/sirx/'+graph+'/'
evaluation_results_folder = results_folder + 'eval/'

graph_parameters_folder = parameters_folder + '/' + 'lattice' + '/'

adjacency_matrix = torch.load(graph_parameters_folder + 'adjacency.pt', map_location=device).to(dtype)
n_nodes = adjacency_matrix.shape[-1]
side_size = int(np.sqrt(n_nodes))


### Dynamics Parameters

In [None]:
initial_infection_nodes = torch.load(graph_parameters_folder + 'initial_infection_nodes.pt')
target_subgraph = torch.load(graph_parameters_folder + 'target_subgraph_nodes.pt')

## Evaluation Data

In [None]:
all_unn = np.load(evaluation_results_folder + "nodec_control_signal.npy")
all_ucc = np.load(evaluation_results_folder + 'constant_control_signal.npy')
all_urn = np.load(evaluation_results_folder + 'random_control_signal.npy')
all_url = np.load(evaluation_results_folder + 'td3_control_signal.npy')

In [None]:
all_xnc = np.load(evaluation_results_folder + "no_control_states.npy")
all_xnn = np.load(evaluation_results_folder + "nodec_states.npy")
all_xcc = np.load(evaluation_results_folder + "constant_control_states.npy")
all_xrn = np.load(evaluation_results_folder + "random_control_states.npy")
all_xrl = np.load(evaluation_results_folder + "td3_control_states.npy")


## Plots
Infection, energy and other comparisons

In [None]:
T = 5
dat = { 'NNC' : all_xnn[:, :n_nodes][:, target_subgraph].mean(-1),
        'TCC' : all_xcc[:, :n_nodes][:, target_subgraph].mean(-1),
        'F'   : all_xnc[:, :n_nodes][:, target_subgraph].mean(-1),
        'RND' : all_xrn[:, :n_nodes][:, target_subgraph].mean(-1),
        'RL'  : all_xrl[:, :n_nodes][:, target_subgraph].mean(-1)
}

cols = { 'NNC' : '#1b9e77',
         'TCC' : '#d95f02',
         'F'   : '#75b270',
         'RND' : '#bebada',
         'RL'  : '#000000'
}
tim = np.linspace(0, T, all_xnn.shape[0])
name = 'Mean Infected Fraction'




fig = comparison(name, tim, dat, cols)
fig.layout.width = 220
fig.layout.height = 220
fig.update_layout(base_layout)
fig.update_layout(legend=dict(orientation="v",
                              x=0.6,
                              y=1.01,                                
                              bgcolor="rgba(0,0,0,0)",
                              bordercolor="Black",
                              borderwidth=0, 
                              font = dict(size=10)
                             ),
                  margin = dict(t=10),
                  font = dict(size=15)
                 )
fig.add_shape(type="line",
    x0=0, y0=0.2, x1=5, y1=0.2,
    line=dict(
        color="Red",
        width=1,
        dash="dash",
    )
)
fig.add_shape(type="line",
    x0=0, y0=0.2, x1=5, y1=0.2,
    line=dict(
        color="Red",
        width=1,
        dash="dash",
    )
)

fig.add_annotation(
            x=1,
            y=0.3,
            text='Hospital Capacity',
            xref="paper",
            yref="paper",
            showarrow=False,
            font_size=10,           
)
fig.update_xaxes(nticks=7)
fig.update_yaxes(tickvals=[0, 0.1, 0.2, 0.3, 0.4, 0.5])
fig.update_yaxes(dtick=0.1)

fig

In [None]:
dat = { 'NNC' : all_xnn[:, n_nodes:2*n_nodes][:, target_subgraph].mean(-1),
        'TCC' : all_xcc[:, n_nodes:2*n_nodes][:, target_subgraph].mean(-1),
        'F'   : all_xnc[:, n_nodes:2*n_nodes][:, target_subgraph].mean(-1),
        'RND' : all_xrn[:, n_nodes:2*n_nodes][:, target_subgraph].mean(-1),
        'RL'  : all_xrl[:, n_nodes:2*n_nodes][:, target_subgraph].mean(-1),
}
name = 'Mean Susceptible Fraction'

fig = comparison(name, tim, dat, cols)
fig.layout.width = 220
fig.layout.height = 220
fig.update_layout(base_layout)
fig.update_layout(legend=dict(orientation="v",
                              x=0.6,
                              y=1.02,                                
                              bgcolor="rgba(0,0,0,0)",
                              bordercolor="Black",
                              borderwidth=0,
                              font = dict(size=10)
                             ),
                  margin = dict(t=10),
                  font = dict(size=15)
                 )
fig

In [None]:
dat = { 'NNC' : all_xnn[:, 2*n_nodes:3*n_nodes][:, target_subgraph].mean(-1),
        'TCC' : all_xcc[:, 2*n_nodes:3*n_nodes][:, target_subgraph].mean(-1),
        'F'   : all_xnc[:, 2*n_nodes:3*n_nodes][:, target_subgraph].mean(-1),
        'RND' : all_xrn[:, 2*n_nodes:3*n_nodes][:, target_subgraph].mean(-1) ,
        'RL'  : all_xrl[:, 2*n_nodes:3*n_nodes][:, target_subgraph].mean(-1)  
}

name = 'Mean Recovered Fraction'

fig = comparison(name, tim, dat, cols)
fig.layout.width = 220
fig.layout.height = 220
fig.update_layout(base_layout)
fig.update_layout(legend=dict(orientation="v",
                              x=0.6,
                              y=0,                                
                              bgcolor="rgba(0,0,0,0)",
                              bordercolor="Black",
                              borderwidth=0, 
                              font = dict(size=10)
                             ),
                  margin = dict(t=10),
                  font = dict(size=15)
                 )

fig

In [None]:
dat = { 'NNC' : all_xnn[:, 3*n_nodes:4*n_nodes][:, target_subgraph].mean(-1),
        'TCC' : all_xcc[:, 3*n_nodes:4*n_nodes][:, target_subgraph].mean(-1),
        'F'   : all_xnc[:, 3*n_nodes:4*n_nodes][:, target_subgraph].mean(-1),
        'RND' : all_xrn[:, 3*n_nodes:4*n_nodes][:, target_subgraph].mean(-1),   
        'RL'  : all_xrl[:, 3*n_nodes:4*n_nodes][:, target_subgraph].mean(-1)   
}

name = 'Mean Eff. Containment Fraction'
#fig.write_image(outfolder+'sir_inf_comp.pdf')

fig = comparison(name, tim, dat, cols)
fig.layout.width = 220
fig.layout.height = 220
fig.update_layout(base_layout)
fig.update_layout(legend=dict(orientation="v",
                              x=0.6,
                              y=0.9,                                
                              bgcolor="rgba(0,0,0,0)",
                              bordercolor="Black",
                              borderwidth=0, 
                              font = dict(size=10)
                             ),
                  margin = dict(t=10),
                  font = dict(size=15)
                 )
fig

In [None]:
dat = {'NNC' :((all_unn**2)*0.001).sum(-1).cumsum(-1),
       'TCC' :((all_ucc**2)*0.001).sum(-1).cumsum(-1),
       'RND' :((all_urn**2)*0.001).sum(-1).cumsum(-1),
       'RL' :((all_url**2)*0.001).sum(-1).cumsum(-1)
}

name = 'Total Energy'
#fig.write_image(outfolder+'sir_inf_comp.pdf')

fig = comparison(name, tim, dat, cols)

fig.update_layout(         height = 300, width =300,
                          legend=dict(
                                x=-0.0,
                                y=1.5,
                               orientation = 'h',
                                traceorder="normal",
                                font=dict(
                                    family="Times New Roman",
                                    color="black"
                                ),
                                bgcolor=None,
                                bordercolor='rgba(0,0,0,0)',
                                borderwidth=1
                            )
                         )
fig.layout.width = 220
fig.layout.height = 220
fig.update_layout(base_layout)
fig.update_layout(legend=dict(orientation="v",
                              x=0.1,
                              y=1,                                
                              bgcolor="rgba(0,0,0,0)",
                              bordercolor="Black",
                              borderwidth=0, 
                              font = dict(size=10)
                             ),
                  margin = dict(t=10),
                  font = dict(size=15)
                 )
fig.update_xaxes(nticks=7)
fig.update_yaxes(tickvals=[0,2000,4000,6000,8000, 10000,12000, 14000])
fig.update_yaxes(ticktext=[str(d//1000)+'k' if d > 0 else '0' for d in [0,2000,4000,6000,8000, 10000,12000, 14000]])

fig