In [6]:
# Install dependencies for the virtual display (if needed)
!apt-get install -y xvfb
!pip install pyvirtualdisplay

# Set up the virtual display
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import ale_py
import gymnasium as gym
from collections import deque
import random
import math
import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML, display as ipy_display

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  libfontenc1 libxfont2 libxkbfile1 x11-xkb-utils xfonts-base xfonts-encodings xfonts-utils
  xserver-common
The following NEW packages will be installed:
  libfontenc1 libxfont2 libxkbfile1 x11-xkb-utils xfonts-base xfonts-encodings xfonts-utils
  xserver-common xvfb
0 upgraded, 9 newly installed, 0 to remove and 20 not upgraded.
Need to get 7,815 kB of archives.
After this operation, 11.9 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 libfontenc1 amd64 1:1.1.4-1build3 [14.7 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxfont2 amd64 1:2.0.5-1build1 [94.5 kB]
Get:3 http://archive.ubuntu.com/ubuntu jammy/main amd64 libxkbfile1 amd64 1:1.1.0-1build3 [71.8 kB]
Get:4 http://archive.ubuntu.com/ubuntu jammy/main amd64 x11-xkb-utils amd64 7.7+5build4 [172 kB]
Get:5 http://archiv

In [7]:
class DQNConv(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQNConv, self).__init__()

        print(f"Input shape: {input_shape}")  # Debug print

        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        conv_out_size = self._get_conv_out(input_shape)
        print(f"Conv out size: {conv_out_size}\n")  # Debug print

        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return (torch.cat(state),
                torch.tensor(action),
                torch.tensor(reward),
                torch.cat(next_state),
                torch.tensor(done))

    def __len__(self):
        return len(self.buffer)

class DQNAgent:
    def __init__(self, env, memory_size=10000, batch_size=32, gamma=0.99,
                 epsilon_start=1.0, epsilon_final=0.01, epsilon_decay=10000,
                 target_update=1000):

        self.env = env
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.n_actions = env.action_space.n
        self.state_shape = env.observation_space.shape

        self.memory = ReplayBuffer(memory_size)
        self.batch_size = batch_size

        self.gamma = gamma
        self.epsilon_start = epsilon_start
        self.epsilon_final = epsilon_final
        self.epsilon_decay = epsilon_decay
        self.target_update = target_update
        self.frame_idx = 0

        self.policy_net = DQNConv(self.state_shape, self.n_actions).to(self.device)
        self.target_net = DQNConv(self.state_shape, self.n_actions).to(self.device)

        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=3e-4)

    def get_epsilon(self):
        return self.epsilon_final + (self.epsilon_start - self.epsilon_final) * \
               math.exp(-1. * self.frame_idx / self.epsilon_decay)

    def select_action(self, state, evaluate=False):
        if evaluate:
            with torch.no_grad():
                state = state.to(self.device)
                q_values = self.policy_net(state)
                return q_values.max(1)[1].item()
        else:
            epsilon = self.get_epsilon()
            if random.random() > epsilon:
                with torch.no_grad():
                    state = state.to(self.device)
                    q_values = self.policy_net(state)
                    return q_values.max(1)[1].item()
            else:
                return random.randrange(self.n_actions)

    def train_step(self):
        if len(self.memory) < self.batch_size:
            return

        state_batch, action_batch, reward_batch, next_state_batch, done_batch = \
            self.memory.sample(self.batch_size)

        state_batch = state_batch.to(self.device)
        action_batch = action_batch.to(self.device)
        reward_batch = reward_batch.to(self.device)
        next_state_batch = next_state_batch.to(self.device)
        done_batch = done_batch.to(self.device)

        # Current Q values for the actions taken
        current_q_values = self.policy_net(state_batch).gather(1, action_batch.unsqueeze(1))

        # Basic DQN target: reward + gamma * max_a Q_target(next_state, a)
        next_q_values = self.target_net(next_state_batch).max(1)[0].unsqueeze(1)
        expected_q_values = reward_batch.unsqueeze(1) + (1 - done_batch.float().unsqueeze(1)) * self.gamma * next_q_values

        loss = F.smooth_l1_loss(current_q_values, expected_q_values.detach())

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def update_target_network(self):
        if self.frame_idx % self.target_update == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())

    def preprocess_state(self, state):
        return torch.FloatTensor(state).unsqueeze(0)

    def train_episode(self, render=False):
        state, _ = self.env.reset()
        state = self.preprocess_state(state)
        done = False
        truncated = False
        total_reward = 0

        while not (done or truncated):
            if render:
                self.env.render()

            action = self.select_action(state)
            next_state, reward, done, truncated, _ = self.env.step(action)
            next_state = self.preprocess_state(next_state)

            self.memory.push(state, action, reward, next_state, done)

            self.train_step()
            self.update_target_network()

            state = next_state
            total_reward += reward
            self.frame_idx += 1

        return total_reward

    def play_episode(self, render=True):
        state, _ = self.env.reset()
        state = self.preprocess_state(state)
        done = False
        truncated = False
        total_reward = 0

        while not (done or truncated):
            if render:
                self.env.render()

            action = self.select_action(state, evaluate=True)
            next_state, reward, done, truncated, _ = self.env.step(action)
            state = self.preprocess_state(next_state)
            total_reward += reward

        return total_reward

# --- Visualization using inline captured frames ---

def play_episode_capture(agent):
    frames = []
    state, _ = agent.env.reset()
    state = agent.preprocess_state(state)
    done = False
    truncated = False
    total_reward = 0

    while not (done or truncated):
        # Capture frame (env was created with render_mode="rgb_array")
        frame = agent.env.render()
        frames.append(frame)

        action = agent.select_action(state, evaluate=True)
        next_state, reward, done, truncated, _ = agent.env.step(action)
        state = agent.preprocess_state(next_state)
        total_reward += reward

    return frames, total_reward

def visualize_frames(frames, interval=50):
    fig = plt.figure(figsize=(6, 6))
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])
        return patch,

    anim = animation.FuncAnimation(fig, animate, frames=len(frames), interval=interval)
    plt.close(fig)
    return HTML(anim.to_jshtml())

def visualize_agent(agent, env_name, episodes=1):
    env = gym.make(env_name, render_mode="rgb_array")
    agent.env = env

    for episode in range(episodes):
        frames, total_reward = play_episode_capture(agent)
        print(f"Visualization Episode {episode}, Reward: {total_reward:.2f}")
        ipy_display(visualize_frames(frames))
    env.close()

def create_atari_env(env_name, render_mode=None):
    env = gym.make(env_name, render_mode=render_mode)
    return env

def train(env_name, num_episodes=1000, render=False):
    env = create_atari_env(env_name, render_mode=None)
    agent = DQNAgent(env)

    rewards = []

    for episode in range(num_episodes):
        reward = agent.train_episode(render=render)
        rewards.append(reward)
        avg_reward = np.mean(rewards[-10:])
        print(f"Episode {episode}, Reward: {reward:.2f}, Avg Reward (last 10): {avg_reward:.2f}")

    return agent, rewards

In [8]:
env_name = "ALE/Pong-v5"

# Training
agent, rewards = train(env_name, num_episodes=1, render=False)

# Visualizing
visualize_agent(agent, env_name, episodes=1)

Input shape: (210, 160, 3)
Conv out size: 5120

Input shape: (210, 160, 3)
Conv out size: 5120

Episode 0, Reward: -20.00, Avg Reward (last 10): -20.00
Visualization Episode 0, Reward: -21.00
