In [1]:
import os
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor

pygame 2.6.0 (SDL 2.28.4, Python 3.8.19)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
from custom_env import CustomEnv
from wrappers import FullyObsSB3MLPWrapper
from callbacks import EvalSaveCallback



In [3]:
# Define the session directory
session_dir = r"./experiments/small_corridor"
os.makedirs(session_dir, exist_ok=True)

# Set the log directory within the session directory
log_dir = os.path.join(session_dir, "logs")
os.makedirs(log_dir, exist_ok=True)

# Set training parameters
max_train_episode_steps = 1000
max_eval_episode_steps = 500
total_timesteps = int(1e5)
eval_deterministic = False
eval_freq = int(5e3)
n_eval_episodes = 10
model_save_path = os.path.join(session_dir, "latest_model")

In [4]:
# Initialize the environment and wrapper
train_env = CustomEnv(
    txt_file_path='simple_test_corridor.txt',
    display_size=6,
    display_mode="random",
    random_rotate=True,
    random_flip=True,
    custom_mission="Find the key and open the door.",
    max_steps=max_train_episode_steps
)

# Wrap the environment with FullyObsSB3MLPWrapper
train_env = FullyObsSB3MLPWrapper(train_env)

# Use DummyVecEnv and VecMonitor to create and wrap the environment
train_env = DummyVecEnv([lambda: train_env])  # Vectorize the environment
train_env = VecMonitor(train_env)  # Track episode statistics such as rewards

In [5]:
# Initialize the environment and wrapper
eval_env = CustomEnv(
    txt_file_path='simple_test_corridor.txt',
    display_size=6,
    display_mode="middle",
    random_rotate=False,
    random_flip=False,
    custom_mission="Find the key and open the door.",
    max_steps=max_eval_episode_steps
)

# Wrap the environment with FullyObsSB3MLPWrapper
eval_env = FullyObsSB3MLPWrapper(eval_env)

# Use DummyVecEnv and VecMonitor to create and wrap the environment
eval_env = DummyVecEnv([lambda: eval_env])  # Vectorize the environment
eval_env = VecMonitor(eval_env)  # Track episode statistics such as rewards

In [6]:
# Load or create a new model
if os.path.exists(f"{model_save_path}.zip"):
    model = PPO.load(model_save_path, env=train_env)
    print("Loaded model from saved path.")
else:
    model = PPO("MlpPolicy", train_env, verbose=1)
    print("Initialized new model.")

Using cuda device
Initialized new model.


In [7]:
# Create EvalSaveCallback
eval_save_callback = EvalSaveCallback(
    eval_env=eval_env,
    log_dir=log_dir,
    eval_freq=eval_freq,
    n_eval_episodes=n_eval_episodes,
    deterministic=eval_deterministic,
    verbose=1,
)

In [8]:
# Train the model and log performance with the custom callback
model.learn(total_timesteps=total_timesteps, callback=eval_save_callback, progress_bar=True)

Output()

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 139      |
|    ep_rew_mean     | 0.875    |
| time/              |          |
|    fps             | 415      |
|    iterations      | 1        |
|    time_elapsed    | 4        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 151         |
|    ep_rew_mean          | 0.864       |
| time/                   |             |
|    fps                  | 369         |
|    iterations           | 2           |
|    time_elapsed         | 11          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.010126465 |
|    clip_fraction        | 0.115       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.94       |
|    explained_variance   | 0.0591      |
|    learning_rate        | 0.

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 127         |
|    ep_rew_mean          | 0.885       |
| time/                   |             |
|    fps                  | 312         |
|    iterations           | 3           |
|    time_elapsed         | 19          |
|    total_timesteps      | 6144        |
| train/                  |             |
|    approx_kl            | 0.010205676 |
|    clip_fraction        | 0.108       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.92       |
|    explained_variance   | 0.393       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.00632    |
|    n_updates            | 20          |
|    policy_gradient_loss | -0.0103     |
|    value_loss           | 0.0486      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 101   

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 81          |
|    ep_rew_mean          | 0.927       |
| time/                   |             |
|    fps                  | 308         |
|    iterations           | 5           |
|    time_elapsed         | 33          |
|    total_timesteps      | 10240       |
| train/                  |             |
|    approx_kl            | 0.011434306 |
|    clip_fraction        | 0.138       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.85       |
|    explained_variance   | 0.188       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.00673    |
|    n_updates            | 40          |
|    policy_gradient_loss | -0.0127     |
|    value_loss           | 0.0408      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 57.3  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 53.3        |
|    ep_rew_mean          | 0.952       |
| time/                   |             |
|    fps                  | 306         |
|    iterations           | 8           |
|    time_elapsed         | 53          |
|    total_timesteps      | 16384       |
| train/                  |             |
|    approx_kl            | 0.012124343 |
|    clip_fraction        | 0.146       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.7        |
|    explained_variance   | 0.195       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0218     |
|    n_updates            | 70          |
|    policy_gradient_loss | -0.0173     |
|    value_loss           | 0.0144      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 49.5  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 47.3        |
|    ep_rew_mean          | 0.957       |
| time/                   |             |
|    fps                  | 305         |
|    iterations           | 10          |
|    time_elapsed         | 67          |
|    total_timesteps      | 20480       |
| train/                  |             |
|    approx_kl            | 0.012683596 |
|    clip_fraction        | 0.158       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.64       |
|    explained_variance   | 0.241       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.00793    |
|    n_updates            | 90          |
|    policy_gradient_loss | -0.0189     |
|    value_loss           | 0.0102      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 39.4  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 19.1        |
|    ep_rew_mean          | 0.983       |
| time/                   |             |
|    fps                  | 307         |
|    iterations           | 13          |
|    time_elapsed         | 86          |
|    total_timesteps      | 26624       |
| train/                  |             |
|    approx_kl            | 0.014641704 |
|    clip_fraction        | 0.191       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.49       |
|    explained_variance   | 0.281       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0394     |
|    n_updates            | 120         |
|    policy_gradient_loss | -0.0164     |
|    value_loss           | 0.00681     |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 19.1  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 20.3        |
|    ep_rew_mean          | 0.982       |
| time/                   |             |
|    fps                  | 306         |
|    iterations           | 15          |
|    time_elapsed         | 100         |
|    total_timesteps      | 30720       |
| train/                  |             |
|    approx_kl            | 0.012370806 |
|    clip_fraction        | 0.146       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.33       |
|    explained_variance   | 0.045       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0114     |
|    n_updates            | 140         |
|    policy_gradient_loss | -0.0182     |
|    value_loss           | 0.00547     |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 15.6  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 12.8        |
|    ep_rew_mean          | 0.988       |
| time/                   |             |
|    fps                  | 307         |
|    iterations           | 18          |
|    time_elapsed         | 119         |
|    total_timesteps      | 36864       |
| train/                  |             |
|    approx_kl            | 0.012317064 |
|    clip_fraction        | 0.142       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.16       |
|    explained_variance   | 0.427       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0381     |
|    n_updates            | 170         |
|    policy_gradient_loss | -0.018      |
|    value_loss           | 0.0056      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 13.2  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 10.2        |
|    ep_rew_mean          | 0.991       |
| time/                   |             |
|    fps                  | 307         |
|    iterations           | 20          |
|    time_elapsed         | 133         |
|    total_timesteps      | 40960       |
| train/                  |             |
|    approx_kl            | 0.013152778 |
|    clip_fraction        | 0.171       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.12       |
|    explained_variance   | 0.563       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0204     |
|    n_updates            | 190         |
|    policy_gradient_loss | -0.0177     |
|    value_loss           | 0.00486     |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 8.53  

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 11.8       |
|    ep_rew_mean          | 0.989      |
| time/                   |            |
|    fps                  | 306        |
|    iterations           | 22         |
|    time_elapsed         | 146        |
|    total_timesteps      | 45056      |
| train/                  |            |
|    approx_kl            | 0.02770732 |
|    clip_fraction        | 0.213      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.894     |
|    explained_variance   | 0.392      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0104    |
|    n_updates            | 210        |
|    policy_gradient_loss | -0.0172    |
|    value_loss           | 0.00269    |
----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 9.38       |
|    ep_rew_mean

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 8.02        |
|    ep_rew_mean          | 0.993       |
| time/                   |             |
|    fps                  | 303         |
|    iterations           | 25          |
|    time_elapsed         | 168         |
|    total_timesteps      | 51200       |
| train/                  |             |
|    approx_kl            | 0.024079997 |
|    clip_fraction        | 0.135       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.725      |
|    explained_variance   | 0.463       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0064     |
|    n_updates            | 240         |
|    policy_gradient_loss | -0.00606    |
|    value_loss           | 0.0027      |
-----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 6.94    

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 7.3         |
|    ep_rew_mean          | 0.993       |
| time/                   |             |
|    fps                  | 298         |
|    iterations           | 27          |
|    time_elapsed         | 185         |
|    total_timesteps      | 55296       |
| train/                  |             |
|    approx_kl            | 0.030391172 |
|    clip_fraction        | 0.166       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.467      |
|    explained_variance   | 0.65        |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0533     |
|    n_updates            | 260         |
|    policy_gradient_loss | -0.0269     |
|    value_loss           | 0.00116     |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 6.31  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 7.5         |
|    ep_rew_mean          | 0.993       |
| time/                   |             |
|    fps                  | 296         |
|    iterations           | 30          |
|    time_elapsed         | 207         |
|    total_timesteps      | 61440       |
| train/                  |             |
|    approx_kl            | 0.031223869 |
|    clip_fraction        | 0.152       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.374      |
|    explained_variance   | 0.492       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0334     |
|    n_updates            | 290         |
|    policy_gradient_loss | -0.0222     |
|    value_loss           | 0.00129     |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 5.58  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 4.69        |
|    ep_rew_mean          | 0.996       |
| time/                   |             |
|    fps                  | 295         |
|    iterations           | 32          |
|    time_elapsed         | 221         |
|    total_timesteps      | 65536       |
| train/                  |             |
|    approx_kl            | 0.028473133 |
|    clip_fraction        | 0.146       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.318      |
|    explained_variance   | 0.473       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0488     |
|    n_updates            | 310         |
|    policy_gradient_loss | -0.0227     |
|    value_loss           | 0.00228     |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 4.88  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 5.21        |
|    ep_rew_mean          | 0.995       |
| time/                   |             |
|    fps                  | 291         |
|    iterations           | 35          |
|    time_elapsed         | 245         |
|    total_timesteps      | 71680       |
| train/                  |             |
|    approx_kl            | 0.023150593 |
|    clip_fraction        | 0.049       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.143      |
|    explained_variance   | 0.383       |
|    learning_rate        | 0.0003      |
|    loss                 | 0.00808     |
|    n_updates            | 340         |
|    policy_gradient_loss | -0.00997    |
|    value_loss           | 0.000298    |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 4.88  

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 4.37       |
|    ep_rew_mean          | 0.996      |
| time/                   |            |
|    fps                  | 288        |
|    iterations           | 37         |
|    time_elapsed         | 262        |
|    total_timesteps      | 75776      |
| train/                  |            |
|    approx_kl            | 0.05426146 |
|    clip_fraction        | 0.131      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.204     |
|    explained_variance   | 0.443      |
|    learning_rate        | 0.0003     |
|    loss                 | 0.0169     |
|    n_updates            | 360        |
|    policy_gradient_loss | -0.0282    |
|    value_loss           | 0.000649   |
----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 4.22        |
|    ep_rew_m

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 13.7        |
|    ep_rew_mean          | 0.988       |
| time/                   |             |
|    fps                  | 285         |
|    iterations           | 40          |
|    time_elapsed         | 286         |
|    total_timesteps      | 81920       |
| train/                  |             |
|    approx_kl            | 0.090145245 |
|    clip_fraction        | 0.486       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.454      |
|    explained_variance   | 0.565       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0586     |
|    n_updates            | 390         |
|    policy_gradient_loss | -0.0216     |
|    value_loss           | 0.000756    |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 6.64  

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 4.9        |
|    ep_rew_mean          | 0.996      |
| time/                   |            |
|    fps                  | 284        |
|    iterations           | 42         |
|    time_elapsed         | 302        |
|    total_timesteps      | 86016      |
| train/                  |            |
|    approx_kl            | 0.02850851 |
|    clip_fraction        | 0.137      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.208     |
|    explained_variance   | 0.692      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.00335   |
|    n_updates            | 410        |
|    policy_gradient_loss | -0.0188    |
|    value_loss           | 0.0056     |
----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 4.34       |
|    ep_rew_mean

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 4.04        |
|    ep_rew_mean          | 0.996       |
| time/                   |             |
|    fps                  | 282         |
|    iterations           | 44          |
|    time_elapsed         | 318         |
|    total_timesteps      | 90112       |
| train/                  |             |
|    approx_kl            | 0.060104527 |
|    clip_fraction        | 0.238       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.212      |
|    explained_variance   | 0.296       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0991     |
|    n_updates            | 430         |
|    policy_gradient_loss | -0.0204     |
|    value_loss           | 0.0011      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 4.21  

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 4.06       |
|    ep_rew_mean          | 0.996      |
| time/                   |            |
|    fps                  | 280        |
|    iterations           | 47         |
|    time_elapsed         | 343        |
|    total_timesteps      | 96256      |
| train/                  |            |
|    approx_kl            | 0.03369914 |
|    clip_fraction        | 0.122      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.11      |
|    explained_variance   | 0.605      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0579    |
|    n_updates            | 460        |
|    policy_gradient_loss | -0.032     |
|    value_loss           | 0.000264   |
----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 8.38       |
|    ep_rew_mean

----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 4          |
|    ep_rew_mean          | 0.996      |
| time/                   |            |
|    fps                  | 279        |
|    iterations           | 49         |
|    time_elapsed         | 359        |
|    total_timesteps      | 100352     |
| train/                  |            |
|    approx_kl            | 0.14872073 |
|    clip_fraction        | 0.41       |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.323     |
|    explained_variance   | 0.359      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.104     |
|    n_updates            | 480        |
|    policy_gradient_loss | 0.0612     |
|    value_loss           | 0.000941   |
----------------------------------------


<stable_baselines3.ppo.ppo.PPO at 0x183bb210550>