In [None]:
import gymnasium as gym
import time
import pygame
from stable_baselines3 import PPO

import config
from wrappers import CliffWalkingStateWrapper

def watch_agent():
    raw_env = gym.make(config.ENV_ID, render_mode="human")
    env = CliffWalkingStateWrapper(raw_env)

    try:
        model = PPO.load("ppo_cliffwalking_rlhf")
        print("Loaded trained PPO model.")
    except FileNotFoundError:
        print("Error: Could not find 'ppo_cliffwalking_rlhf.zip'. Did you run main.py?")
        return

    obs, _ = env.reset()
    print("Starting simulation...")
    print("Press [ESC] to quit episode/simulation.")

    try:
        for episode in range(5):
            print(f"--- Episode {episode + 1} ---")
            done = False
            total_steps = 0
            
            while not done:
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        raise KeyboardInterrupt
                    if event.type == pygame.KEYDOWN:
                        if event.key == pygame.K_ESCAPE:
                            print("\nESC pressed. Exiting...")
                            raise KeyboardInterrupt

                action, _ = model.predict(obs, deterministic=True)
                action_int = action.item() 
                
                obs, reward, terminated, truncated, info = env.step(action_int)
                
                done = terminated or truncated
                total_steps += 1
                
                time.sleep(0.05) 
            
            print(f"Episode finished in {total_steps} steps.")
            obs, _ = env.reset()
            
    except KeyboardInterrupt:
        print("\nStopped by user.")
    finally:
        env.close()
        pygame.quit()

if __name__ == "__main__":
    watch_agent()