In [1]:
import gymnasium as gym
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
from stable_baselines3 import DDPG, DQN, PPO
from stable_baselines3.common.callbacks import ProgressBarCallback
# from sbx import DQN
import pde_opt
from pde_opt.numerics.domains import Domain


In [2]:
def reset_func(domain, seed=0):
    return 0.5 * jnp.ones(domain.points) + 0.01 * random.normal(random.PRNGKey(seed), domain.points)

In [3]:
Nx, Ny = 64, 64
Lx = 0.02 * Nx
Ly = 0.02 * Ny
domain = Domain((Nx, Ny), ((-Lx / 2, Lx / 2), (-Ly / 2, Ly / 2)), "dimensionless")

In [4]:
params = {
    "reset_func": reset_func,
    "diffusion_coefficient": 0.1,
    "max_control_step": 0.1,
    "end_time": 1.0,
    "step_dt": 0.05,
    "numeric_dt": 0.0001,
    "domain": domain,
    "field_dim": 1,
    "reward_function": lambda x: np.linalg.norm(x[32,32]),
    "discrete_action_space": False
}

In [5]:
env = gym.make('AdvectionDiffusion-v0', **params)

In [6]:
# model = DDPG("CnnPolicy", env, verbose=1, learning_starts=2, train_freq=1)
model = PPO("CnnPolicy", env, verbose=1)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [None]:
model.learn(total_timesteps=1, callback=ProgressBarCallback())

Output()

KeyboardInterrupt: 

In [25]:
model.save("model")

In [26]:
# model = DQN.load("../model_1000")

In [27]:
# jax.clear_caches()

In [28]:
obs, info = env.reset()
observations = []
for i in range(30):
    print(f"i: {i}", end=", ")
    action, _states = model.predict(obs, deterministic=True)
    # obs, reward, terminated, truncated, info = env.step(int(action))
    obs, reward, terminated, truncated, info = env.step(action)
    observations.append(obs)

i: 0, i: 1, i: 2, i: 3, i: 4, i: 5, i: 6, i: 7, i: 8, i: 9, i: 10, i: 11, i: 12, i: 13, i: 14, i: 15, i: 16, i: 17, i: 18, i: 19, i: 20, i: 21, i: 22, i: 23, i: 24, i: 25, i: 26, i: 27, i: 28, i: 29, 

In [29]:
from IPython.display import HTML
import matplotlib.animation as animation

fig, ax = plt.subplots(figsize=(4,4))

ims = []
for i in range(0, len(observations), 2):
    im = ax.imshow(observations[i][0], animated=True, 
                   vmin=0.0, vmax=255,
                   extent=[domain.box[0][0], domain.box[0][1], 
                          domain.box[1][0], domain.box[1][1]])
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True)

plt.title('Cahn-Hilliard Evolution')
plt.xlabel('x')
plt.ylabel('y')

plt.close()

HTML(ani.to_jshtml())