In [1]:
from typing import cast

import numpy as np
import torch
from agent.graph_extractor import GraphExtractor
from environment.environment import Environment
from environment.grid import Grid
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from torch_geometric.nn import summary

In [2]:
device = torch.device("cuda:0")
rng = np.random.default_rng(0)

grid = Grid(rng=rng)
env = Environment(grid, rng=rng, verbose=2)
# check_env(env)

policy_kwargs = {"features_extractor_class": GraphExtractor}
model = PPO(
    "MultiInputPolicy",
    env,
    policy_kwargs=policy_kwargs,
    verbose=2,
    device=device,
    n_steps=128,
)

feature_extractor = cast(GraphExtractor, model.policy.features_extractor)
obs = {
    key: torch.tensor(value, dtype=torch.float32, device=device).unsqueeze(0)
    for key, value in env.observe().items()
}
print(summary(feature_extractor, obs))


Initialization: Found steady state after 1 trials
Initialization: perturbation will be [ 0  0  0  0  0  0  0  0  0 -1]
Using cuda:0 device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
+-------------------------+-------------------+----------------+----------+
| Layer                   | Input Shape       | Output Shape   | #Param   |
|-------------------------+-------------------+----------------+----------|
| GraphExtractor          |                   | [1, 640]       | 9,001    |
| ├─(node_type)Embedding  | [1, 10]           | [1, 10, 8]     | 32       |
| ├─(phase)MLP            | [1, 10, 2]        | [1, 10, 8]     | 96       |
| │    └─(mlp)Sequential  | [1, 10, 2]        | [1, 10, 8]     | 96       |
| │    │    └─(0)Linear   | [1, 10, 2]        | [1, 10, 8]     | 24       |
| │    │    └─(1)Identity | [1, 10, 8]        | [1, 10, 8]     | --       |
| │    │    └─(2)ELU      | [1, 10, 8]        | [1, 10, 8]     | --       |
| │    │    └─(3)Identit

In [3]:
model.learn(total_timesteps=128 * 10)


Reset: Found steady state after 2 trials
Reset: perturbation will be [ 0  0  0  0  0  0  0  0  0 -1]
Step: after pre-processing, action=[0.   0.44 0.   0.   0.89 1.   0.7  0.85 0.   0.  ]
Step: successfully finished rebalancing, reward=-2.75e-02
Step: next perturbation will be [0 0 0 0 0 0 0 0 0 1]
Step: after pre-processing, action=[0.   0.48 0.   0.   0.35 1.   1.   0.27 0.   0.  ]
Step: No steady state after rebalancing. reward=-1.00e+00
Step: next perturbation will be [0 0 0 0 0 0 0 0 0 1]
Reset: Found steady state after 1 trials
Reset: perturbation will be [ 0  0  0  0  0  0  0  0  0 -1]
Step: after pre-processing, action=[0.   0.68 0.   0.   0.21 0.55 0.16 0.81 0.   0.  ]
Step: No steady state after rebalancing. reward=-1.00e+00
Step: next perturbation will be [0 0 0 0 0 0 0 0 0 1]
Reset: Found steady state after 2 trials
Reset: perturbation will be [1 0 0 0 0 0 0 0 0 0]
Step: after pre-processing, action=[0.   0.44 0.   0.   0.13 0.98 0.73 0.22 0.   0.  ]
Step: successfully fini

<stable_baselines3.ppo.ppo.PPO at 0x7f9e935b73a0>