# Part 1: Pong Tournament

In [1]:
!pip install gymnasium[atari]
!pip install gymnasium[accept-rom-license]
!pip install ale-py
!pip install autorom
!AutoROM --accept-license

AutoROM will download the Atari 2600 ROMs.
They will be installed to:
	/usr/local/lib/python3.11/dist-packages/AutoROM/roms

Existing ROMs will be overwritten.
Installed /usr/local/lib/python3.11/dist-packages/AutoROM/roms/adventure.bin    
Installed /usr/local/lib/python3.11/dist-packages/AutoROM/roms/air_raid.bin
Installed /usr/local/lib/python3.11/dist-packages/AutoROM/roms/alien.bin
Installed /usr/local/lib/python3.11/dist-packages/AutoROM/roms/amidar.bin
Installed /usr/local/lib/python3.11/dist-packages/AutoROM/roms/assault.bin
Installed /usr/local/lib/python3.11/dist-packages/AutoROM/roms/asterix.bin
Installed /usr/local/lib/python3.11/dist-packages/AutoROM/roms/asteroids.bin
Installed /usr/local/lib/python3.11/dist-packages/AutoROM/roms/atlantis.bin
Installed /usr/local/lib/python3.11/dist-packages/AutoROM/roms/atlantis2.bin
Installed /usr/local/lib/python3.11/dist-packages/AutoROM/roms/backgammon.bin
Installed /usr/local/lib/python3.11/dist-packages/AutoROM/roms/bank_heist.bin


In [2]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import collections
import random
import time
import datetime
from collections import deque, namedtuple
import matplotlib.pyplot as plt
import cv2
import os
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print(f"Using Gymnasium version: {gym.__version__}")

# Define the Experience tuple
Experience = namedtuple('Experience', field_names=['state', 'action', 'reward', 'new_state', 'done'])

# Fixed Environment Wrappers for Gymnasium
class FireResetEnv(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        # Check if fire action is available
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        obs, _, terminated, truncated, _ = self.env.step(1)
        if terminated or truncated:
            self.env.reset(**kwargs)
        obs, _, terminated, truncated, _ = self.env.step(2)
        if terminated or truncated:
            self.env.reset(**kwargs)
        return obs, {}

    def step(self, action):
        return self.env.step(action)

class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env, skip=4):
        super().__init__(env)
        self._obs_buffer = collections.deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        terminated = False
        truncated = False
        info = {}
        
        for _ in range(self._skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if terminated or truncated:
                break
        
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, terminated, truncated, info

    def reset(self, **kwargs):
        self._obs_buffer.clear()
        obs, info = self.env.reset(**kwargs)
        self._obs_buffer.append(obs)
        return obs, info

class ProcessFrame84(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def observation(self, obs):
        return self._process_frame(obs)

    def _process_frame(self, frame):
        if frame.size == 210 * 160 * 3:
            img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
        elif frame.size == 250 * 160 * 3:
            img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
        else:
            # Try to handle different frame sizes
            img = frame.astype(np.float32)
            if len(img.shape) == 3 and img.shape[2] == 3:
                pass  # Already in correct format
            else:
                img = img.reshape((210, 160, 3))
        
        # Convert to grayscale
        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
        
        # Resize
        resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
        x_t = resized_screen[18:102, :]
        x_t = np.reshape(x_t, [84, 84, 1])
        return x_t.astype(np.uint8)

class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps):
        super().__init__(env)
        self.n_steps = n_steps
        old_space = env.observation_space
        new_shape = (n_steps * old_space.shape[0], old_space.shape[1], old_space.shape[2])
        
        self.observation_space = gym.spaces.Box(
            low=0.0, high=1.0, shape=new_shape, dtype=old_space.dtype
        )
        self.buffer = deque(maxlen=n_steps)

    def reset(self, **kwargs):
        self.buffer.clear()
        obs, info = self.env.reset(**kwargs)
        for _ in range(self.n_steps):
            self.buffer.append(obs)
        return self._get_obs(), info

    def observation(self, observation):
        self.buffer.append(observation)
        return self._get_obs()

    def _get_obs(self):
        return np.concatenate(self.buffer, axis=0)

class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        old_shape = self.observation_space.shape
        # Change from HWC to CHW
        new_shape = (old_shape[-1], old_shape[0], old_shape[1])
        self.observation_space = gym.spaces.Box(
            low=0.0, high=1.0, shape=new_shape, dtype=np.float32
        )

    def observation(self, observation):
        # Move channels to first dimension
        return np.moveaxis(observation, -1, 0)

class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs, dtype=np.float32) / 255.0

def make_env(env_name):
    env = gym.make(env_name, render_mode='rgb_array')
    env = MaxAndSkipEnv(env)
    env = FireResetEnv(env)
    env = ProcessFrame84(env)
    env = ImageToPyTorch(env)
    env = ScaledFloatFrame(env)
    env = BufferWrapper(env, 4)
    return env

# DQN Network
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )

        # Calculate the size of the flattened features
        with torch.no_grad():
            dummy_input = torch.zeros(1, *input_shape)
            conv_out = self.conv(dummy_input)
            n_flatten = conv_out.view(1, -1).size(1)

        self.fc = nn.Sequential(
            nn.Linear(n_flatten, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

    def forward(self, x):
        # x shape: [batch_size, channels, height, width]
        conv_out = self.conv(x)
        flattened = conv_out.view(conv_out.size(0), -1)
        return self.fc(flattened)

# Experience Buffer
class ExperienceBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def __len__(self):
        return len(self.buffer)

    def push(self, state, action, reward, next_state, done):
        experience = Experience(state, action, reward, next_state, done)
        self.buffer.append(experience)

    def sample(self, batch_size):
        if len(self.buffer) < batch_size:
            batch_size = len(self.buffer)
            
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        return (
            np.array(states, dtype=np.float32),
            np.array(actions, dtype=np.int64),
            np.array(rewards, dtype=np.float32),
            np.array(next_states, dtype=np.float32),
            np.array(dones, dtype=np.bool_)
        )

# Agent
class Agent:
    def __init__(self, env, exp_buffer):
        self.env = env
        self.exp_buffer = exp_buffer
        self.reset()

    def reset(self):
        self.state, _ = self.env.reset()
        self.total_reward = 0.0

    @torch.no_grad()
    def play_step(self, net, epsilon=0.0, device=device):
        done_reward = None

        if np.random.random() < epsilon:
            action = self.env.action_space.sample()
        else:
            state_a = np.array([self.state], copy=False)
            state_v = torch.tensor(state_a, dtype=torch.float32).to(device)
            q_vals_v = net(state_v)
            _, act_v = torch.max(q_vals_v, dim=1)
            action = int(act_v.item())

        # Take step in environment
        new_state, reward, terminated, truncated, _ = self.env.step(action)
        done = terminated or truncated

        reward = np.clip(reward, -1.0, 1.0)
        # Store the ACTUAL step reward, not cumulative
        self.exp_buffer.push(self.state, action, reward, new_state, done)
        
        self.state = new_state
        self.total_reward += reward

        if done:
            done_reward = self.total_reward
            self.reset()
            
        return done_reward

# Loss calculation functions
def calc_loss(batch, net, tgt_net, device, gamma=0.99):
    states, actions, rewards, next_states, dones = batch
    
    # Convert to tensors
    states_v = torch.tensor(states).to(device)
    actions_v = torch.tensor(actions).to(device)
    rewards_v = torch.tensor(rewards).to(device)
    next_states_v = torch.tensor(next_states).to(device)
    dones_mask = torch.tensor(dones, dtype=torch.bool).to(device)
    
    # Current Q values
    state_action_values = net(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
    
    # Next Q values from target network
    with torch.no_grad():
        next_state_values = tgt_net(next_states_v).max(1)[0]
        next_state_values[dones_mask] = 0.0
        expected_state_action_values = next_state_values * gamma + rewards_v
    
    # Compute loss
    loss = nn.MSELoss()(state_action_values, expected_state_action_values)
    return loss

# Plotting function
def plot_rewards(rewards, ma_window=100):
    plt.figure(figsize=(12, 8))
    plt.title("Training Rewards")
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    
    if len(rewards) > 0:
        plt.plot(rewards, alpha=0.3, label='Raw rewards', color='blue')
        
        if len(rewards) >= ma_window:
            ma_rewards = []
            for i in range(len(rewards) - ma_window + 1):
                ma_rewards.append(np.mean(rewards[i:i+ma_window]))
            plt.plot(range(ma_window-1, len(rewards)), ma_rewards, 
                    label=f'Moving Average ({ma_window} episodes)', color='red', linewidth=2)
        
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

# Training function
def train_dqn():
    # Create models directory
    os.makedirs("models", exist_ok=True)
    
    # Parameters
    ENV_NAME = "PongNoFrameskip-v4"
    MEAN_REWARD_BOUND = 18.0  # Pong is solved when average reward > 18
    
    GAMMA = 0.99
    BATCH_SIZE = 32
    REPLAY_SIZE = 50000
    LEARNING_RATE = 1e-4
    SYNC_TARGET_FRAMES = 1000
    REPLAY_START_SIZE = 10000
    
    EPSILON_DECAY_LAST_FRAME = 50000
    EPSILON_START = 1.0
    EPSILON_FINAL = 0.02
    
    print(f"Training DQN on {ENV_NAME}")
    print(f"Device: {device}")
    
    # Create environment
    env = make_env(ENV_NAME)
    
    # Initialize networks
    net = DQN(env.observation_space.shape, env.action_space.n).to(device)
    tgt_net = DQN(env.observation_space.shape, env.action_space.n).to(device)
    tgt_net.load_state_dict(net.state_dict())
    
    print(f"Network: {net}")
    print(f"Input shape: {env.observation_space.shape}")
    print(f"Number of actions: {env.action_space.n}")
    
    # Initialize buffer and agent
    buffer = ExperienceBuffer(REPLAY_SIZE)
    agent = Agent(env, buffer)
    
    optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
    total_rewards = []
    frame_idx = 0
    ts_frame = 0
    ts = time.time()
    best_mean_reward = None
    
    print("Starting training at:", datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    print("Filling replay buffer...")
    
    # Initial filling of replay buffer
    while len(buffer) < REPLAY_START_SIZE:
        frame_idx += 1
        epsilon = 1.0  # Always explore during initial filling
        reward = agent.play_step(net, epsilon, device)
        if reward is not None:
            total_rewards.append(reward)
            
        if frame_idx % 1000 == 0:
            print(f"Filled {len(buffer)}/{REPLAY_START_SIZE} experiences in replay buffer")
    
    print("Replay buffer filled. Starting training...")
    
    for episode in range(10000):
        frame_idx += 1
        epsilon = EPSILON_FINAL + (EPSILON_START - EPSILON_FINAL) * max(0, 1 - frame_idx / EPSILON_DECAY_LAST_FRAME)
        
        reward = agent.play_step(net, epsilon, device)
        if reward is not None:
            total_rewards.append(reward)
            mean_reward = np.mean(total_rewards[-100:]) if len(total_rewards) >= 100 else np.mean(total_rewards)
            
            # Calculate speed
            speed = (frame_idx - ts_frame) / (time.time() - ts)
            ts_frame = frame_idx
            ts = time.time()
            
            print(f"Frame {frame_idx}: Episode {len(total_rewards)}, Reward: {reward:.1f}, "
                  f"Mean Reward (last 100): {mean_reward:.3f}, Epsilon: {epsilon:.3f}, Speed: {speed:.2f} f/s")
            
            # Save best model
            if best_mean_reward is None or mean_reward > best_mean_reward:
                if best_mean_reward is not None:
                    print(f"New best mean reward: {best_mean_reward:.3f} -> {mean_reward:.3f}")
                best_mean_reward = mean_reward
                model_path = f"models/{ENV_NAME}_best.pth"
                torch.save(net.state_dict(), model_path)
                print(f"Model saved to {model_path}")
            
            # Check if solved
            if mean_reward >= MEAN_REWARD_BOUND:
                print(f"Solved at frame {frame_idx} with mean reward {mean_reward:.3f}!")
                break
        
        # Sync target network
        if frame_idx % SYNC_TARGET_FRAMES == 0:
            tgt_net.load_state_dict(net.state_dict())
            print(f"Target network updated at frame {frame_idx}")
        
        # Training step
        optimizer.zero_grad()
        batch = buffer.sample(BATCH_SIZE)
        loss_t = calc_loss(batch, net, tgt_net, device, GAMMA)
        loss_t.backward()
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        optimizer.step()
        
        # Plot progress every 50 episodes
        if len(total_rewards) % 50 == 0 and len(total_rewards) > 0:
            plot_rewards(total_rewards)
    
    print("Training completed at:", datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    return total_rewards, net

Using device: cuda
Using Gymnasium version: 0.29.0


In [6]:
import torch.nn.functional as F

class PolicyNetwork(nn.Module):
    """REINFORCE Policy Network"""
    def __init__(self, input_shape, n_actions):
        super(PolicyNetwork, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten()
        )
        
        with torch.no_grad():
            n_flatten = self.conv(torch.zeros(1, *input_shape)).shape[1]
            
        self.fc = nn.Sequential(
            nn.Linear(n_flatten, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
    
    def forward(self, x):
        x = self.conv(x)
        return F.softmax(self.fc(x), dim=-1)

class REINFORCEAgent:
    def __init__(self, env):
        self.env = env
        self.net = PolicyNetwork(env.observation_space.shape, env.action_space.n).to(device)
        self.optimizer = optim.Adam(self.net.parameters(), lr=1e-4)
        self.saved_log_probs = []
        self.rewards = []
    
    def act(self, state):
        state_v = torch.tensor(np.array([state]), dtype=torch.float32).to(device)
        probs = self.net(state_v)
        m = torch.distributions.Categorical(probs)
        action = m.sample()
        self.saved_log_probs.append(m.log_prob(action))
        return action.item()
    
    def update_policy(self, gamma=0.99):
        returns = []
        R = 0
        for r in self.rewards[::-1]:
            R = r + gamma * R
            returns.insert(0, R)
        
        returns = torch.tensor(returns).to(device)
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        
        policy_loss = []
        for log_prob, R in zip(self.saved_log_probs, returns):
            policy_loss.append(-log_prob * R)
        
        self.optimizer.zero_grad()
        policy_loss = torch.cat(policy_loss).sum()
        policy_loss.backward()
        self.optimizer.step()
        
        self.saved_log_probs = []
        self.rewards = []
    
    def train(self, episodes=1000):
        total_rewards = []
        
        for episode in range(episodes):
            state, _ = self.env.reset()
            episode_reward = 0
            done = False
            
            while not done:
                action = self.act(state)
                state, reward, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated
                self.rewards.append(reward)
                episode_reward += reward
            
            self.update_policy()
            total_rewards.append(episode_reward)
            
            if episode % 50 == 0:
                avg_reward = np.mean(total_rewards[-50:]) if len(total_rewards) >= 50 else np.mean(total_rewards)
                print(f"REINFORCE Episode {episode}, Reward: {episode_reward}, Avg: {avg_reward:.2f}")
        
        return total_rewards

In [7]:
def compare_models(dqn_rewards, reinforce_rewards):
    """Compare both models and determine the best one"""
    plt.figure(figsize=(12, 6))
    
    # Smooth rewards for better visualization
    window = 100
    if len(dqn_rewards) >= window:
        dqn_smooth = np.convolve(dqn_rewards, np.ones(window)/window, mode='valid')
        reinforce_smooth = np.convolve(reinforce_rewards, np.ones(window)/window, mode='valid')
        
        plt.plot(range(window-1, len(dqn_rewards)), dqn_smooth, label='DQN', linewidth=2)
        plt.plot(range(window-1, len(reinforce_rewards)), reinforce_smooth, label='REINFORCE', linewidth=2)
    else:
        plt.plot(dqn_rewards, label='DQN', alpha=0.7)
        plt.plot(reinforce_rewards, label='REINFORCE', alpha=0.7)
    
    plt.xlabel('Episode')
    plt.ylabel('Smoothed Reward')
    plt.title('DQN vs REINFORCE Performance Comparison')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('model_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Final analysis
    dqn_final_avg = np.mean(dqn_rewards[-100:]) if len(dqn_rewards) >= 100 else np.mean(dqn_rewards)
    reinforce_final_avg = np.mean(reinforce_rewards[-100:]) if len(reinforce_rewards) >= 100 else np.mean(reinforce_rewards)
    
    print("\n" + "="*60)
    print("FINAL MODEL COMPARISON")
    print("="*60)
    print(f"DQN Final Average Reward (last 100 episodes): {dqn_final_avg:.3f}")
    print(f"REINFORCE Final Average Reward (last 100 episodes): {reinforce_final_avg:.3f}")
    
    return "DQN" if dqn_final_avg > reinforce_final_avg else "REINFORCE"

In [5]:
if __name__ == "__main__":
    print("Starting Part 1: Solving Pong with DQN and REINFORCE")
    
    # Train DQN
    print("\n" + "="*50)
    print("TRAINING DQN")
    print("="*50)
    dqn_rewards, trained_dqn = train_dqn()
    
    # Train REINFORCE
    print("\n" + "="*50)
    print("TRAINING REINFORCE")
    print("="*50)
    env = make_env("PongNoFrameskip-v4")
    reinforce_agent = REINFORCEAgent(env)
    reinforce_rewards = reinforce_agent.train(episodes=1000)
    
    # Compare models
    print("\n" + "="*50)
    print("MODEL COMPARISON")
    print("="*50)
    best_model = compare_models(dqn_rewards, reinforce_rewards)
    
    print(f"\nðŸ“‹ Part 1 Summary:")
    print(f"- Environment: PongNoFrameskip-v4")
    print(f"- Models implemented: DQN and REINFORCE")
    print(f"- Best performing model: {best_model}")
    print(f"- Final DQN performance: {np.mean(dqn_rewards[-100:]):.3f} (last 100 episodes)")
    print(f"- Final REINFORCE performance: {np.mean(reinforce_rewards[-100:]):.3f} (last 100 episodes)")

Starting Part 1: Solving Pong with DQN and REINFORCE

TRAINING DQN
Training DQN on PongNoFrameskip-v4
Device: cuda


A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Network: DQN(
  (conv): Sequential(
    (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
  )
  (fc): Sequential(
    (0): Linear(in_features=3136, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=6, bias=True)
  )
)
Input shape: (4, 84, 84)
Number of actions: 6
Starting training at: 2025-11-27 14:41:38
Filling replay buffer...
Filled 1000/10000 experiences in replay buffer
Filled 2000/10000 experiences in replay buffer
Filled 3000/10000 experiences in replay buffer
Filled 4000/10000 experiences in replay buffer
Filled 5000/10000 experiences in replay buffer
Filled 6000/10000 experiences in replay buffer
Filled 7000/10000 experiences in replay buffer
Filled 8000/10000 experiences in replay buffer
Filled 9000/10000 experiences in replay buffer
Filled 10000/10000 experiences in r

NameError: name 'F' is not defined

## Data Preprocessing (Wrappers)

In [None]:
class FireResetEnv(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'

    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        obs, _, term, trunc, _ = self.env.step(1)
        if term or trunc:
            self.env.reset(**kwargs)
        obs, _, term, trunc, _ = self.env.step(2)
        if term or trunc:
            self.env.reset(**kwargs)
        return obs, {}

    def step(self, action):
        return self.env.step(action)

class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env, skip=4):
        super().__init__(env)
        self.buffer = collections.deque(maxlen=2)
        self.skip = skip

    def step(self, action):
        total_reward = 0
        term = False
        trunc = False
        info = {}
        for _ in range(self.skip):
            obs, reward, term, trunc, info = self.env.step(action)
            self.buffer.append(obs)
            total_reward += reward
            if term or trunc:
                break
        max_frame = np.max(np.stack(self.buffer), axis=0)
        return max_frame, total_reward, term, trunc, info

    def reset(self, **kwargs):
        self.buffer.clear()
        obs, info = self.env.reset(**kwargs)
        self.buffer.append(obs)
        return obs, info


class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps):
        super().__init__(env)
        self.n_steps = n_steps
        old = env.observation_space
        self.observation_space = gym.spaces.Box(
            low=old.low.repeat(n_steps,0),
            high=old.high.repeat(n_steps,0),
            dtype=old.dtype
        )
        self.buffer = deque(maxlen=n_steps)

    def reset(self, **kwargs):
        self.buffer.clear()
        obs, info = self.env.reset(**kwargs)
        for _ in range(self.n_steps):
            self.buffer.append(obs)
        return self._get(), info

    def observation(self, obs):
        self.buffer.append(obs)
        return self._get()

    def _get(self):
        return np.stack(self.buffer, axis=0)


class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs, dtype=np.float32) / 255.0

def make_env(env_name):
    env = gym.make(env_name, render_mode='rgb_array')
    env = MaxAndSkipEnv(env)
    env = FireResetEnv(env)
    env = ProcessFrame84(env)
    env = BufferWrapper(env, 4)
    env = ScaledFloatFrame(env)
    return env

## DQN

In [None]:
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, 8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(),
            nn.Flatten()
        )
        with torch.no_grad():
            n_flatten = self.conv(torch.zeros(1, *input_shape)).shape[1]

        self.fc = nn.Sequential(
            nn.Linear(n_flatten, 512), nn.ReLU(),
            nn.Linear(512, n_actions)
        )

    def forward(self, x):
        return self.fc(self.conv(x))

In [None]:
class ExperienceBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def __len__(self):
        return len(self.buffer)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append(Experience(state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch_size = min(batch_size, len(self.buffer))
        batch = random.sample(self.buffer, batch_size)
        s, a, r, ns, d = zip(*batch)
        return (
            np.array(s, dtype=np.float32),
            np.array(a, dtype=np.int64),
            np.array(r, dtype=np.float32),
            np.array(ns, dtype=np.float32),
            np.array(d, dtype=np.bool_)
        )


class Agent:
    def __init__(self, env, buffer):
        self.env = env
        self.buffer = buffer
        self.reset()

    def reset(self):
        self.state, _ = self.env.reset()
        self.total_reward = 0

    @torch.no_grad()
    def play_step(self, net, epsilon=0.0, device=device):
        if np.random.random() < epsilon:
            action = self.env.action_space.sample()
        else:
            s = torch.tensor([self.state], dtype=torch.float32).to(device)
            q = net(s)
            action = int(torch.argmax(q, dim=1).item())

        new_state, reward, term, trunc, _ = self.env.step(action)
        done = term or trunc
        self.total_reward += reward

        self.buffer.push(self.state, action, reward, new_state, done)
        self.state = new_state

        if done:
            r = self.total_reward
            self.reset()
            return r
        return None



In [None]:
def calc_loss(batch, net, tgt_net, device, gamma=0.99):
    states, actions, rewards, next_states, dones = batch

    s = torch.tensor(states).to(device)
    a = torch.tensor(actions).to(device)
    r = torch.tensor(rewards).to(device)
    ns = torch.tensor(next_states).to(device)
    d = torch.tensor(dones, dtype=torch.bool).to(device)

    q_vals = net(s).gather(1, a.unsqueeze(-1)).squeeze(-1)

    with torch.no_grad():
        next_q = tgt_net(ns).max(1)[0]
        next_q[d] = 0.0
        expected = r + gamma * next_q

    return nn.MSELoss()(q_vals, expected)

In [None]:
def plot_rewards(rewards, ma_window=100):
    plt.figure(figsize=(12, 8))
    plt.title("Training Rewards")
    plt.plot(rewards, alpha=0.3)
    
    if len(rewards) >= ma_window:
        ma = [np.mean(rewards[i:i+ma_window]) 
              for i in range(len(rewards)-ma_window+1)]
        plt.plot(range(ma_window-1, len(rewards)), ma, linewidth=2)

    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.grid(True)
    plt.show()

## Training

In [None]:
def train_dqn():
    os.makedirs("models", exist_ok=True)

    ENV_NAME = "PongNoFrameskip-v4"
    MEAN_REWARD_BOUND = 18.0

    GAMMA = 0.99
    BATCH_SIZE = 32
    REPLAY_SIZE = 10000
    LR = 1e-4
    SYNC_FRAMES = 1000
    REPLAY_START = 10000

    EPS_LAST = 100000
    EPS_START = 1.0
    EPS_END = 0.02

    env = make_env(ENV_NAME)
    net = DQN(env.observation_space.shape, env.action_space.n).to(device)
    tgt = DQN(env.observation_space.shape, env.action_space.n).to(device)
    tgt.load_state_dict(net.state_dict())

    buffer = ExperienceBuffer(REPLAY_SIZE)
    agent = Agent(env, buffer)

    opt = optim.Adam(net.parameters(), lr=LR)
    rewards = []
    frame_idx = 0
    best_mean = None

    while len(buffer) < REPLAY_START:
        agent.play_step(net, epsilon=1.0)
        frame_idx += 1

    for episode in range(10000):
        frame_idx += 1
        epsilon = max(EPS_END, EPS_START - frame_idx / EPS_LAST)

        reward = agent.play_step(net, epsilon)
        if reward is not None:
            rewards.append(reward)
            mean = np.mean(rewards[-100:])
            print(f"Frame {frame_idx}, Episode {len(rewards)}, Reward {reward:.1f}, Mean {mean:.2f}, Îµ={epsilon:.3f}")

            if best_mean is None or mean > best_mean:
                best_mean = mean
                torch.save(net.state_dict(), f"models/{ENV_NAME}_best.pth")

            if mean >= MEAN_REWARD_BOUND:
                print("Solved!")
                break

        if frame_idx % SYNC_FRAMES == 0:
            tgt.load_state_dict(net.state_dict())

        opt.zero_grad()
        batch = buffer.sample(BATCH_SIZE)
        loss = calc_loss(batch, net, tgt, device, GAMMA)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        opt.step()

    return rewards, net

In [None]:
rewards, trained_net = train_dqn()

plot_rewards(rewards)

if len(rewards) >= 100:
    best100 = max(np.mean(rewards[i:i+100]) for i in range(len(rewards)-99))
    print(f"Best 100-episode mean reward: {best100:.3f}")
else:
    print(f"Mean reward: {np.mean(rewards):.3f}")

# REINFORCE