In [1]:
import gymnasium as gym
import numpy as np
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
# from stable_baselines3 import DDPG
from sbx import DQN
import pde_opt
from pde_opt.numerics.domains import Domain


In [2]:
def reset_func(domain, seed=0):
    return 0.5 * jnp.ones(domain.points) + 0.01 * random.normal(random.PRNGKey(seed), domain.points)

In [10]:
Nx, Ny = 128, 128
Lx = 0.01 * Nx
Ly = 0.01 * Ny
domain = Domain((Nx, Ny), ((-Lx / 2, Lx / 2), (-Ly / 2, Ly / 2)), "dimensionless")

In [11]:
params = {
    "reset_func": reset_func,
    "diffusion_coefficient": 0.1,
    "max_control_step": 0.1,
    "end_time": 1.0,
    "step_dt": 0.05,
    "numeric_dt": 0.0001,
    "domain": domain,
    "field_dim": 1,
    "reward_function": lambda x: np.var(x),
    "discrete_action_space": True
}

In [12]:
env = gym.make('AdvectionDiffusion-v0', **params)

In [13]:
model = DQN("CnnPolicy", env, verbose=1)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [14]:
model.learn(total_timesteps=10, progress_bar=True)

Output()

  self.pbar = tqdm(total=self.locals["total_timesteps"] - self.model.num_timesteps)


<sbx.dqn.dqn.DQN at 0x147c6840e230>

In [None]:

vec_env = model.get_env()
obs = vec_env.reset()
for _ in range(1000):
    vec_env.render()
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, info = vec_env.step(action)

vec_env.close()