!["Logo"](../assets/logo.png)

Created by **Domonkos Nagy**

[<img src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/github/Fortuz/rl_education/blob/main/10.%20Off-policy%20Control/atari_breakout.ipynb)

# Atari Breakout

Breakout is a famous Atari game. The dynamics are similar to pong: You move a paddle and hit the ball in a brick wall at the top of the screen. Your goal is to destroy the brick wall. You can try to break through the wall and let the ball wreak havoc on the other side, all on its own! You have five lives.

<img src="assets/breakout.gif" width="300"/>

This environment runs the actual Atari game in an emulator. The observation space is `Box(0, 255, (210, 160, 3), np.uint8)`, meaning we have a 210x160 RGB image at every time step. There are 4 possible actions: NOOP, FIRE, RIGHT and LEFT, where FIRE is only used to start a new round after the ball has fallen down. You score points by destroying bricks in the wall, and the reward depends on the color of the brick. For a more detailed documentation, see the AtariAge page linked below.

This notebook uses *Deep Q-Learning (DQN)* to train an agent to play Breakout. We will follow the original DQN paper (Human-level control through deep reinforcement learning, Mnih et. al.) 
closely with our implementation, and thus **it is highly recommended to read the paper before getting started with the notebook!** It is also recommended to run this notebook on Colab for much faster training times.

- Read the DQN paper here: https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf
- Documentation for the Atari Breakout environment: https://gymnasium.farama.org/environments/atari/breakout/
- Description of the game: https://atariage.com/manual_html_page.php?SoftwareID=889

In [1]:
# Install dependencies if running in Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    !pip install gymnasium[atari,accept-rom-license]==0.29.0
    !pip install gymnasium==0.29.0

In [2]:
from torch import nn
import torch
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStack, RecordVideo
from collections import deque
import numpy as np
from tqdm.notebook import trange
import torch.nn.functional as F
import random
import ipywidgets as widgets

## Hyperparameters

These parameters are set up by default so that you can train a decent agent in about 3 hours using the free GPUs provided by Colab. This code block more or less mirrors page 10 of the paper.

In [4]:
FRAME_SKIP = 4  # Repeat each action selected for this many frames
N_STEPS = 10_000_000 // FRAME_SKIP  # Number of training time steps
BATCH_SIZE = 32  # SGD batch size
BUFFER_SIZE = 1_000_000 // FRAME_SKIP  # Size of the replay memory
HISTORY_LENGTH = 4  # Number of frames given to the Q-network
GAMMA = 0.99  # Discount factor
OPTIM_FREQ = 4  # Optimize after this many action selections
ALPHA = 5e-5  # Learning rate
EPSILON = 1.0  # Initial exploration
EPSILON_MIN = 0.01  # Final exploration
EPSILON_DECAY = (EPSILON - EPSILON_MIN) / (1_000_000 // FRAME_SKIP)  # Exploration decay rate
MIN_REPLAY_SIZE = 50_000 // FRAME_SKIP  # Minimum size of the replay memory
NOOP_MAX = 30  # Maximum number of NOOP actions take by the agent at the start of each episode
TARGET_UPDATE_FREQ = 250 * OPTIM_FREQ  # Update the target network with the online network's weights this frequently
LOG_FREQ = N_STEPS // 100  # Progress log frequency
N_RECORDINGS = 5  # Number of episodes to record

## Nature CNN

This class defines the network architecture described in the paper, commonly called the "Nature CNN".
The `act` method runs a forward pass on the state it receives as an argument, and returns the argmax
of the resulting action values.

In [5]:
class Net(nn.Module):
    def __init__(self, env, device):
        super(Net, self).__init__()

        self.device = device
        # Input and output shape
        in_channels = env.observation_space.shape[0]
        out_channels = int(env.action_space.n)
        # Convolutional layers
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        # Linear layers
        self.fc1 = nn.LazyLinear()
        self.fc2 = nn.Linear(512, out_channels)

    def forward(self, x):
        # Convolution 1
        x = self.conv1(x)
        x = F.relu(x)
        # Convolution 2
        x = self.conv2(x)
        x = F.relu(x)
        # Convolution 3
        x = self.conv3(x)
        x = F.relu(x)
        # Flattening, linear layers
        x = torch.flatten(x, start_dim=1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)

        return x

    # Return the action with the highest Q-value for the input state
    def act(self, state):
        with torch.no_grad():
            state_t = torch.as_tensor(np.array(state), dtype=torch.float32, device=self.device)
            q_values = self(state_t.unsqueeze(0))
        return torch.argmax(q_values[0]).item()

In [5]:
class ReplayMemory():
    def __init__(self, env, min_size, max_size, device):
        self.device = device
        self.transitions = [None] * max_size

        self.idx = 0
        self.min_size = min_size
        self.max_size = max_size
        self.full = False
        self._initialize(env)

    def _initialize(self, env):
        state, _ = env.reset()

        for _ in trange(self.min_size):
            action = env.action_space.sample()

            new_state, reward, terminated, truncated, _ = env.step(action)

            transition = (state, action, reward, new_state, terminated)
            self.append(transition)
            state = new_state

            if terminated or truncated:
                state, _ = env.reset()

    def append(self, transition):
        self.transitions[self.idx] = transition

        self.idx = (self.idx + 1) % self.max_size
        self.full = self.full or self.idx == 0

    def sample(self, batch_size):
        size = self.max_size if self.full else self.idx
        indices = np.random.randint(0, high=size, size=(batch_size,))

        states = np.array([self.transitions[i][0] for i in indices])
        actions = np.array([[self.transitions[i][1]] for i in indices])
        rewards = np.array([[self.transitions[i][2]] for i in indices])
        new_states = np.array([self.transitions[i][3] for i in indices])
        terminateds = np.array([[self.transitions[i][4]] for i in indices])

        states_t = torch.as_tensor(states, dtype=torch.float32, device=self.device)
        actions_t = torch.as_tensor(actions, dtype=torch.int64, device=self.device)
        rewards_t = torch.as_tensor(rewards, dtype=torch.float32, device=self.device)
        new_states_t = torch.as_tensor(new_states, dtype=torch.float32, device=self.device)
        terminateds_t = torch.as_tensor(terminateds, dtype=torch.float32, device=self.device)


        return states_t, actions_t, rewards_t, new_states_t, terminateds_t

In [6]:
rec_steps = np.linspace(-1, N_STEPS-10_000, num=N_RECORDINGS, dtype=int)
rec_episodes = [0]
trigger = lambda ep: ep in rec_episodes

In [7]:
def initialize_env(recording=False):
    env = gym.make('ALE/Breakout-v5', render_mode='rgb_array',
               frameskip=1, repeat_action_probability=0)
    env.metadata['render_fps'] = 60
    if recording:
        env = RecordVideo(env, video_folder="./videos", episode_trigger=trigger, disable_logger=True)
    env = AtariPreprocessing(env, noop_max=NOOP_MAX,
                                  frame_skip=FRAME_SKIP, scale_obs=True)
    env = FrameStack(env, HISTORY_LENGTH)

    return env

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
env = initialize_env()
replay_memory = ReplayMemory(env, MIN_REPLAY_SIZE, BUFFER_SIZE, device)
env = initialize_env(recording=True)

  0%|          | 0/12500 [00:00<?, ?it/s]

  logger.warn(
  logger.warn(


In [9]:
online_net = Net(env, device).to(device)
target_net = Net(env, device).to(device)
target_net.load_state_dict(online_net.state_dict())
print(device)

  and should_run_async(code)


cuda


In [10]:
optimizer = torch.optim.Adam(online_net.parameters(), lr=ALPHA)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.8)

In [11]:
env = initialize_env(True)
state, _ = env.reset()
action = 0
rew_buffer = deque([0.], maxlen=LOG_FREQ * 3 // 800)
episode_reward = 0.
loss_buffer = deque([0.], maxlen=LOG_FREQ * 3 // OPTIM_FREQ)
n_eps = 0
online_net.train()

for step in trange(N_STEPS):
    if np.random.rand() < EPSILON:
        action = env.action_space.sample()
    else:
        action = online_net.act(state)

    new_state, reward, terminated, truncated, _ = env.step(action)

    transition = (state, action, reward, new_state, terminated)
    replay_memory.append(transition)

    episode_reward += reward
    state = new_state

    if step in rec_steps:
        rec_episodes.append(n_eps + 1)

    if terminated or truncated:
        rew_buffer.append(episode_reward)
        episode_reward = 0.
        n_eps += 1
        state, _ = env.reset()

    EPSILON = max(EPSILON - EPSILON_DECAY, EPSILON_MIN)

    if step % OPTIM_FREQ == 0:
        # Optimiziation step
        states, actions, rewards, new_states, terminateds = replay_memory.sample(BATCH_SIZE)

        # Compute targets
        with torch.no_grad():
            target_q_values = target_net(new_states)
            max_target_q_values = target_q_values.max(dim=1, keepdim=True)[0]
            targets = rewards + GAMMA * (1 - terminateds) * max_target_q_values

        # Compute loss
        q_values = online_net(states)
        action_q_values = torch.gather(input=q_values, dim=1, index=actions)

        #print(action_q_values)
        #print(targets)

        loss = F.smooth_l1_loss(action_q_values, targets)
        loss_buffer.append(loss.item())

        # Gradient descent step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Update target network
    if step % TARGET_UPDATE_FREQ == 0:
        target_net.load_state_dict(online_net.state_dict())
        #scheduler.step()

    # Logging
    if step % LOG_FREQ == 0:
        print(f'Step {step:,}/{N_STEPS:,}:\n\tAvg. Rew: {np.mean(rew_buffer)}\n\tAvg. Loss: {np.mean(loss_buffer)}\n\tEpisodes: {n_eps}')
        model_scripted = torch.jit.script(online_net)  # Export to TorchScript
        model_scripted.save('model_scripted.pt')  # Save

  and should_run_async(code)
  logger.warn(
  logger.warn(


  0%|          | 0/2500000 [00:00<?, ?it/s]

Step 0/2,500,000:
	Avg. Rew: 0.0
	Avg. Loss: 0.0007584651466459036
	Episodes: 0
Step 25,000/2,500,000:
	Avg. Rew: 1.3870967741935485
	Avg. Loss: 0.003761217329411684
	Episodes: 129
Step 50,000/2,500,000:
	Avg. Rew: 1.118279569892473
	Avg. Loss: 0.0038503264800945174
	Episodes: 266
Step 75,000/2,500,000:
	Avg. Rew: 1.4623655913978495
	Avg. Loss: 0.004079032092438777
	Episodes: 396
Step 100,000/2,500,000:
	Avg. Rew: 2.247311827956989
	Avg. Loss: 0.003959366893396558
	Episodes: 511
Step 125,000/2,500,000:
	Avg. Rew: 2.946236559139785
	Avg. Loss: 0.0038263056703951833
	Episodes: 614
Step 150,000/2,500,000:
	Avg. Rew: 3.956989247311828
	Avg. Loss: 0.003629148761924977
	Episodes: 700
Step 175,000/2,500,000:
	Avg. Rew: 5.419354838709677
	Avg. Loss: 0.00390175898233739
	Episodes: 772
Step 200,000/2,500,000:
	Avg. Rew: 7.204301075268817
	Avg. Loss: 0.004383385069508416
	Episodes: 834
Step 225,000/2,500,000:
	Avg. Rew: 9.505376344086022
	Avg. Loss: 0.004748695286490644
	Episodes: 884
Step 250,00

In [12]:
children = [widgets.Video.from_file(f'./videos/rl-video-episode-{episode}.mp4', autoplay=False, loop=False, width=500) for episode in rec_episodes]
tab = widgets.Tab()
tab.children = children
titles = tuple([f'Episode {episode+1}' for episode in rec_episodes])
for i in range(len(children)):
    tab.set_title(i, titles[i])
display(tab)

Tab(children=(Video(value=b'\x00\x00\x00 ftypisom\x00\x00\x02\x00isomiso2avc1mp41\x00\x00\x00\x08free\x00\x00\…

  and should_run_async(code)
