In [None]:
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from config import RLConfig, SwingConfig
from environment.grid import Grid
from environment.swing.solver import swing_solver
from environment.trajectory import is_stable


In [None]:
def find_steady_state(
    grid: Grid, rng: np.random.Generator
) -> list[npt.NDArray[np.float32]]:
    """Find random initial state of given grid.
    Args
        grid: In which grid to find initial state
        rng: used for initializing random phase
    Return
        phase: [N, ] phase of each node in (-pi, pi]
        dphase: [N, ] dphase (angular frequency) of each node
    """

    # Initial random state
    # phase = rng.uniform(-np.pi, np.pi, grid.num_nodes).astype(np.float32)
    phase = np.zeros(grid.num_nodes, dtype=np.float32)
    dphase = np.zeros_like(phase)

    dphases = []

    # Run swing equation until reaching steady time
    time = 0.0
    solver = swing_solver(grid.weighted_adjacency_matrix, grid.params)
    while time < RLConfig.steady_time:
        time += SwingConfig._dt
        phase, dphase = solver(phase=phase, dphase=dphase)
        dphases.append(dphase)
        if is_stable(dphase):
            break

    return dphases

grid = Grid(rng=0)


In [None]:
# grid.reset_nodes()
# grid.reset_node_types()
# grid.reset_graph()

grid.info()

dphases = np.stack(find_steady_state(grid, grid.rng))
time = np.arange(0.0, RLConfig.steady_time, SwingConfig._dt)
print(
    f"Steady at t={len(dphases) * SwingConfig._dt}, max_dphase={np.abs(dphases[-1]).max()}"
)

for i, dphase in enumerate(dphases.T):
    plt.plot(time[: len(dphase)], dphase, label=i)
plt.legend()
plt.xlim(0, time[-1])
plt.show()