# Watch Trained Agent Play Snake

Load a trained model and watch it play the game in real-time.

In [None]:
import sys
import os
sys.path.append('../src')

import numpy as np
import torch
import matplotlib.pyplot as plt
from IPython.display import clear_output
import time

from environments import SnakeEnv
from agents import DQNAgent, PPODiscreteAgent
import yaml

print("Imports successful!")

In [None]:
# Load config
config_path = '../configs/snake_config.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Config loaded!")

In [None]:
# Create environment
env = SnakeEnv(
    grid_size=config['environment']['grid_size'],
    state_representation=config['environment']['state_representation'],
    initial_length=config['environment']['initial_length'],
    reward_food=config['environment']['reward_food'],
    reward_death=config['environment']['reward_death'],
    reward_step=config['environment']['reward_step'],
    reward_distance=config['environment']['reward_distance'],
    render_mode=None  # We'll render manually
)

# Get state shape
obs_space = env.observation_space
if hasattr(obs_space, 'shape'):
    state_shape = obs_space.shape
else:
    state_shape = (obs_space.n,)

print(f"Environment created! State shape: {state_shape}")

In [None]:
# Load trained agent
# Update this path to your checkpoint
checkpoint_path = '../checkpoints/snake/best_model.pth'  # or 'final_model.pth'

# Check if file exists
if not os.path.exists(checkpoint_path):
    print(f"⚠️  Checkpoint not found at {checkpoint_path}")
    print("Available checkpoints:")
    checkpoint_dir = '../checkpoints/snake'
    if os.path.exists(checkpoint_dir):
        for f in os.listdir(checkpoint_dir):
            if f.endswith('.pth'):
                print(f"  - {os.path.join(checkpoint_dir, f)}")
    checkpoint_path = input("Enter checkpoint path: ") or checkpoint_path

# Create agent
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
algorithm = config['training']['algorithm'].lower()

if algorithm == "dqn":
    agent = DQNAgent(
        state_shape=state_shape,
        num_actions=env.action_space.n,
        learning_rate=config['dqn']['learning_rate'],
        gamma=config['dqn']['gamma'],
        epsilon_start=config['dqn']['epsilon_start'],
        epsilon_end=config['dqn']['epsilon_end'],
        epsilon_decay=config['dqn']['epsilon_decay'],
        replay_buffer_size=config['dqn']['replay_buffer_size'],
        batch_size=config['dqn']['batch_size'],
        target_update_frequency=config['dqn']['target_update_frequency'],
        hidden_sizes=config['dqn']['network'],
        activation=config['dqn']['activation'],
        state_representation=config['environment']['state_representation'],
        device=device
    )
else:
    agent = PPODiscreteAgent(
        state_shape=state_shape,
        num_actions=env.action_space.n,
        learning_rate=config['ppo']['learning_rate'],
        gamma=config['ppo']['gamma'],
        gae_lambda=config['ppo']['gae_lambda'],
        clip_epsilon=config['ppo']['clip_epsilon'],
        value_coef=config['ppo']['value_coef'],
        entropy_coef=config['ppo']['entropy_coef'],
        max_grad_norm=config['ppo']['max_grad_norm'],
        update_epochs=config['ppo']['update_epochs'],
        batch_size=config['ppo']['batch_size'],
        hidden_sizes=config['ppo']['network'],
        activation=config['ppo']['activation'],
        state_representation=config['environment']['state_representation'],
        device=device
    )

# Load checkpoint
try:
    agent.load(checkpoint_path)
    print(f"✅ Agent loaded from {checkpoint_path}")
except Exception as e:
    print(f"❌ Error loading checkpoint: {e}")
    print("Using untrained agent (random actions)")

In [None]:
# Watch agent play
def watch_agent_play(env, agent, num_episodes=3, delay=0.1, max_steps=500):
    """
    Watch trained agent play Snake game.
    
    Args:
        env: Snake environment
        agent: Trained agent
        num_episodes: Number of episodes to watch
        delay: Delay between frames (seconds)
        max_steps: Maximum steps per episode
    """
    agent.eval()
    
    for episode in range(num_episodes):
        state, info = env.reset()
        done = False
        steps = 0
        total_reward = 0
        
        # Create figure for this episode
        fig, ax = plt.subplots(figsize=(10, 10))
        
        print(f"\n{'='*50}")
        print(f"Episode {episode + 1}")
        print(f"{'='*50}")
        
        while not done and steps < max_steps:
            # Get action from agent
            if isinstance(agent, DQNAgent):
                action = agent.act(state, deterministic=True)
            else:
                action, _, _ = agent.act(state, deterministic=True)
            
            # Take step
            next_state, reward, terminated, truncated, step_info = env.step(action)
            done = terminated or truncated
            
            total_reward += reward
            steps += 1
            state = next_state
            
            # Render current state
            ax.clear()
            ax.set_xlim(0, env.grid_size)
            ax.set_ylim(0, env.grid_size)
            ax.set_aspect('equal')
            ax.invert_yaxis()
            ax.set_xticks([])
            ax.set_yticks([])
            
            # Draw grid
            for i in range(env.grid_size):
                for j in range(env.grid_size):
                    rect = plt.Rectangle((j, i), 1, 1, 
                                       linewidth=0.5, edgecolor='lightgray', 
                                       facecolor='white', alpha=0.3)
                    ax.add_patch(rect)
            
            # Draw snake
            for i, segment in enumerate(env.snake):
                row, col = segment
                if i == 0:
                    # Head
                    color = '#2E7D32'
                    alpha = 1.0
                else:
                    # Body
                    color = '#66BB6A'
                    alpha = 0.8 - (i / len(env.snake)) * 0.3
                
                rect = plt.Rectangle((col, row), 1, 1,
                                   linewidth=1, edgecolor='darkgreen',
                                   facecolor=color, alpha=alpha)
                ax.add_patch(rect)
            
            # Draw food
            if env.food:
                row, col = env.food
                circle = plt.Circle((col + 0.5, row + 0.5), 0.4,
                                 color='red', alpha=0.8)
                ax.add_patch(circle)
            
            # Add info text
            info_text = (f"Episode {episode + 1} | "
                        f"Score: {env.score} | "
                        f"Steps: {steps} | "
                        f"Length: {len(env.snake)} | "
                        f"Reward: {total_reward:.1f}")
            ax.text(0.5, -0.05, info_text,
                   transform=ax.transAxes, ha='center',
                   fontsize=12, fontweight='bold')
            
            ax.set_title('Snake Game - Trained Agent', fontsize=16, fontweight='bold')
            
            # Display
            plt.tight_layout()
            display(fig)
            clear_output(wait=True)
            
            time.sleep(delay)
            
            if done:
                break
        
        # Final results
        print(f"\nEpisode {episode + 1} Complete!")
        print(f"  Final Score: {env.score}")
        print(f"  Total Steps: {steps}")
        print(f"  Total Reward: {total_reward:.2f}")
        print(f"  Snake Length: {len(env.snake)}")
        
        plt.close(fig)
    
    agent.train()
    print(f"\n{'='*50}")
    print("All episodes complete!")
    print(f"{'='*50}")

# Run it!
watch_agent_play(env, agent, num_episodes=3, delay=0.1)