In [None]:
!pip install --quiet "stable-baselines3==2.1.0"
!pip install --quiet "gymnasium[classic-control]"

In [None]:
import gymnasium as gym
import numpy as np
import random
import matplotlib.pyplot as plt
from stable_baselines3 import PPO
import os
import torch as th

SEED = 42
np.random.seed(SEED)
random.seed(SEED)
th.manual_seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)


In [None]:
from gymnasium import spaces

class OneHotWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        assert isinstance(env.observation_space, spaces.Discrete), "OneHotWrapper expects discrete observation space"
        self.n = env.observation_space.n
        # Box of shape (n,) with float32 - suitable for SB3 MLP policy
        self.observation_space = spaces.Box(low=0.0, high=1.0, shape=(self.n,), dtype=np.float32)

    def observation(self, obs):
        # obs will be an integer state
        arr = np.zeros(self.n, dtype=np.float32)
        arr[int(obs)] = 1.0
        return arr


In [None]:
def make_env(seed=None):
    env = gym.make("Taxi-v3")
    env = OneHotWrapper(env)
    if seed is not None:
        env.reset(seed=seed)
    return env


env = make_env(seed=SEED)
obs, info = env.reset()
print("Initial observation (one-hot) shape:", obs.shape, "  sum:", obs.sum())  # sum == 1.0
env.close()


In [None]:
policy_kwargs = dict(net_arch=[128, 128])  # modest MLP

def train_ppo(env_fn, total_timesteps=200_000, learning_rate=3e-4, seed=SEED, model_name="ppo_model"):
    env = env_fn(seed=seed)
    # force CPU device to avoid CUDA complications
    model = PPO("MlpPolicy", env, learning_rate=learning_rate, verbose=1,
                policy_kwargs=policy_kwargs, seed=seed, device="cpu")
    model.learn(total_timesteps=total_timesteps)
    # save model
    model_path = f"{model_name}_lr{learning_rate:.0e}.zip"
    model.save(model_path)
    env.close()
    return model, model_path


In [None]:
def test_agent(model, env_fn, episodes=10, render=False):
    env = env_fn()
    rewards = []
    for ep in range(episodes):
        obs, info = env.reset()
        done = False
        total_reward = 0
        steps = 0
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            action = int(action)                      # important: Taxi expects int
            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            total_reward += reward
            steps += 1
            if render:
                # text rendering (works in classic-control)
                try:
                    print(env.unwrapped.desc)  # may or may not be helpful; optional
                except Exception:
                    pass
        rewards.append(total_reward)
        print(f"Episode {ep+1} Total Reward: {total_reward}  (steps: {steps})")
    env.close()
    return rewards


In [None]:
default_timesteps = 200_000
print("Training PPO with default LR = 3e-4 ...")
model_default, path_default = train_ppo(make_env, total_timesteps=default_timesteps, learning_rate=3e-4, model_name="ppo_default")
print("Saved default model to:", path_default)


In [None]:
print("\n--- Testing Default PPO ---")
default_rewards = test_agent(model_default, make_env, episodes=20)
print("Average reward (default):", np.mean(default_rewards))


In [None]:
aggressive_lr = 1e-3
aggressive_timesteps = 200_000
print(f"Training PPO with aggressive LR = {aggressive_lr} ...")
model_aggr, path_aggr = train_ppo(make_env, total_timesteps=aggressive_timesteps, learning_rate=aggressive_lr, model_name="ppo_aggressive")
print("Saved aggressive model to:", path_aggr)


In [None]:
print("\n--- Testing Aggressive-LR PPO ---")
aggr_rewards = test_agent(model_aggr, make_env, episodes=20)
print("Average reward (aggressive):", np.mean(aggr_rewards))


In [None]:
plt.figure(figsize=(10,4))
plt.plot(default_rewards, label=f"default lr=3e-4 (avg={np.mean(default_rewards):.2f})', marker='o')
plt.plot(aggr_rewards, label=f'aggressive lr={aggressive_lr} (avg={np.mean(aggr_rewards):.2f})', marker='x')
plt.xlabel("Test episode")
plt.ylabel("Total reward")
plt.title("Per-episode Total Reward (comparison)")
plt.legend()
plt.grid(True)
plt.show()

print("\nSummary:")
print("Default avg:", np.mean(default_rewards), "Aggressive avg:", np.mean(aggr_rewards))
