In [1]:
import numpy as np

from stable_baselines3 import PPO
from PO_grid_world import PO_GridWorld
from notebook_env_wrapper import NotebookEnvWrapper
from stable_baselines3.common.evaluation import evaluate_policy

In [2]:
env_po = PO_GridWorld(partially_observable=True)
env_notebook = NotebookEnvWrapper(PO_GridWorld(partially_observable=True), notebook_size=8)

In [3]:
arrows = ["↑", "↓", "←", "→"]

def print_policy(model):
    for i in range(6):
        for j in range(6):
            obs = i*6 + j
            pred = model.predict(obs, deterministic=True)[0]
            print(arrows[pred], end=" ")
        print()

def print_policy_po(po_model):
    for i in range(6):
        for j in range(6):
            obs = (i//3)*2 + (j//3)
            pred = po_model.predict(obs, deterministic=True)[0]
            print(arrows[pred], end=" ")
        print()

In [4]:
model_po = PPO.load("models_cmp/ppo_gridworld_po_2")
print_policy_po(model_po)

↑ ↑ ↑ ↑ ↑ ↑ 
↑ ↑ ↑ ↑ ↑ ↑ 
↑ ↑ ↑ ↑ ↑ ↑ 
→ → → ↑ ↑ ↑ 
→ → → ↑ ↑ ↑ 
→ → → ↑ ↑ ↑ 


In [14]:
n = 3

mean_episode_lengths_po = []
std_episode_lengths_po = []
mean_rewards_po = []
std_rewards_po = []

for i in range(n):
    model = PPO.load(f"models_cmp/ppo_gridworld_po_{i}")
    rewards, lengths = evaluate_policy(model, env_po, n_eval_episodes=1000, return_episode_rewards=True)

    mean_episode_length = np.mean(lengths)
    std_episode_length = np.std(lengths)
    mean_reward = np.mean(rewards)
    std_reward = np.std(rewards)

    mean_episode_lengths_po.append(mean_episode_length)
    std_episode_lengths_po.append(std_episode_length)
    mean_rewards_po.append(mean_reward)
    std_rewards_po.append(std_reward)
    print(f"PO Model {i}: {mean_reward:.2f} +/- {std_reward:.2f}, mean episode length: {mean_episode_length:.2f} +/- {std_episode_length:.2f}")



PO Model 0: 0.73 +/- 0.68, mean episode length: 79.10 +/- 65.56
PO Model 1: 0.75 +/- 0.66, mean episode length: 77.41 +/- 63.00
PO Model 2: 0.78 +/- 0.62, mean episode length: 75.93 +/- 59.46


In [15]:
n = 3

mean_episode_lengths_notebook = []
std_episode_lengths_notebook = []
mean_rewards_notebook = []
std_rewards_notebook = []

for i in range(n):
    model = PPO.load(f"models_cmp/ppo_gridworld_notebook_{i}")
    rewards, lengths = evaluate_policy(model, env_notebook, n_eval_episodes=1000, return_episode_rewards=True)
    
    mean_episode_length = np.mean(lengths)
    std_episode_length = np.std(lengths)
    mean_reward = np.mean(rewards)
    std_reward = np.std(rewards)

    mean_episode_lengths_notebook.append(mean_episode_length)
    std_episode_lengths_notebook.append(std_episode_length)
    mean_rewards_notebook.append(mean_reward)
    std_rewards_notebook.append(std_reward)
    print(f"Notebook Model {i}: {mean_reward:.2f} +/- {std_reward:.2f}, mean episode length: {mean_episode_length:.2f} +/- {std_episode_length:.2f}")

Notebook Model 0: 0.87 +/- 0.49, mean episode length: 29.10 +/- 34.52
Notebook Model 1: 0.86 +/- 0.52, mean episode length: 31.15 +/- 40.49
Notebook Model 2: 0.84 +/- 0.54, mean episode length: 30.25 +/- 37.03


In [16]:
print("PO models")
print(f"Mean episode length: {np.mean(mean_episode_lengths_po):.2f} +/- {np.mean(std_episode_lengths_po):.2f}")
print(f"Mean reward: {np.mean(mean_rewards_po):.2f} +/- {np.mean(std_rewards_po):.2f}")
print("Notebook models")
print(f"Mean episode length: {np.mean(mean_episode_lengths_notebook):.2f} +/- {np.mean(std_episode_lengths_notebook):.2f}")
print(f"Mean reward: {np.mean(mean_rewards_notebook):.2f} +/- {np.mean(std_rewards_notebook):.2f}")

PO models
Mean episode length: 77.48 +/- 62.68
Mean reward: 0.76 +/- 0.65
Notebook models
Mean episode length: 30.17 +/- 37.35
Mean reward: 0.86 +/- 0.51


In [10]:

model = PPO.load(f"models_cmp/ppo_gridworld_notebook_0")
rewards = evaluate_policy(model, env_notebook, n_eval_episodes=10, return_episode_rewards=True)

print(rewards)

([np.float64(1.0), np.float64(1.0), np.float64(1.0), np.float64(1.0), np.float64(1.0), np.float64(1.0), np.float64(1.0), np.float64(-1.0), np.float64(1.0), np.float64(1.0)], [np.int64(14), np.int64(31), np.int64(14), np.int64(20), np.int64(15), np.int64(87), np.int64(75), np.int64(14), np.int64(143), np.int64(11)])


