In [None]:
import gymnasium as gym
import simple_env
from simple_env.wrappers import RelativePosition, NormalizedObservation, DiscreteActions
from gymnasium.wrappers import TimeAwareObservation
from stable_baselines3 import DQN, A2C, PPO, DDPG
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallback, BaseCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
import shutil
from typing import Callable
import numpy as np

file = "03_steeringworld.ipynb"
name = '01_steering_world'
env_name = "simple_env/SteeringWorld-v1"
world_size = 300
max_episode_steps = 500

agent = PPO
policy = 'MlpPolicy'
dir = f"./{name}"
tensorboard_log = f"./{name}/t_logs/"
model_path = f"./{name}/model/best_model.zip"
best_model_save_path = f"./{name}/model/"
log_path = f"./{name}/logs/"
device = 'cpu'

# num_action = 20
# keys = [i for i in range(num_action)]
# values = [(i/num_action)*2 - 1 for i in range(num_action)]
# disc_to_cont = dict(zip(keys, values))
def make_env(render_mode = None):
    env = gym.make(
        env_name,
        render_mode = render_mode,
        size=world_size, 
        max_episode_steps=max_episode_steps,
        )
    
    #env = DiscreteActions(env, disc_to_cont)
    #env = TimeAwareObservation(env, normalize_time=True)
    return env

    

# Test

In [1]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open("03_steeringworld.ipynb", "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)

env = make_env(render_mode=None)
obs, info = env.reset()

terminated = False
truncated = False

while not terminated and not truncated:
    
    action = env.action_space.sample()
    #action = np.random.randint(num_action)
    #action = 0
    obs, rew, terminated, truncated, info = env.step(action)
    print(f"Action: {action}, Obs: {obs}, info {info}, rew: {rew}")
    
env.close()

Action: [0.9592223  0.25691098], Obs: [0.07230948 0.        ], info {'direction': -0.7284310953063695}, rew: 0
Action: [-0.82369816 -0.4674632 ], Obs: [0.07230948 0.        ], info {'direction': -0.7284310953063695}, rew: 0
Action: [0.24238512 0.2993225 ], Obs: [0.07230948 0.        ], info {'direction': -0.7284310953063695}, rew: 0
Action: [-0.10057948  0.90380514], Obs: [0.07230948 0.        ], info {'direction': -0.7284310953063695}, rew: 0
Action: [0.2548914  0.47157463], Obs: [0.08230948 0.        ], info {'direction': -0.7284310953063695}, rew: 0
Action: [-0.34541538  0.44920975], Obs: [0.07230948 0.        ], info {'direction': -0.7284310953063695}, rew: 0
Action: [-0.43157744 -0.5029821 ], Obs: [0.06230948 0.        ], info {'direction': -0.7284310953063695}, rew: 0
Action: [0.3234243 0.6543925], Obs: [0.06230948 0.        ], info {'direction': -0.7284310953063695}, rew: 0
Action: [-0.57064444 -0.03184291], Obs: [0.07230948 0.        ], info {'direction': -0.7284310953063695}, 

  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(f"{pre} is not within the observation space.")


# Create agent



In [None]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open(file, "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)

# Linear Schedule
def linear_schedule(initial_value: float) -> Callable[[float], float]:
    def func(progress_remaining: float) -> float:
        return progress_remaining * initial_value
    return func

# Create model
env = make_env()


if agent is DQN:
    model = DQN(
        policy,
        env,
        verbose=0,
        device=device,
        tensorboard_log=tensorboard_log,
        exploration_fraction=0.5,
        learning_rate=linear_schedule(0.0001)
    )
    
elif agent is DDPG:
    model = DDPG(
        policy,
        env,
        verbose=0,
        device=device,
        tensorboard_log=tensorboard_log,
        learning_rate=linear_schedule(0.0001)
    )
    
elif agent is A2C:
    model = A2C(
        policy,
        env,
        verbose=0,
        device=device,
        tensorboard_log=tensorboard_log,    
    )
    
elif agent is PPO:
    model = PPO(
        policy,
        env,
        verbose=0,
        device=device,
        tensorboard_log=tensorboard_log,    
    )

# Save
shutil.rmtree(dir, ignore_errors=True)
model.save(model_path)

# Train

In [None]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open(file, "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)
        
print(f"Start training with {agent}")
# Env and model
train_env = make_env()
eval_env = make_env()

model = agent.load(model_path, train_env, device)

# Callbacks
eval_callback = EvalCallback(
    eval_env,
    eval_freq=1e4,
    deterministic=True,
    n_eval_episodes=10,
    best_model_save_path=best_model_save_path,
)

checkpoint_callback = CheckpointCallback(
    1e4,
    best_model_save_path,
    name_prefix="checkpoint"
)

# Training
model.learn(
    total_timesteps=5e7,
    progress_bar=True,
    reset_num_timesteps=False,
    
    callback=[
        eval_callback,
        checkpoint_callback
    ]
)

# Save and close
model.save(model_path)
train_env.close()

# Display

In [None]:
# Execute Setup
import nbformat
from IPython import get_ipython
with open(file, "r", encoding="utf-8") as f:
    notebook = nbformat.read(f, as_version=4)
for cell in notebook.cells:
    if "tags" in cell.metadata and "setup" in cell.metadata.tags:
        exec(cell.source)
env = make_env(render_mode='human')
obs, info = env.reset()

#model = agent.load(best_model_save_path + "best_model.zip")
model = agent.load(best_model_save_path + "/checkpoint_160000_steps.zip")

terminated = False
truncated = False
while not terminated and not truncated:
    
    action, _ = model.predict(obs)
    action = int(action)
    obs, rew, terminated, truncated, info = env.step(action)
    print(f"Action: {action}, Obs: {obs}, Rew: {rew}")
    
env.close()