In [1]:
#DRQN

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import numpy.typing as npt
import torch
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.ppo.ppo import PPO

from agent import GraphExtractor
from graph import get_ba
from graph.utils import get_edge_list, directed2undirected
from swing_data import SwingData
from swing_env import SwingEnv

In [2]:
num_nodes = 20
mean_degree = 4.0
random_engine = np.random.default_rng(42)
precision = np.float32

# Network
g = get_ba(num_nodes, mean_degree)
num_edges = g.number_of_edges()
edge_list = get_edge_list(g)
weights = np.ones(num_edges, dtype=precision)


# Swing parameters
phase = random_engine.uniform(0, 2 * np.pi, num_nodes).astype(precision, copy=False)
power = np.array([1] * int(num_nodes / 2) + [-1] * int(num_nodes / 2), dtype=precision)
random_engine.shuffle(power)
gamma = np.ones(num_nodes, dtype=precision)
mass = np.ones(num_nodes, dtype=precision)
dt = 0.001

swing_data = SwingData(
    edge_list=edge_list,
    phase=phase,
    dphase=np.zeros_like(phase),
    coupling=weights,
    power=power,
    gamma=gamma,
    mass=mass,
)

In [3]:
# RL environment
swing_env = SwingEnv(swing_data, dt, equilibrium_step = 123)
check_env(swing_env)


In [4]:
# torch_geometric
device = torch.device("cuda:0")

edge_index = directed2undirected(edge_list, device)  # (2, E)
edge_attr = torch.unsqueeze(
    torch.tensor(np.concatenate([weights, weights]), device=device), -1
)  # (E, 1)

policy_kwargs = {
    "features_extractor_class": GraphExtractor,
    "features_extractor_kwargs": {
        "edge_index": edge_index,
        "edge_attr": edge_attr,
        "hidden_dim": 16,
    },
}
model = PPO(
    "MultiInputPolicy",
    swing_env,
    policy_kwargs=policy_kwargs,
    verbose=2,
    device=device,
    n_steps=2048,
)

# model.get_parameters()
model.learn(total_timesteps=2048 * 5)

Using cuda:0 device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 7.87     |
|    ep_rew_mean     | 9.54     |
| time/              |          |
|    fps             | 252      |
|    iterations      | 1        |
|    time_elapsed    | 8        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 7.88        |
|    ep_rew_mean          | 9.63        |
| time/                   |             |
|    fps                  | 210         |
|    iterations           | 2           |
|    time_elapsed         | 19          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.014613959 |
|    clip_fraction        | 0.16        |
|    clip_range           | 0.2         |
|    entropy_loss

<stable_baselines3.ppo.ppo.PPO at 0x7f1b56e7bd30>

In [None]:
# model.predict()