In [1]:
import gym
import numpy as np
import rich

from tqdm import tqdm

from stable_baselines3 import PPO, A2C, DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy

ENV_NAME = "LunarLander-v2"
POLICY_TYPE = "MlpPolicy"
N_ENV = 8
LEARNING_ALG_NAME = "ppo"
CKPT_PATH = lambda n_steps: f"{ENV_NAME}_{LEARNING_ALG_NAME}.pkl"

LEARNING_ALGS = dict(dqn=DQN, ppo=PPO, a2c=A2C)
LEARNING_ALGS_KWARGS = dict(
    dqn=dict(
        max_grad_norm=1,
        tau=0.01,
    )
)

# Parallel environments
env = make_vec_env(ENV_NAME, n_envs=N_ENV)

In [2]:
model = LEARNING_ALGS[LEARNING_ALG_NAME](
    POLICY_TYPE, 
    env, 
    verbose=0,
    **LEARNING_ALGS_KWARGS.get(LEARNING_ALG_NAME, {})
)
print("Training...")

TOTAL = 1E6
SPLIT = 50

for i in tqdm(range(SPLIT)):
    model = model.learn(total_timesteps=int(TOTAL // SPLIT) , progress_bar=True)
    mean_reward, std_reward = evaluate_policy(
        model, 
        env, 
        n_eval_episodes=25,
    )

    rich.print(f"[bold blue]Learning alg:[/] {LEARNING_ALGS[LEARNING_ALG_NAME].__name__}")
    rich.print(f"[bold blue]Policy type:[/]  {POLICY_TYPE}")
    rich.print(f"[bold blue]Mean reward:[/]  {mean_reward:.2f} +/- {std_reward}")
    rich.print("\n" + "#" * 80)


Training...


  0%|          | 0/50 [00:00<?, ?it/s]

Output()

  2%|▏         | 1/50 [00:32<26:45, 32.76s/it]

Output()

  4%|▍         | 2/50 [01:09<28:06, 35.13s/it]

Output()

  6%|▌         | 3/50 [02:01<33:25, 42.68s/it]

Output()

  8%|▊         | 4/50 [03:14<42:05, 54.91s/it]

Output()

 10%|█         | 5/50 [04:43<50:19, 67.09s/it]

Output()

 12%|█▏        | 6/50 [06:15<55:26, 75.59s/it]

Output()

 14%|█▍        | 7/50 [07:29<53:51, 75.15s/it]

Output()

 16%|█▌        | 8/50 [08:30<49:23, 70.57s/it]

Output()

 18%|█▊        | 9/50 [09:26<44:58, 65.82s/it]

Output()

 20%|██        | 10/50 [10:19<41:22, 62.06s/it]

Output()

 22%|██▏       | 11/50 [11:14<38:52, 59.80s/it]

Output()

 24%|██▍       | 12/50 [12:04<35:56, 56.74s/it]

Output()

 26%|██▌       | 13/50 [12:52<33:23, 54.15s/it]

Output()

 28%|██▊       | 14/50 [13:41<31:40, 52.79s/it]

Output()

 30%|███       | 15/50 [14:29<29:56, 51.32s/it]

Output()

 32%|███▏      | 16/50 [15:18<28:37, 50.53s/it]

Output()

 34%|███▍      | 17/50 [16:05<27:09, 49.38s/it]

Output()

 36%|███▌      | 18/50 [16:50<25:44, 48.25s/it]

Output()

 38%|███▊      | 19/50 [17:34<24:09, 46.76s/it]

Output()

 40%|████      | 20/50 [18:16<22:40, 45.36s/it]

Output()

 42%|████▏     | 21/50 [18:57<21:19, 44.12s/it]

Output()

 44%|████▍     | 22/50 [19:37<20:02, 42.95s/it]

Output()

 46%|████▌     | 23/50 [20:18<19:02, 42.30s/it]

Output()

 48%|████▊     | 24/50 [20:59<18:10, 41.94s/it]

Output()

 50%|█████     | 25/50 [21:40<17:20, 41.63s/it]

Output()

 52%|█████▏    | 26/50 [22:19<16:17, 40.74s/it]

Output()

 54%|█████▍    | 27/50 [23:01<15:47, 41.18s/it]

Output()

 56%|█████▌    | 28/50 [23:42<15:04, 41.12s/it]

Output()

 58%|█████▊    | 29/50 [24:21<14:09, 40.47s/it]

Output()

 60%|██████    | 30/50 [24:59<13:15, 39.77s/it]

Output()

 62%|██████▏   | 31/50 [25:39<12:34, 39.71s/it]

Output()

 64%|██████▍   | 32/50 [26:17<11:47, 39.32s/it]

Output()

 66%|██████▌   | 33/50 [26:56<11:08, 39.34s/it]

Output()

 68%|██████▊   | 34/50 [27:33<10:17, 38.61s/it]

Output()

 70%|███████   | 35/50 [28:08<09:23, 37.56s/it]

Output()

 72%|███████▏  | 36/50 [28:43<08:34, 36.76s/it]

Output()

 74%|███████▍  | 37/50 [29:19<07:53, 36.40s/it]

Output()

 76%|███████▌  | 38/50 [29:55<07:15, 36.32s/it]

Output()

 78%|███████▊  | 39/50 [30:32<06:41, 36.54s/it]

Output()

 80%|████████  | 40/50 [31:10<06:11, 37.13s/it]

Output()

 82%|████████▏ | 41/50 [31:46<05:30, 36.73s/it]

Output()

 84%|████████▍ | 42/50 [32:22<04:50, 36.32s/it]

Output()

 86%|████████▌ | 43/50 [32:58<04:14, 36.36s/it]

Output()

 88%|████████▊ | 44/50 [33:34<03:37, 36.30s/it]

Output()

 90%|█████████ | 45/50 [34:09<02:58, 35.68s/it]

Output()

 92%|█████████▏| 46/50 [34:43<02:21, 35.44s/it]

Output()

 94%|█████████▍| 47/50 [35:17<01:44, 34.85s/it]

Output()

 96%|█████████▌| 48/50 [35:54<01:11, 35.55s/it]

Output()

 98%|█████████▊| 49/50 [36:30<00:35, 35.54s/it]

Output()

100%|██████████| 50/50 [37:06<00:00, 44.53s/it]


In [3]:
mean_reward, std_reward = evaluate_policy(
    model, 
    env, 
    n_eval_episodes=100,
)

print(f"[bold blue]Learning alg:[/] {LEARNING_ALGS[LEARNING_ALG_NAME].__name__}")
print(f"[bold blue]policy type:[/]  {POLICY_TYPE}")
print(f"[bold blue]mean reward:[/]  {mean_reward:.2f} +/- {std_reward}")

[bold blue]Learning alg:[/] PPO
[bold blue]policy type:[/]  MlpPolicy
[bold blue]mean reward:[/]  266.52 +/- 32.63089626909642


In [4]:
model.save(CKPT_PATH)

TypeError: ('Path parameter has invalid type.', <class 'io.BufferedIOBase'>)

In [None]:
CKPT_PATH