# SIRX: Performance Comparison

Here we evaluate how each model performs under fixed interaction intervals.

Baseline comparion in terms of total loss and energy.

To run this script:
1. Please make sure that the required data folder is available at the paths used by the script.
You may generate the required data by running the python script
```nodec_experiments/sirx/gen_parameters.py```.

2. The plots use the training results.
Please also make sure that a training proceedures for both RL and NODEC have produced results in the corresponding paths used in plot and table scripts.
Running ```nodec_experiments/sirx/nodec_train.ipynb``` and ```nodec_experiments/sirx/nodec_train.ipynb```with default paths is expected to generate at the requiered location for the plots and table scripts in each folder.

3. The scripts below:
 - ```nodec_experiments/sirx/sirx.py```
 - ```nodec_experiments/sirx/rl_utils.py```
 - ```nodec_experiments/sirx/sirx_utils.py```
contain very important utilities for running training , evaluation and plotting scripts. Please make sure that they are available in the python path when running experiments.

Reinforcement Learning requires some significant time to train.

As neural network intialization is stochastic, please make sure that appropriate seeds are used or expect some variance to paper results.

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
sys.path.append('../../')

import numpy as np
import pandas as pd
import torch
from torchdiffeq import odeint, odeint_adjoint

import networkx as nx

from plotly import graph_objects as go
from plotly import figure_factory as ff

from sirx import SIRDelta, GCNNControl, flat_to_channels, neighborhood_mask

import random

from tqdm.auto import tqdm

import torch

import copy
import timeit
from tqdm.auto import tqdm

# Here we use a custom trajectory evaluator that was used as the basis of the FixedIntervalEvaluator
from sirx_utils import trajectory_eval
from plotly import express as px


from nnc.controllers.neural_network.nnc_controllers import NNCDynamics
from nnc.helpers.torch_utils.graphs import drivers_to_tensor


from sirx import SIRDelta, neighborhood_mask, flat_to_channels, GCNNControl
from rl_utils import SIRXEnv, RLGCNN, Actor, Critic, transform_u

import tianshou as ts
from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer, to_torch
from tianshou.exploration import GaussianNoise

## Parameters and data
Plese change to ```'cuda:0'``` to use a gpu or ```'cpu'``` to use the cpu.
Here we load the adjacency matrix of a square lattice with $1024$ nodes.

In [None]:
device = 'cuda:0'
dtype = torch.float

In [None]:
# Graph parameters

graph = 'lattice'
parameters_folder = '../../../data/parameters/sirx/'
training_results_folder = '../../../results/sirx/'+graph+'/'
results_folder = '../../../results/sirx/'+graph+'/'

graph_parameters_folder = parameters_folder + '/' + 'lattice' + '/'
evaluation_results_folder = results_folder + 'eval/'
os.makedirs(evaluation_results_folder, exist_ok=True)

adjacency_matrix = torch.load(graph_parameters_folder + 'adjacency.pt', map_location=device).to(dtype)
n_nodes = adjacency_matrix.shape[-1]
drivers = torch.load(graph_parameters_folder + 'drivers.pt', map_location='cpu').to(torch.long)
driver_matrix = drivers_to_tensor(n_nodes, drivers).to(dtype=dtype, device=device)
alpha = adjacency_matrix
beta = driver_matrix
side_size = int(np.sqrt(n_nodes))

In [None]:
# Dynamics Paramters

x0 = torch.load(graph_parameters_folder + 'initial_state.pt').to(device=device, dtype=dtype)
target_subgraph = torch.load(graph_parameters_folder + 'target_subgraph_nodes.pt')
dynamics_params = torch.load(graph_parameters_folder + 'dynamics_parameters.pt')
budget = dynamics_params['budget']
infection_rate = dynamics_params['infection_rate']
recovery_rate = dynamics_params['recovery_rate']
total_time = 5

In [None]:
interaction_interval = 0.001

### NODEC Evaluation
We use the best model observed during training

In [None]:
mask, ninds = neighborhood_mask(alpha)
in_preprocessor = lambda x: flat_to_channels(x, n_nodes=n_nodes, mask=mask, inds=ninds)
nodec_controller = GCNNControl(alpha, 
                   beta, 
                   input_preprocessor=in_preprocessor, 
                   budget=budget, 
                   in_channels=4, 
                   feat_channels=5).to(device=device, dtype=dtype)
nodec_controller.load_state_dict(torch.load(training_results_folder + 'nodec_best.pt'))
nodec_dynamics = SIRDelta(adjacency_matrix=alpha,
                          infection_rate=infection_rate,
                          recovery_rate=recovery_rate,
                          driver_matrix=beta,
                          k_0=0,
                         ).to(device=device)

all_xnn, all_unn = trajectory_eval(nodec_dynamics, 
                                   x_0=x0, 
                                   model = nodec_controller, 
                                   method='dopri5', 
                                   T=total_time, 
                                   dt=interaction_interval
                                  )

#sanity checks that nodec does not somehow cheat
for i in range(all_unn.shape[0]):
    assert np.where(all_unn[i]>0)[0].tolist() == drivers.tolist()
    assert np.all(all_unn[i]>=0)
    assert np.isclose(all_unn[i].sum(), budget) #less or equal also work, but apparently nn is assigning all budget

np.save(evaluation_results_folder + 'nodec_states.npy', all_xnn)
np.save(evaluation_results_folder + 'nodec_control_signal.npy', all_unn)
all_xnn[:, target_subgraph].mean(-1).max()

## No control Evaluation
Evolution without control

In [None]:
no_control_dynamics = SIRDelta(
             adjacency_matrix=alpha,
             infection_rate=infection_rate,
             recovery_rate=recovery_rate,
             driver_matrix=beta,
             k_0=0
            )
all_xnc, all_unc = trajectory_eval(no_control_dynamics, 
                                   x_0=x0, 
                                   model = None, 
                                   method='dopri5', 
                                   T=total_time, 
                                   dt=interaction_interval
                                  )
np.save(evaluation_results_folder + 'no_control_states.npy', all_xnc)
np.save(evaluation_results_folder + 'no_control_control_signal.npy', all_unc)
all_xnc[:, target_subgraph].mean(-1).max()

## Constant Control Evaluation
We allocate constant control to all driver nodes in target subgraph

In [None]:
# find eligible driver nodes inside the target subgraph
target_subgraph_drivers = set(target_subgraph.tolist()).intersection(set(drivers.cpu().tolist()))
target_subgraph_drivers = sorted(list(target_subgraph_drivers))
control = torch.zeros([n_nodes], device=device)
control[target_subgraph_drivers] = budget/len(target_subgraph_drivers)

constant_control_dynamics = SIRDelta(
             adjacency_matrix=alpha,
             infection_rate=infection_rate,
             recovery_rate=recovery_rate,
             driver_matrix=beta,
             k_0=control,
            )

# sanity check that we do the control assignment correctly
assert torch.where(control>0)[0].tolist() == target_subgraph_drivers
assert torch.all(control>=0)
assert torch.isclose(control.sum(), torch.tensor(float(budget)))

all_xcc, all_ucc = trajectory_eval(constant_control_dynamics, x_0=x0, model = None, method='dopri5', T=total_time, dt=interaction_interval)
all_ucc = control.repeat(int(total_time/interaction_interval)-1, 1)

np.save(evaluation_results_folder + 'constant_control_states.npy', all_xcc)
np.save(evaluation_results_folder + 'constant_control_signal.npy', all_ucc.cpu().detach().numpy())
all_xcc[:, target_subgraph].mean(-1).max()

## Random Control Evaluation
We allocate control randomly across nodes per interaction

In [None]:
def random_control(x, t, beta, budget):
    u = torch.rand([1, n_nodes]).to(device=device)*beta.sum(-1)# (torch.randn([1, n_nodes]).cuda()+10)*beta.sum(-1)
    u = budget*(u/u.sum(-1))
    return u.to(device=device)


random_control_dynamics = SIRDelta(adjacency_matrix=alpha,
                                   infection_rate=infection_rate,
                                   recovery_rate=recovery_rate,
                                   driver_matrix=beta,
                                   k_0=0
                                  )
random_controller = lambda x,t: random_control(x,t, beta=beta, budget=budget)
all_xrn, all_urn = trajectory_eval(random_control_dynamics, x_0=x0, model = random_controller, method='dopri5', T=total_time, dt=interaction_interval)
all_xnc[:, :n_nodes].mean(-1).max()
all_xrn[:, target_subgraph].mean(-1).max()

# check if random control somehow managed to cheat
for i in range(all_urn.shape[0]):
    assert np.where(all_urn[i]>0)[0].tolist() == drivers.cpu().tolist()
    assert np.all(all_urn[i]>=0)
    assert np.isclose(all_urn[i].sum(), budget)

np.save(evaluation_results_folder + 'random_control_states.npy', all_xrn)
np.save(evaluation_results_folder + 'random_control_signal.npy', all_urn)
all_xrn[:, target_subgraph].mean(-1).max()

## RL Evaluation
We use the best model saved during TD3 training, since other methods did not perform that well.

In [None]:
ask, ninds = neighborhood_mask(alpha)
in_preprocessor = lambda x: flat_to_channels(x, n_nodes=n_nodes, mask=mask, inds=ninds)

policy_net = RLGCNN(
                   adjacency_matrix = alpha,
                   driver_matrix = beta, 
                   input_preprocessor = in_preprocessor,
                   in_channels=4,
                   feat_channels=5,
                   message_passes=4
                  )

actor = Actor(model = policy_net, device=device).to(device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=0.0003)

critic1 = Critic(1, 4096, 512, device=device).to(device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=1e-4)

critic2 = Critic(1, 4096, 512, device=device).to(device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=1e-4)

rl_dt = 0.01 # RL interaction frequency

rl_dynamics = SIRDelta(adjacency_matrix=alpha,
                          infection_rate=infection_rate,
                          recovery_rate=recovery_rate,
                          driver_matrix=beta,
                          k_0=0,
                         ).to(device=device)

env_config={
    'sirx' : rl_dynamics,
    'target_nodes' : target_subgraph.tolist(),
    'dt' : rl_dt,
    'T' : total_time,
    'ode_solve_method' : 'dopri5',
    'reward_type' : 'sum_to_max',
    'x0' : x0,
    'budget' : budget    
}

env = SIRXEnv(env_config)

policy = TD3Policy(
    actor = actor,
    actor_optim = actor_optim,
    critic1 = critic1,
    critic1_optim = critic1_optim,
    critic2 = critic2,
    critic2_optim = critic2_optim,
    tau= 0.005,
    gamma = 0.999,
    exploration_noise = GaussianNoise(0.01),
    policy_noise = 0.001,
    update_actor_freq = 5,
    noise_clip = 0.5,
    action_range =  [env.action_space.low[0], env.action_space.high[0]],
    reward_normalization = True,
    ignore_done = False,
)

policy.load_state_dict(torch.load(training_results_folder + 'rl/td3/time_1608770137/policy.pth'))

rl_dynamics = SIRDelta(
                 adjacency_matrix=alpha,
                 infection_rate=infection_rate,
                 recovery_rate=recovery_rate,
                 driver_matrix=beta,
                 k_0=0,
              )

model = lambda x,t: transform_u(policy.actor(x)[0], driver_matrix=beta, budget=budget)
all_xrl, all_url = trajectory_eval(rl_dynamics, x_0=x0, model =model)
np.save(evaluation_results_folder + 'td3_control_states.npy', all_xrl)
np.save(evaluation_results_folder + 'td3_control_signal.npy', all_url)
all_xrl[:, target_subgraph].mean(-1).max()