In [4]:
import gymnasium as gym
import simple_env
from simple_env.wrappers import RelativePosition, NormalizedObservation
from gymnasium.wrappers import TimeAwareObservation
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, BaseCallback
import shutil
from typing import Callable

name = '01_gridworld'
grid_size = 10

policy = 'MlpPolicy'
dir = f"./{name}"
tensorboard_log = f"./{name}/t_logs/"
model_path = f"./{name}/model/best_model.zip"
best_model_save_path = f"./{name}/model/"
log_path = f"./{name}/logs/"
device = 'cpu'

def make_env(render_mode = None):
    env = gym.make(
        "simple_env/GridWorld-v0",
        render_mode = render_mode,
        size=grid_size, 
        max_episode_steps=10)
    
    env = TimeAwareObservation(env, normalize_time=True)
    return env



# Test

In [None]:
env = make_env(render_mode='human')
obs, info = env.reset()

terminated = False
truncated = False
while not terminated and not truncated:
    
    action = env.action_space.sample()
    obs, rew, terminated, truncated, info = env.step(action)
    print(f"Obs: {obs}, rew: {rew}")
    
env.close()

# Create agent



In [5]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open("01_gridworld.ipynb", "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)

# Linear Schedule
def linear_schedule(initial_value: float) -> Callable[[float], float]:
    def func(progress_remaining: float) -> float:
        return progress_remaining * initial_value
    return func

# Create model
env = make_env()
model = DQN(
    policy,
    env,
    verbose=0,
    device=device,
    tensorboard_log=tensorboard_log,
    exploration_fraction=0.5,
    learning_rate=linear_schedule(0.0001)
)

# Save
shutil.rmtree(dir, ignore_errors=True)
model.save(model_path)

# Train

In [6]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open("01_gridworld.ipynb", "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)
        
# Env and model
train_env = make_env()
eval_env = make_env()
model = DQN.load(model_path, train_env, device)

# Callbacks
eval_callback = EvalCallback(
    eval_env,
    eval_freq=1e4,
    deterministic=True,
    n_eval_episodes=10,
    best_model_save_path=best_model_save_path,
)

checkpoint_callback = CheckpointCallback(
    1e4,
    best_model_save_path,
    name_prefix="checkpoint"
)

# Training
model.learn(
    total_timesteps=1e5,
    progress_bar=True,
    reset_num_timesteps=False,
    
    callback=[
        eval_callback,
        checkpoint_callback
    ]
)

# Save and close
#model.save(path)
train_env.close()

Output()

# Display

In [32]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open("01_gridworld.ipynb", "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)

env = make_env(render_mode='human')
obs, info = env.reset()

model = DQN.load(best_model_save_path + "/checkpoint_80000_steps.zip")

terminated = False
truncated = False
while not terminated and not truncated:
    
    action, _ = model.predict(obs)
    action = int(action)
    obs, rew, terminated, truncated, info = env.step(action)
    print(f"Action: {action}, Obs: {obs}, Rew: {rew}")
    
env.close()

Action: 2, Obs: [-0.6 -0.4  0.1], Rew: 0.1
Action: 2, Obs: [-0.5 -0.4  0.2], Rew: 0.1
Action: 2, Obs: [-0.4 -0.4  0.3], Rew: 0.1
Action: 2, Obs: [-0.3 -0.4  0.4], Rew: 0.1
Action: 3, Obs: [-0.3 -0.3  0.5], Rew: 0.1
Action: 2, Obs: [-0.2 -0.3  0.6], Rew: 0.1
Action: 3, Obs: [-0.2 -0.2  0.7], Rew: 0.1
Action: 2, Obs: [-0.1 -0.2  0.8], Rew: 0.1
Action: 3, Obs: [-0.1 -0.1  0.9], Rew: 0.1
Action: 2, Obs: [ 0.  -0.1  1. ], Rew: 0.1
