In [1]:
from cell_env import CellEnv
# Use sb3 env checker:
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import DQN

from stable_baselines3.common.vec_env import DummyVecEnv
from gymnasium.wrappers import TimeLimit
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import EvalCallback


In [2]:

env_args = {
    "max_timesteps": 200,
    "alpha_mem": 1.0,
    "dt": 0.5,
    "frame_stack": 2,
}


In [3]:
env = CellEnv(**env_args)
check_env(env)
# use the monitor wrapper to log the results:
env = Monitor(env, info_keywords=("c"))
eval_env = CellEnv(**env_args)
eval_env = Monitor(eval_env, info_keywords=("c"))

In [4]:

eval_callback = EvalCallback(eval_env, best_model_save_path='./rl-models-sde/',
                             n_eval_episodes=10,
                             log_path='./rl-logs/', eval_freq=5_000,
                             deterministic=True, render=False,
                             )

model = DQN("MlpPolicy", DummyVecEnv([lambda: env]), verbose=4, tensorboard_log="./rl-logs/",
            exploration_fraction=0.2,
            target_update_interval=2000,
            buffer_size=10_000,
            gradient_steps=2,
            learning_starts=1000,
            learning_rate=0.0015,
            batch_size=16,
)
model.learn(total_timesteps=100_000, tb_log_name="dqn",
            callback=eval_callback)


Using cuda device
Logging to ./rl-logs/dqn_18
{'c': 1086.4190084455197, 'res_fraction': array([0.        , 0.00157254, 0.00472786, ..., 0.57916521, 0.58359159,
       0.58799472])}
c
{'c': 3591.6911758813603, 'res_fraction': array([0.        , 0.00157254, 0.00472786, ..., 0.41678588, 0.42172022,
       0.42664903])}
c
{'c': 3430.000765973453, 'res_fraction': array([0.        , 0.        , 0.        , ..., 0.73334284, 0.73666641,
       0.73995957])}
c
{'c': 4377.321175409317, 'res_fraction': array([0.        , 0.00157254, 0.00472786, ..., 0.37833986, 0.38329585,
       0.38825117])}
c
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 200      |
|    ep_rew_mean      | -361     |
|    exploration_rate | 0.962    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 34       |
|    time_elapsed     | 23       |
|    total_timesteps  | 800      |
----------------------------------
{'c': 2535.227906499760