In [1]:
import os
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor

pygame 2.6.0 (SDL 2.28.4, Python 3.8.19)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
from custom_env import CustomEnv
from wrappers import FullyObsSB3MLPWrapper
from callbacks import EvalSaveCallback



In [3]:
# Define the session directory
session_dir = r"./experiments/door_key"
os.makedirs(session_dir, exist_ok=True)

# Set the log directory within the session directory
log_dir = os.path.join(session_dir, "logs")
os.makedirs(log_dir, exist_ok=True)

# Set training parameters
max_train_episode_steps = 1000
max_eval_episode_steps = 500
total_timesteps = int(2e5)
eval_deterministic = False
eval_freq = int(5e3)
n_eval_episodes = 10
model_save_path = os.path.join(session_dir, "latest_model")

In [4]:
# Initialize the environment and wrapper
train_env = CustomEnv(
    txt_file_path=r'./maps/door_key.txt',
    display_size=6,
    display_mode="random",
    random_rotate=True,
    random_flip=True,
    custom_mission="Find the key and open the door.",
    max_steps=max_train_episode_steps
)

# Wrap the environment with FullyObsSB3MLPWrapper
train_env = FullyObsSB3MLPWrapper(train_env)

# Use DummyVecEnv and VecMonitor to create and wrap the environment
train_env = DummyVecEnv([lambda: train_env])  # Vectorize the environment
train_env = VecMonitor(train_env)  # Track episode statistics such as rewards

In [5]:
# Initialize the environment and wrapper
eval_env = CustomEnv(
    txt_file_path=r'./maps/door_key.txt',
    display_size=6,
    display_mode="middle",
    random_rotate=False,
    random_flip=False,
    custom_mission="Find the key and open the door.",
    max_steps=max_eval_episode_steps
)

# Wrap the environment with FullyObsSB3MLPWrapper
eval_env = FullyObsSB3MLPWrapper(eval_env)

# Use DummyVecEnv and VecMonitor to create and wrap the environment
eval_env = DummyVecEnv([lambda: eval_env])  # Vectorize the environment
eval_env = VecMonitor(eval_env)  # Track episode statistics such as rewards

In [6]:
# Load or create a new model
if os.path.exists(f"{model_save_path}.zip"):
    model = PPO.load(model_save_path, env=train_env)
    print("Loaded model from saved path.")
else:
    model = PPO("MlpPolicy", train_env, verbose=1)
    print("Initialized new model.")

Using cuda device
Initialized new model.


In [7]:
# Create EvalSaveCallback
eval_save_callback = EvalSaveCallback(
    eval_env=eval_env,
    log_dir=log_dir,
    eval_freq=eval_freq,
    n_eval_episodes=n_eval_episodes,
    deterministic=eval_deterministic,
    verbose=1,
)

In [8]:
# Train the model and log performance with the custom callback
model.learn(total_timesteps=total_timesteps, callback=eval_save_callback, progress_bar=True)

Output()

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 340      |
|    ep_rew_mean     | -2.39    |
| time/              |          |
|    fps             | 280      |
|    iterations      | 1        |
|    time_elapsed    | 7        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 372         |
|    ep_rew_mean          | -2.71       |
| time/                   |             |
|    fps                  | 263         |
|    iterations           | 2           |
|    time_elapsed         | 15          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008843426 |
|    clip_fraction        | 0.0739      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.94       |
|    explained_variance   | 0.753       |
|    learning_rate        | 0.

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 470         |
|    ep_rew_mean          | -3.69       |
| time/                   |             |
|    fps                  | 193         |
|    iterations           | 3           |
|    time_elapsed         | 31          |
|    total_timesteps      | 6144        |
| train/                  |             |
|    approx_kl            | 0.010857217 |
|    clip_fraction        | 0.14        |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.92       |
|    explained_variance   | 0.67        |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0153     |
|    n_updates            | 20          |
|    policy_gradient_loss | -0.0109     |
|    value_loss           | 0.00838     |
-----------------------------------------
---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 418       

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 364         |
|    ep_rew_mean          | -2.63       |
| time/                   |             |
|    fps                  | 195         |
|    iterations           | 5           |
|    time_elapsed         | 52          |
|    total_timesteps      | 10240       |
| train/                  |             |
|    approx_kl            | 0.011949029 |
|    clip_fraction        | 0.13        |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.92       |
|    explained_variance   | 0.563       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.00275     |
|    n_updates            | 40          |
|    policy_gradient_loss | -0.0119     |
|    value_loss           | 0.0114      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 357   

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 276         |
|    ep_rew_mean          | -1.75       |
| time/                   |             |
|    fps                  | 198         |
|    iterations           | 8           |
|    time_elapsed         | 82          |
|    total_timesteps      | 16384       |
| train/                  |             |
|    approx_kl            | 0.011756517 |
|    clip_fraction        | 0.135       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.86       |
|    explained_variance   | 0.49        |
|    learning_rate        | 0.0003      |
|    loss                 | 0.00745     |
|    n_updates            | 70          |
|    policy_gradient_loss | -0.0133     |
|    value_loss           | 0.0316      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 257   

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 237         |
|    ep_rew_mean          | -1.36       |
| time/                   |             |
|    fps                  | 197         |
|    iterations           | 10          |
|    time_elapsed         | 103         |
|    total_timesteps      | 20480       |
| train/                  |             |
|    approx_kl            | 0.010835327 |
|    clip_fraction        | 0.102       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.79       |
|    explained_variance   | 0.394       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.00199     |
|    n_updates            | 90          |
|    policy_gradient_loss | -0.0119     |
|    value_loss           | 0.0255      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 205   

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 129         |
|    ep_rew_mean          | -0.279      |
| time/                   |             |
|    fps                  | 202         |
|    iterations           | 13          |
|    time_elapsed         | 131         |
|    total_timesteps      | 26624       |
| train/                  |             |
|    approx_kl            | 0.012316141 |
|    clip_fraction        | 0.133       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.66       |
|    explained_variance   | 0.579       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0157     |
|    n_updates            | 120         |
|    policy_gradient_loss | -0.0154     |
|    value_loss           | 0.0277      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 91.4  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 84.4        |
|    ep_rew_mean          | 0.166       |
| time/                   |             |
|    fps                  | 203         |
|    iterations           | 15          |
|    time_elapsed         | 150         |
|    total_timesteps      | 30720       |
| train/                  |             |
|    approx_kl            | 0.012542499 |
|    clip_fraction        | 0.112       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.61       |
|    explained_variance   | 0.494       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0301     |
|    n_updates            | 140         |
|    policy_gradient_loss | -0.0114     |
|    value_loss           | 0.0282      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 75.3  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 73.1        |
|    ep_rew_mean          | 0.279       |
| time/                   |             |
|    fps                  | 207         |
|    iterations           | 18          |
|    time_elapsed         | 177         |
|    total_timesteps      | 36864       |
| train/                  |             |
|    approx_kl            | 0.015906215 |
|    clip_fraction        | 0.163       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.49       |
|    explained_variance   | 0.329       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0234     |
|    n_updates            | 170         |
|    policy_gradient_loss | -0.0135     |
|    value_loss           | 0.0276      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 59.6  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 49.5        |
|    ep_rew_mean          | 0.515       |
| time/                   |             |
|    fps                  | 208         |
|    iterations           | 20          |
|    time_elapsed         | 196         |
|    total_timesteps      | 40960       |
| train/                  |             |
|    approx_kl            | 0.018091826 |
|    clip_fraction        | 0.139       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.48       |
|    explained_variance   | 0.589       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0319     |
|    n_updates            | 190         |
|    policy_gradient_loss | -0.0152     |
|    value_loss           | 0.0217      |
-----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 52.7    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 53.1        |
|    ep_rew_mean          | 0.479       |
| time/                   |             |
|    fps                  | 209         |
|    iterations           | 22          |
|    time_elapsed         | 215         |
|    total_timesteps      | 45056       |
| train/                  |             |
|    approx_kl            | 0.013326206 |
|    clip_fraction        | 0.143       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.42       |
|    explained_variance   | 0.0658      |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0244     |
|    n_updates            | 210         |
|    policy_gradient_loss | -0.00918    |
|    value_loss           | 0.0199      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 49    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 47          |
|    ep_rew_mean          | 0.54        |
| time/                   |             |
|    fps                  | 211         |
|    iterations           | 25          |
|    time_elapsed         | 241         |
|    total_timesteps      | 51200       |
| train/                  |             |
|    approx_kl            | 0.028062671 |
|    clip_fraction        | 0.215       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.42       |
|    explained_variance   | 0.47        |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0118     |
|    n_updates            | 240         |
|    policy_gradient_loss | -0.0193     |
|    value_loss           | 0.0201      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 40.9  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 35.8        |
|    ep_rew_mean          | 0.652       |
| time/                   |             |
|    fps                  | 213         |
|    iterations           | 27          |
|    time_elapsed         | 259         |
|    total_timesteps      | 55296       |
| train/                  |             |
|    approx_kl            | 0.020229608 |
|    clip_fraction        | 0.206       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.35       |
|    explained_variance   | 0.425       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0173     |
|    n_updates            | 260         |
|    policy_gradient_loss | -0.0213     |
|    value_loss           | 0.0198      |
-----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 39.8    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 21.2        |
|    ep_rew_mean          | 0.798       |
| time/                   |             |
|    fps                  | 215         |
|    iterations           | 30          |
|    time_elapsed         | 285         |
|    total_timesteps      | 61440       |
| train/                  |             |
|    approx_kl            | 0.022415856 |
|    clip_fraction        | 0.186       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.12       |
|    explained_variance   | 0.358       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0052     |
|    n_updates            | 290         |
|    policy_gradient_loss | -0.0146     |
|    value_loss           | 0.0153      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 18.9  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 29.7        |
|    ep_rew_mean          | 0.713       |
| time/                   |             |
|    fps                  | 216         |
|    iterations           | 32          |
|    time_elapsed         | 303         |
|    total_timesteps      | 65536       |
| train/                  |             |
|    approx_kl            | 0.046620186 |
|    clip_fraction        | 0.259       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.04       |
|    explained_variance   | 0.327       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0384     |
|    n_updates            | 310         |
|    policy_gradient_loss | -0.0232     |
|    value_loss           | 0.0102      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 22.5  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 13.4        |
|    ep_rew_mean          | 0.876       |
| time/                   |             |
|    fps                  | 217         |
|    iterations           | 35          |
|    time_elapsed         | 329         |
|    total_timesteps      | 71680       |
| train/                  |             |
|    approx_kl            | 0.016301757 |
|    clip_fraction        | 0.191       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.731      |
|    explained_variance   | 0.098       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0224     |
|    n_updates            | 340         |
|    policy_gradient_loss | -0.0154     |
|    value_loss           | 0.00826     |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 19.9  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 12.4        |
|    ep_rew_mean          | 0.886       |
| time/                   |             |
|    fps                  | 217         |
|    iterations           | 37          |
|    time_elapsed         | 347         |
|    total_timesteps      | 75776       |
| train/                  |             |
|    approx_kl            | 0.008065378 |
|    clip_fraction        | 0.0895      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.44       |
|    explained_variance   | -0.134      |
|    learning_rate        | 0.0003      |
|    loss                 | -0.00911    |
|    n_updates            | 360         |
|    policy_gradient_loss | -0.00348    |
|    value_loss           | 0.00442     |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 10.3  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 8.92        |
|    ep_rew_mean          | 0.921       |
| time/                   |             |
|    fps                  | 218         |
|    iterations           | 40          |
|    time_elapsed         | 375         |
|    total_timesteps      | 81920       |
| train/                  |             |
|    approx_kl            | 0.019766612 |
|    clip_fraction        | 0.185       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.616      |
|    explained_variance   | 0.609       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0787      |
|    n_updates            | 390         |
|    policy_gradient_loss | -0.00385    |
|    value_loss           | 0.00792     |
-----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 10.5    

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 7.89       |
|    ep_rew_mean          | 0.931      |
| time/                   |            |
|    fps                  | 217        |
|    iterations           | 42         |
|    time_elapsed         | 395        |
|    total_timesteps      | 86016      |
| train/                  |            |
|    approx_kl            | 0.03765536 |
|    clip_fraction        | 0.221      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.563     |
|    explained_variance   | 0.708      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0172    |
|    n_updates            | 410        |
|    policy_gradient_loss | -0.0282    |
|    value_loss           | 0.00587    |
----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 7.65       |
|    ep_rew_mean

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 7.29       |
|    ep_rew_mean          | 0.937      |
| time/                   |            |
|    fps                  | 216        |
|    iterations           | 44         |
|    time_elapsed         | 415        |
|    total_timesteps      | 90112      |
| train/                  |            |
|    approx_kl            | 0.05361692 |
|    clip_fraction        | 0.192      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.372     |
|    explained_variance   | 0.497      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0322    |
|    n_updates            | 430        |
|    policy_gradient_loss | -0.0134    |
|    value_loss           | 0.00193    |
----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 7.25        |
|    ep_rew_m

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 6.96        |
|    ep_rew_mean          | 0.94        |
| time/                   |             |
|    fps                  | 215         |
|    iterations           | 47          |
|    time_elapsed         | 447         |
|    total_timesteps      | 96256       |
| train/                  |             |
|    approx_kl            | 0.063053995 |
|    clip_fraction        | 0.28        |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.494      |
|    explained_variance   | 0.521       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0739     |
|    n_updates            | 460         |
|    policy_gradient_loss | -0.038      |
|    value_loss           | 0.00579     |
-----------------------------------------


KeyboardInterrupt: 