In [None]:
import gymnasium as gym
import simple_env
from simple_env.wrappers import RelativePosition, NormalizedObservation
from gymnasium.wrappers import TimeAwareObservation
from stable_baselines3 import DQN, A2C, PPO
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, BaseCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
import shutil
from typing import Callable

name = '01_gridworld'
env_name = "simple_env/ContinuousWorld-v1"
world_size = 300

agent = PPO
policy = 'MlpPolicy'
dir = f"./{name}"
tensorboard_log = f"./{name}/t_logs/"
model_path = f"./{name}/model/best_model.zip"
best_model_save_path = f"./{name}/model/"
log_path = f"./{name}/logs/"
device = 'cpu'


def make_env(render_mode = None):
    env = gym.make(
        env_name,
        render_mode = render_mode,
        size=world_size, 
        max_episode_steps=50)
    
    env = TimeAwareObservation(env, normalize_time=True)
    return env

    

# Test

In [None]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open("01_gridworld.ipynb", "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)

env = make_env(render_mode='human')
obs, info = env.reset()

terminated = False
truncated = False
while not terminated and not truncated:
    
    action = env.action_space.sample()
    obs, rew, terminated, truncated, info = env.step(action)
    print(f"Action: {action}, Obs: {obs}, rew: {rew}")
    
env.close()

# Create agent



In [1]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open("01_gridworld.ipynb", "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)

# Linear Schedule
def linear_schedule(initial_value: float) -> Callable[[float], float]:
    def func(progress_remaining: float) -> float:
        return progress_remaining * initial_value
    return func

# Create model
env = make_env()


if agent is DQN:
    model = DQN(
        policy,
        env,
        verbose=0,
        device=device,
        tensorboard_log=tensorboard_log,
        exploration_fraction=0.5,
        learning_rate=linear_schedule(0.0001)
    )
    
elif agent is A2C:
    model = A2C(
        policy,
        env,
        verbose=0,
        device=device,
        tensorboard_log=tensorboard_log,    
    )
    
elif agent is PPO:
    model = PPO(
        policy,
        env,
        verbose=0,
        device=device,
        tensorboard_log=tensorboard_log,    
    )

# Save
shutil.rmtree(dir, ignore_errors=True)
model.save(model_path)



# Train

In [2]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open("01_gridworld.ipynb", "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)
        
# Env and model
train_env = make_env()
eval_env = make_env()

model = agent.load(model_path, train_env, device)

# Callbacks
eval_callback = EvalCallback(
    eval_env,
    eval_freq=1e4,
    deterministic=True,
    n_eval_episodes=10,
    best_model_save_path=best_model_save_path,
)

checkpoint_callback = CheckpointCallback(
    1e4,
    best_model_save_path,
    name_prefix="checkpoint"
)

# Training
model.learn(
    total_timesteps=1e5,
    progress_bar=True,
    reset_num_timesteps=False,
    
    callback=[
        eval_callback,
        checkpoint_callback
    ]
)

# Save and close
#model.save(path)
train_env.close()

Output()

# Display

In [5]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open("01_gridworld.ipynb", "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)

env = make_env(render_mode='human')
obs, info = env.reset()

#model = agent.load(best_model_save_path + "best_model.zip")
model = agent.load(best_model_save_path + "/checkpoint_60000_steps.zip")

terminated = False
truncated = False
while not terminated and not truncated:
    
    action, _ = model.predict(obs)
    action = int(action)
    obs, rew, terminated, truncated, info = env.step(action)
    print(f"Action: {action}, Obs: {obs}, Rew: {rew}")
    
env.close()

Action: 2, Obs: [-0.37666667  0.42        0.02      ], Rew: 0.1
Action: 2, Obs: [-0.34333333  0.42        0.04      ], Rew: 0.1
Action: 2, Obs: [-0.31  0.42  0.06], Rew: 0.1
Action: 2, Obs: [-0.27666667  0.42        0.08      ], Rew: 0.1
Action: 2, Obs: [-0.24333334  0.42        0.1       ], Rew: 0.1
Action: 2, Obs: [-0.21  0.42  0.12], Rew: 0.1
Action: 2, Obs: [-0.17666666  0.42        0.14      ], Rew: 0.1
Action: 2, Obs: [-0.14333333  0.42        0.16      ], Rew: 0.1
Action: 2, Obs: [-0.11  0.42  0.18], Rew: 0.1
Action: 1, Obs: [-0.11        0.38666666  0.2       ], Rew: 0.1
Action: 2, Obs: [-0.07666667  0.38666666  0.22      ], Rew: 0.1
Action: 1, Obs: [-0.07666667  0.35333332  0.24      ], Rew: 0.1
Action: 1, Obs: [-0.07666667  0.32        0.26      ], Rew: 0.1
Action: 1, Obs: [-0.07666667  0.28666666  0.28      ], Rew: 0.1
Action: 1, Obs: [-0.07666667  0.25333333  0.3       ], Rew: 0.1
Action: 1, Obs: [-0.07666667  0.22        0.32      ], Rew: 0.1
Action: 2, Obs: [-0.04333333  