In [1]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import time
import os
from collections import deque
import random

# Set environment variables to handle PyGame display better
os.environ['SDL_VIDEODRIVER'] = 'windib'  # Use Windows driver
os.environ['SDL_WINDOW_CENTERED'] = '1'

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return (np.array(state), np.array(action), np.array(reward), 
                np.array(next_state), np.array(done))
    
    def __len__(self):
        return len(self.buffer)

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )
    
    def forward(self, state):
        return self.net(state)

class DQN:
    def __init__(self, state_dim, action_dim, discretization=5):
        self.action_dim = action_dim
        self.discretization = discretization
        self.discrete_actions = self._create_discrete_actions()
        
        # Calculate total number of discrete actions
        self.total_actions = discretization ** action_dim
        
        # Networks
        self.q_network = QNetwork(state_dim, self.total_actions)
        self.target_network = QNetwork(state_dim, self.total_actions)
        self.target_network.load_state_dict(self.q_network.state_dict())
        
        # Optimizer
        self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=1e-3)
        
        # Replay buffer
        self.replay_buffer = ReplayBuffer(100000)
        
        # Parameters
        self.batch_size = 128
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.target_update_freq = 10
        self.learning_starts = 1000
        
        # Training history
        self.rewards_history = []
        self.avg_rewards_history = []
    
    def _create_discrete_actions(self):
        """Create a discrete action space by discretizing each continuous dimension"""
        action_space = np.linspace(-1, 1, self.discretization)
        # Create all combinations of discrete actions
        action_combinations = np.array(np.meshgrid(*[action_space] * self.action_dim))
        return action_combinations.T.reshape(-1, self.action_dim)
    
    def _discrete_to_continuous(self, discrete_action):
        """Convert discrete action index to continuous action vector"""
        return self.discrete_actions[discrete_action]
    
    def get_action(self, state, training=True):
        if training and random.random() < self.epsilon:
            action_idx = random.randrange(self.total_actions)
        else:
            state = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                q_values = self.q_network(state)
                action_idx = q_values.argmax().item()
        
        return self._discrete_to_continuous(action_idx), action_idx
    
    def update(self, state, action, reward, next_state, done):
        # Store transition in replay buffer
        self.replay_buffer.push(state, action, reward, next_state, done)
        
        if len(self.replay_buffer) < self.learning_starts:
            return
        
        # Sample from replay buffer
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        
        # Convert to torch tensors
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)
        
        # Compute current Q values
        current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1))
        
        # Compute next Q values
        with torch.no_grad():
            next_q_values = self.target_network(next_states).max(1)[0]
            target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
        
        # Compute loss
        loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
        
        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
        self.optimizer.step()
        
        # Update epsilon
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

def train(render_every=20, total_episodes=1000):
    # Create environment
    env = gym.make('BipedalWalker-v3')
    
    # Initialize agent
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    agent = DQN(state_dim, action_dim)
    
    # Training loop
    best_reward = float('-inf')
    episode_rewards = []
    update_count = 0
    
    try:
        for episode in range(total_episodes):
            state, _ = env.reset()
            episode_reward = 0
            
            # Create display environment if needed
            if episode % render_every == 0:
                display_env = gym.make('BipedalWalker-v3', render_mode='human')
                display_state, _ = display_env.reset()
            
            while True:
                # Get action from agent
                action, action_idx = agent.get_action(state)
                
                # Take step in environment
                next_state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated
                
                # Update agent
                agent.update(state, action_idx, reward, next_state, done)
                update_count += 1
                
                # Update target network
                if update_count % agent.target_update_freq == 0:
                    agent.target_network.load_state_dict(agent.q_network.state_dict())
                
                episode_reward += reward
                state = next_state
                
                # Display if needed
                if episode % render_every == 0:
                    display_action, _ = agent.get_action(display_state, training=False)
                    display_state, _, terminated, truncated, _ = display_env.step(display_action)
                    if terminated or truncated:
                        display_env.close()
                        break
                
                if done:
                    break
            
            # Store reward
            episode_rewards.append(episode_reward)
            avg_reward = np.mean(episode_rewards[-100:])
            
            # Print progress
            print(f"Episode {episode + 1}")
            print(f"Reward: {episode_reward:.2f}")
            print(f"Average Reward (last 100): {avg_reward:.2f}")
            print(f"Epsilon: {agent.epsilon:.3f}")
            print("-" * 50)
            
            # Plot progress
            if (episode + 1) % 10 == 0:
                plt.figure(figsize=(10, 5))
                plt.plot(episode_rewards)
                plt.title("Training Progress")
                plt.xlabel("Episode")
                plt.ylabel("Reward")
                plt.savefig("training_progress.png")
                plt.close()
            
            # Save best model
            if avg_reward > best_reward:
                best_reward = avg_reward
                torch.save({
                    'q_network_state_dict': agent.q_network.state_dict(),
                    'target_network_state_dict': agent.target_network.state_dict(),
                    'reward': best_reward
                }, 'best_model.pth')
            
            # Early stopping
            if avg_reward >= 300:
                print("Environment solved!")
                break
    
    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
    
    finally:
        env.close()
        if 'display_env' in locals():
            display_env.close()
    
    return agent

def evaluate(agent, episodes=5):
    env = gym.make('BipedalWalker-v3', render_mode='human')
    
    try:
        for episode in range(episodes):
            state, _ = env.reset()
            total_reward = 0
            done = False
            
            while not done:
                action, _ = agent.get_action(state, training=False)
                state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated
                total_reward += reward
                time.sleep(0.01)  # Slow down visualization
            
            print(f"Episode {episode + 1} Reward: {total_reward:.2f}")
    
    finally:
        env.close()

if __name__ == "__main__":
    # Set random seeds
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    
    # Train agent
    print("Starting training...")
    agent = train(render_every=20)  # Show every 20 episodes
    
    # Evaluate agent
    print("\nEvaluating trained agent...")
    evaluate(agent)

Starting training...


DependencyNotInstalled: Box2D is not installed, you can install it by run `pip install swig` followed by `pip install "gymnasium[box2d]"`