# Space Invaders with Deep Q-Networks

In this notebook I will try to create a DQN that surpasses human ability in playing Space Invaders.
Lets see how it goes...

In [1]:
#!pip install ale-py
#!pip install torch

In [2]:
import torch
print("CUDA" if torch.cuda.is_available() else "cpu")

CUDA


## SCORING in Space invaders 

The SPACE INVADERS are worth 5, 10, 15, 20, 25, 30 points in
the first through sixth rows respectively. (See diagram.) The
point value of each target stays the same as it drops lower on
the screen. Each complete set of SPACE INVADERS is worth 630
points.


taken from https://atariage.com/manual_html_page.php?SoftwareLabelID=460

In [3]:
import logging
import gymnasium as gym
from gymnasium.wrappers import RecordEpisodeStatistics, RecordVideo
import ale_py
import random
import math
import numpy as np
from collections import defaultdict

gym.register_envs(ale_py)

# Training configuration
training_period = 250           # Record video every 250 episodes
num_training_episodes = 5_000  # Total training episodes
env_name = "ALE/SpaceInvaders-v5" # has a default obs tpye of rgb, 4 frames are skipped and the repeat action propability is 0.25

# Set up logging for episode statistics
logging.basicConfig(level=logging.INFO, format='%(message)s')

# Create environment with periodic video recording
# possible to activate full action space with full_action_space=True
# using grayscaling to reduce input of Q-Network
env = gym.make(env_name, render_mode="rgb_array", obs_type="grayscale")

# Record videos periodically (every 250 episodes)
env = RecordVideo(
    env,
    video_folder="space_invaders",
    name_prefix="training",
    episode_trigger=lambda x: x % training_period == 0  # Only record every 250th episode
)

# Track statistics for every episode (lightweight)
env = RecordEpisodeStatistics(env)

print(f"Training for {num_training_episodes} episodes")
print(f"Videos will be recorded every {training_period} episodes")
print(f"Videos saved to: space_invaders/")


Training for 5000 episodes
Videos will be recorded every 250 episodes
Videos saved to: space_invaders/


A.L.E: Arcade Learning Environment (version 0.11.2+ecc1138)
[Powered by Stella]
  logger.warn(


In [4]:
from gymnasium.wrappers import FrameStackObservation
FRAME_STACK_SIZE = 4


stacked_env = FrameStackObservation(env, stack_size=FRAME_STACK_SIZE)

In [5]:
import torch.nn as nn


class QNetwork(nn.Module):

    def __init__(
        self,
        in_channels,
        action_space
        ):
        super().__init__()

        self.conv_stack = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=8, stride=4 ),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=2, stride=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=2, stride=1),
            nn.Flatten(),
        )

        dummy = torch.zeros(1, in_channels, 210, 160)
        dummy_out = self.conv_stack(dummy)
        self.num_features = dummy_out.numel()

        self.fc = nn.Sequential(
            nn.Linear(self.num_features, 1024),
            nn.ReLU(),
            nn.Linear(1024, action_space)
        )

        self.model = self.conv_stack.extend(self.fc)


    def forward(self, input):
        return self.model(input)




In [6]:
from collections import namedtuple, deque
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [7]:

class SpaceInvaderAgent:

    def __init__(
        self,
        env: gym.Env,
        learning_rate: float,
        initial_epsilon: float,
        epsilon_decay: float,
        final_epsilon: float,
        discount_factor: float = 0.95,
        frame_stacking: int = FRAME_STACK_SIZE,
    ):
        self.env = env
        
        # Q-Network to represent the current policy
        self.policy_network = QNetwork(in_channels=frame_stacking, action_space=env.action_space.n)

        
        self.target_network = QNetwork(in_channels=frame_stacking, action_space=env.action_space.n)
        self.target_network.load_state_dict(self.policy_network.state_dict())

        self.optimizer = torch.optim.Adam(self.policy_network.parameters(), lr=learning_rate)

        self.discount_factor = discount_factor  # How much we care about future rewards also known as gamma

        # Exploration parameters
        self.epsilon = initial_epsilon
        self.epsilon_decay = epsilon_decay
        self.final_epsilon = final_epsilon

        # initialize a replay buffer
        self.memory = ReplayMemory(10_000)
        self.batch_size = 120

        # Track learning progress
        self.training_error = []

    # currently using a random action
    def get_action(self, obs)-> int:
        # With probability epsilon: explore (random action)
        if np.random.random() < self.epsilon:
            return self.env.action_space.sample()

        # With probability (1-epsilon): exploit (best known action)
        else:
            obs = torch.FloatTensor(obs).unsqueeze(0)
            with torch.no_grad():
                return torch.argmax(self.policy_network(obs)).item()

        
    def update(self):
        if len(self.memory) < 5000:
            return

        batch = self.memory.sample(self.batch_size)
        state_batch, action_batch, reward_batch, next_state_batch = zip(*batch)

        state_batch = torch.FloatTensor(state_batch)
        action_batch = torch.LongTensor(action_batch).unsqueeze(1)
        reward_batch = torch.FloatTensor(reward_batch)
        next_state_batch = torch.FloatTensor(next_state_batch)


        # Compute Q-values for current states
        q_values = self.policy_network(state_batch).gather(1, action_batch).squeeze().to(device)

        # Compute target Q-values using the target network
        with torch.no_grad():
            max_next_q_values = self.target_network(next_state_batch).max(1)[0]
            target_q_values = reward_batch + self.discount_factor * max_next_q_values * (1 - done_batch)

        loss = nn.MSELoss()(q_values, target_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


    

    def decay_epsilon(self):
        """Reduce exploration rate after each episode."""
        self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)

    

    


In [8]:
# Training hyperparameters
learning_rate = 0.01        # How fast to learn (higher = faster but less stable)
n_episodes = 100_000        # Number of hands to practice
start_epsilon = 1.0         # Start with 100% random actions
epsilon_decay = start_epsilon / (n_episodes / 2)  # Reduce exploration over time
final_epsilon = 0.1         # Always keep some exploration


agent = SpaceInvaderAgent(
    env=stacked_env,
    learning_rate=learning_rate,
    initial_epsilon=start_epsilon,
    epsilon_decay=epsilon_decay,
    final_epsilon=final_epsilon,
)

In [9]:
for episode in range(n_episodes):
    state, info = env.reset()
    step_counter = 0
    done = False
    
    while not done:
        action = agent.get_action(state)
        next_state, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated

        step_counter += 1
        agent.memory.push(state, action, reward, next_state)

        agent.update()

        state = next_state

        if step_counter % 10000 == 0:
            agent.target_network.load_state_dict(agent.policy_network.state_dict())





  state_batch = torch.FloatTensor(state_batch)


RuntimeError: Given groups=1, weight of size [16, 4, 8, 8], expected input[1, 120, 210, 160] to have 4 channels, but got 120 channels instead

In [None]:
import matplotlib.pyplot as plt

def plot_training_results(agent, env):
    # Create a figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # --- Plot 1: Episode Rewards ---
    # RecordEpisodeStatistics stores rewards in env.return_queue
    if hasattr(env, 'return_queue') and len(env.return_queue) > 0:
        rewards = list(env.return_queue)
        ax1.plot(rewards, label='Reward per Episode', alpha=0.3, color='blue')
        
        # Add a rolling average to see the trend through the noise
        if len(rewards) >= 10:
            rolling_avg = np.convolve(rewards, np.ones(10)/10, mode='valid')
            ax1.plot(rolling_avg, label='Rolling Avg (10 ep)', color='darkblue')
            
        ax1.set_title("Episode Rewards")
        ax1.set_xlabel("Episode")
        ax1.set_ylabel("Total Reward")
        ax1.legend()

    # --- Plot 2: Training Loss ---
    if len(agent.training_error) > 0:
        # RL loss is often very "spiky", so we plot it on a log scale or with transparency
        ax2.plot(agent.training_error, label='Loss', alpha=0.3, color='orange')
        
        # Add rolling average for loss
        if len(agent.training_error) >= 50:
            rolling_loss = np.convolve(agent.training_error, np.ones(50)/50, mode='valid')
            ax2.plot(rolling_loss, label='Rolling Avg (50 updates)', color='red')
            
        ax2.set_title("Training Loss")
        ax2.set_xlabel("Update Step")
        ax2.set_ylabel("Loss (MSE)")
        ax2.legend()
    else:
        ax2.text(0.5, 0.5, "No loss data yet.\nDid you remember to append to training_error?", 
                 ha='center', va='center')

    plt.tight_layout()
    plt.show()

# Call this after your training loop or inside the loop every X episodes
plot_training_results(agent, env)