# SIRX: Training RL

Baseline comparion in terms of total loss and energy.

To run this script:
1. Please make sure that the required data folder is available at the paths used by the script.
You may generate the required data by running the python script
```nodec_experiments/sirx/gen_parameters.py```.

2. The scripts below:
 - ```nodec_experiments/sirx/sirx.py```
 - ```nodec_experiments/sirx/rl_utils.py```
 - ```nodec_experiments/sirx/sirx_utils.py```
contain very important utilities for running training , evaluation and plotting scripts. Please make sure that they are available in the python path when running experiments.

Reinforcement Learning requires some significant time to train.

As neural network intialization is stochastic, please make sure that appropriate seeds are used or expect some variance to paper results.

## Imports

In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append("../../") # append modules from parent dir

import time

import copy

import numpy as np
import gym
from gym.spaces import Box
import numpy as np
import torch
from torchdiffeq import odeint, odeint_adjoint
from nnc.controllers.neural_network.nnc_controllers import NNCDynamics
from nnc.helpers.torch_utils.graphs import drivers_to_tensor

In [None]:
from sirx import SIRDelta, neighborhood_mask, flat_to_channels, GCNNControl
from rl_utils import SIRXEnv, RLGCNN, Actor, Critic

import tianshou as ts
from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer, to_torch
from tianshou.exploration import GaussianNoise

from torch.utils.tensorboard import SummaryWriter

In [None]:
device = 'cuda:0'
dtype = torch.float

### Graph parameters

In [None]:
graph = 'lattice'
parameters_folder = '../../../data/parameters/sirx/'
results_folder = '../../../results/sirx/'+graph+'/'

graph_parameters_folder = parameters_folder + '/' + 'lattice' + '/'

adjacency_matrix = torch.load(graph_parameters_folder + 'adjacency.pt', map_location=device).to(dtype)
n_nodes = adjacency_matrix.shape[-1]
drivers = torch.load(graph_parameters_folder + 'drivers.pt', map_location='cpu').to(torch.long)
driver_matrix = drivers_to_tensor(n_nodes, drivers).to(dtype=dtype, device=device)
alpha = adjacency_matrix
beta = driver_matrix
side_size = int(np.sqrt(n_nodes))

### Dynamics Parameters

In [None]:
x0 = torch.load(graph_parameters_folder + 'initial_state.pt').to(device=device, dtype=dtype)
target_subgraph = torch.load(graph_parameters_folder + 'target_subgraph_nodes.pt')
dynamics_params = torch.load(graph_parameters_folder + 'dynamics_parameters.pt')
# budget and rates need to be choosen according to graph size
budget = dynamics_params['budget']
infection_rate = dynamics_params['infection_rate']
recovery_rate = dynamics_params['recovery_rate']
total_time = 5 # determined via no control testing

In [None]:
sirx_dyn = SIRDelta(
             adjacency_matrix=alpha,
             infection_rate=infection_rate,
             recovery_rate=recovery_rate,
             driver_matrix=beta,
             k_0=0.0,
            ).to(device=device, dtype=dtype)

In [None]:
rl_dt = 0.01 # RL interaction frequency
env_config={
    'sirx' : sirx_dyn,
    'target_nodes' : target_subgraph.tolist(),
    'dt' : rl_dt,
    'T' : total_time,
    'ode_solve_method' : 'dopri5',
    'reward_type' : 'sum_to_max',
    'x0' : x0,
    'budget' : budget    
}

In [None]:
train_envs = ts.env.DummyVectorEnv([lambda: SIRXEnv(env_config) for _ in range(2)])
test_envs = ts.env.DummyVectorEnv([lambda: SIRXEnv(env_config) for _ in range(2)])

### RL Neural Networks
If you check code you will see that it has the same learnable parameters and structure as the network
used for NODEC before the decision layer.

In [None]:
mask, ninds = neighborhood_mask(alpha)
in_preprocessor = lambda x: flat_to_channels(x, n_nodes=n_nodes, mask=mask, inds=ninds)

policy_net = RLGCNN(
                   adjacency_matrix = alpha,
                   driver_matrix = beta, 
                   input_preprocessor = in_preprocessor,
                   in_channels=4,
                   feat_channels=5,
                   message_passes=4
                  )

actor = Actor(model = policy_net, device=device).to(device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=0.0003)

critic1 = Critic(1, 4096, 512, device=device).to(device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=1e-4)

critic2 = Critic(1, 4096, 512, device=device).to(device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=1e-4)


In [None]:
# for transfer learning we can literally load the model
#actor.model.load_state_dict(torch.load('../sir/sirx_best.torch'))
secs = int(round(time.time()))
log_path = results_folder + 'rl/td3/time_'+str(secs)
log_path

In [None]:
# Policy training proceedure
# evaluation environment
env = SIRXEnv(env_config)


# YOu can change TD3 to SAC or any other contious action policy provided from tianshou
policy = TD3Policy(
    actor = actor,
    actor_optim = actor_optim,
    critic1 = critic1,
    critic1_optim = critic1_optim,
    critic2 = critic2,
    critic2_optim = critic2_optim,
    tau= 0.005,
    gamma = 0.999,
    exploration_noise = GaussianNoise(0.01),
    policy_noise = 0.001,
    update_actor_freq = 5,
    noise_clip = 0.5,
    action_range =  [env.action_space.low[0], env.action_space.high[0]],
    reward_normalization = True,
    ignore_done = False,
)


   
# Experience Collector
train_collector = Collector(
    policy, train_envs, ReplayBuffer(8000))
test_collector = Collector(policy, test_envs)
writer = SummaryWriter(log_path)

def save_fn(policy):
    # save best model
    torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

# trainer
result = offpolicy_trainer(
    policy = policy,
    train_collector = train_collector,
    test_collector = test_collector,
    max_epoch = 100,
    step_per_epoch = len(env.time_steps),
    collect_per_step = 1,
    episode_per_test = 1,
    batch_size = len(env.time_steps),
    save_fn = save_fn,
    writer = writer,
    log_interval = 1,
    verbose = True,
)
