In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch_geometric.data import Data

In [2]:
from env import *
from gnn_model import GraphGNNModel

In [3]:
import sys
sys.path.append('/home/victorialena/rlkit')

import rlkit

In [4]:
from any_replay_buffer import anyReplayBuffer
from policies import *

from rlkit.samplers.data_collector import MdpPathCollector
# from rlkit.torch.dqn.dqn import DQNTrainer
from dqn import DQNTrainer
from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm

import pdb

In [5]:
n_nodes, n_communities = 100, 4
env = CovidSEIR(n_nodes, n_communities)
x = env.reset()

In [6]:
in_channels = len(NodeState) # feature space
out_channels = len(Measure)

In [7]:
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'

```python 
format_input = lambda x: F.one_hot(torch.Tensor(np.vectorize(int)(x)).to(torch.int64), 
                                   num_classes=len(NodeState)).to(torch.float32)

# format_data = lambda x: (format_input(x.x).to(torch.float32),
#                          x.edge_index,
#                          torch.Tensor(x.edge_attr))
format_data = lambda x: (format_input(x.x).to(torch.float32).to(device),
                         x.edge_index.to(device), 
                         torch.Tensor(x.edge_attr).to(device))
```

```python 
model = GraphGNNModel(in_channels, 64, out_channels)
# s, edge_index, edge_weight = format_data(x)
out = model(*format_data(x), env.community_labels)
```

In [8]:
qf = GraphGNNModel(in_channels, 256, out_channels)
target_qf = GraphGNNModel(in_channels, 256, out_channels)

In [9]:
qf_criterion = nn.MSELoss()
eval_policy = argmaxDiscretePolicy(qf, env.community_labels)
expl_policy = epsilonGreedyPolicy(qf, env.action_space, args=env.community_labels)

In [10]:
expl_path_collector = MdpPathCollector(env, expl_policy)
eval_path_collector = MdpPathCollector(env, eval_policy)

In [11]:
variant = dict(
    algorithm="DQN",
    version="normal",
    layer_size=256,
    replay_buffer_size=int(1E6),
    algorithm_kwargs=dict(
        num_epochs=3000,
        num_eval_steps_per_epoch=5000,
        num_trains_per_train_loop=1000,
        num_expl_steps_per_train_loop=1000,
        min_num_steps_before_training=1000,
        max_path_length=1000,
        batch_size=256,
    ),
    trainer_kwargs=dict(
        discount=0.99,
        learning_rate=3E-4,
    ),
)

```python
variant = dict(
    algorithm="DQN",
    version="normal",
    layer_size=256,
    replay_buffer_size=int(1E6),
    algorithm_kwargs=dict(
        num_epochs=300,
        num_eval_steps_per_epoch=500,
        num_trains_per_train_loop=100,
        num_expl_steps_per_train_loop=100,
        min_num_steps_before_training=100,
        max_path_length=100,
        batch_size=256,
    ),
    trainer_kwargs=dict(
        discount=0.99,
        learning_rate=3E-4,
    ),
)
```

In [12]:
trainer = DQNTrainer(
    qf=qf,
    target_qf=target_qf,
    qf_criterion=qf_criterion,
    args=env.community_labels,
    **variant['trainer_kwargs'])

In [13]:
replay_buffer = anyReplayBuffer(variant['replay_buffer_size'])

In [14]:
algorithm = TorchBatchRLAlgorithm(
    trainer=trainer,
    exploration_env=env,
    evaluation_env=env,
    exploration_data_collector=expl_path_collector,
    evaluation_data_collector=eval_path_collector,
    replay_buffer=replay_buffer,
    **variant['algorithm_kwargs']
)

In [None]:
algorithm.to(device)
algorithm.train()

evaluation sampling
exploration sampling
data storing
qf loss: 43104.99
training
2022-02-14 16:18:20.057202 PST | Epoch 0 finished
-----------------------------  ----------------
epoch                               0
replay_buffer/size               2000
trainer/QF Loss                 43105
trainer/Y Predictions Mean          1.34584
trainer/Y Predictions Std           0.0168273
trainer/Y Predictions Max           1.39327
trainer/Y Predictions Min           1.29518
expl/num steps total             2000
expl/num paths total                2
expl/path length Mean            1000
expl/path length Std                0
expl/path length Max             1000
expl/path length Min             1000
expl/Rewards Mean                 201.89
expl/Rewards Std                   57.5488
expl/Rewards Max                  300
expl/Rewards Min                    0
expl/Returns Mean              201890
expl/Returns Std                    0
expl/Returns Max               201890
expl/Returns Min           

expl/Returns Mean              293273
expl/Returns Std                    0
expl/Returns Max               293273
expl/Returns Min               293273
expl/Actions Mean                   1.96
expl/Actions Std                    0.251794
expl/Actions Max                    2
expl/Actions Min                    0
expl/Num Paths                      1
expl/Average Returns           293273
eval/num steps total            20000
eval/num paths total               20
eval/path length Mean            1000
eval/path length Std                0
eval/path length Max             1000
eval/path length Min             1000
eval/Rewards Mean                 300
eval/Rewards Std                    0
eval/Rewards Max                  300
eval/Rewards Min                  300
eval/Returns Mean              300000
eval/Returns Std                    0
eval/Returns Max               300000
eval/Returns Min               300000
eval/Actions Mean                   2
eval/Actions Std                    0
ev

eval/Returns Min               300000
eval/Actions Mean                   2
eval/Actions Std                    0
eval/Actions Max                    2
eval/Actions Min                    2
eval/Num Paths                      5
eval/Average Returns           300000
time/data storing (s)               0.000706581
time/evaluation sampling (s)      165.479
time/exploration sampling (s)      31.8539
time/logging (s)                    0.0110178
time/saving (s)                     0.000123004
time/training (s)                 871.256
time/epoch (s)                   1068.6
time/total (s)                   7389.46
Epoch                               6
-----------------------------  ----------------
evaluation sampling
exploration sampling
data storing
qf loss: 79195.625
training
2022-02-14 18:21:17.333604 PST | Epoch 7 finished
-----------------------------  ----------------
epoch                               7
replay_buffer/size               9000
trainer/QF Loss                 79195.6
tr

epoch                              10
replay_buffer/size              12000
trainer/QF Loss                 81618.7
trainer/Y Predictions Mean          3.66016
trainer/Y Predictions Std           0.891438
trainer/Y Predictions Max           4
trainer/Y Predictions Min           4.74268e-15
expl/num steps total            12000
expl/num paths total               12
expl/path length Mean            1000
expl/path length Std                0
expl/path length Max             1000
expl/path length Min             1000
expl/Rewards Mean                 291.756
expl/Rewards Std                   38.7404
expl/Rewards Max                  300
expl/Rewards Min                   19
expl/Returns Mean              291756
expl/Returns Std                    0
expl/Returns Max               291756
expl/Returns Min               291756
expl/Actions Mean                   1.9505
expl/Actions Std                    0.284692
expl/Actions Max                    2
expl/Actions Min                    0
expl

expl/Actions Std                    0.284949
expl/Actions Max                    2
expl/Actions Min                    0
expl/Num Paths                      1
expl/Average Returns           292248
eval/num steps total            70000
eval/num paths total               70
eval/path length Mean            1000
eval/path length Std                0
eval/path length Max             1000
eval/path length Min             1000
eval/Rewards Mean                 300
eval/Rewards Std                    0
eval/Rewards Max                  300
eval/Rewards Min                  300
eval/Returns Mean              300000
eval/Returns Std                    0
eval/Returns Max               300000
eval/Returns Min               300000
eval/Actions Mean                   2
eval/Actions Std                    0
eval/Actions Max                    2
eval/Actions Min                    2
eval/Num Paths                      5
eval/Average Returns           300000
time/data storing (s)               0.00067

eval/Num Paths                      5
eval/Average Returns           300000
time/data storing (s)               0.000679171
time/evaluation sampling (s)      165.907
time/exploration sampling (s)      33.8953
time/logging (s)                    0.0114579
time/saving (s)                     0.000121868
time/training (s)                 848.094
time/epoch (s)                   1047.91
time/total (s)                  17900.9
Epoch                              16
-----------------------------  ----------------
evaluation sampling
exploration sampling
data storing
qf loss: 80993.17
training
2022-02-14 21:16:21.360983 PST | Epoch 17 finished
-----------------------------  ----------------
epoch                              17
replay_buffer/size              19000
trainer/QF Loss                 80993.2
trainer/Y Predictions Mean          3.65234
trainer/Y Predictions Std           0.833991
trainer/Y Predictions Max           4
trainer/Y Predictions Min           6.24706e-17
expl/num steps to