# DQN from Memory in Breakout

I need:
* [x] The actual Q network
* [x] A replay buffer
* [x] Action selection
* [x] Training loop

The process:
1. Set up online Q network and target network
1. Initialize (vectorized) environment
1. Perform training loop (Optional: With logging)
1. Evalute and record an episode

## The Q network

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [2]:
class QNetwork(nn.Module):
    def __init__(self, obs_shape, n_actions, n_hiddens=128):
        super().__init__()
        # The observations are Batch x width x stack (channels) x height, for some reason
        in_shape = (obs_shape[0], obs_shape[2], obs_shape[3], obs_shape[1])

        # Convolutional layers
        self.convs = nn.Sequential(
            nn.Conv2d(in_channels=in_shape[1], out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.MaxPool2d(2, 2),
        )

        # Calculate the input size for the first fully connected layer
        dummy_in = torch.zeros(*in_shape[1:])
        dummy_out = self.convs(dummy_in)
        n_input = dummy_out.shape.numel()

        # Fully connected layers
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(n_input, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, n_actions),
        )

    def forward(self, x):
        device = next(self.mlp.parameters())[0].device
        x = torch.tensor(np.array(x) / 255.0, dtype=torch.float32, device=device)
        x = x.permute(0, 2, 3, 1)
        
        x = self.convs(x)
        return self.mlp(x)

### Testing things

In [3]:
import numpy as np

In [4]:
import pufferlib
from pufferlib.environments import atari

In [5]:
env_name = "breakout"
env_creator = atari.env_creator(env_name)

In [6]:
import pufferlib.vector

In [7]:
vecenv = pufferlib.vector.make(
    env_creator,
    env_kwargs={
        "framestack": 4,
    },
    backend=pufferlib.vector.Multiprocessing,
    num_envs=8
)

A.L.E: Arcade Learning Environment (version 0.9.0+750d7f9)
[Powered by Stella]


In [8]:
obs_shape = vecenv.obs_batch_shape
obs_shape

(8, 80, 4, 105)

In [9]:
n_actions = vecenv.action_space.nvec[0]
n_actions

4

In [10]:
obs, _ = vecenv.reset()
obs.transpose(0,2,3,1).shape

(8, 4, 105, 80)

In [11]:
test_network = QNetwork(obs_shape, n_actions)
print(test_network, f"\n{sum(p.numel() for p in test_network.parameters()):,}")

QNetwork(
  (convs): Sequential(
    (0): Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (mlp): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=8320, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=4, bias=True)
  )
) 
1,089,332


In [12]:
obs, _ = vecenv.reset()

In [13]:
q_values = test_network(obs)
q_values

tensor([[-0.0334,  0.0302, -0.0121, -0.0764],
        [-0.0334,  0.0302, -0.0121, -0.0764],
        [-0.0334,  0.0302, -0.0121, -0.0764],
        [-0.0334,  0.0302, -0.0121, -0.0764],
        [-0.0334,  0.0302, -0.0121, -0.0764],
        [-0.0334,  0.0302, -0.0121, -0.0764],
        [-0.0334,  0.0302, -0.0121, -0.0764],
        [-0.0334,  0.0302, -0.0121, -0.0764]], grad_fn=<AddmmBackward0>)

## Replay Buffer

* Needs to be able to store some number of transitions
* Needs to automatically kick out old transitions
* Needs `store`
* Needs `extend`
* Needs `__len__()`
* Neesd `sample()`

In [14]:
from collections import deque
import random

In [15]:
class ReplayBuffer:
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.queue = deque(maxlen=capacity)

    def store(self, transition):
        self.queue.append(transition)

    def extend(self, transitions):
        for transition in transitions:
            self.store(transition)

    def sample(self, n_samples: int):
        return [random.choice(self.queue) for _ in range(n_samples)]

    def __len__(self):
        return len(self.queue)

    def __repr__(self):
        return self.queue.__repr__()

### Testing things

In [16]:
buf = ReplayBuffer(5)

In [17]:
for i in range(10):
    buf.store(i)

buf, buf.sample(2)

(deque([5, 6, 7, 8, 9], maxlen=5), [5, 7])

In [18]:
for i in range(10):
    buf.extend(range(i))

buf, buf.sample(2)

(deque([4, 5, 6, 7, 8], maxlen=5), [6, 8])

## Action Selection

In [19]:
def select_actions(q_values, epsilon):
    batch_size, n_actions = q_values.shape
    if random.random() < epsilon:
        return torch.randint(0, n_actions, [batch_size])
    return torch.argmax(q_values, dim=1).cpu()

### Testing things

In [20]:
q_values.shape

torch.Size([8, 4])

In [21]:
select_actions(q_values, 0.5)

tensor([1, 1, 1, 1, 1, 1, 1, 1])

## Training Loop

1. [x] Initialize replay buffer
1. [x] Reset `env`
1. [x] Initialize `returns` array
1. [x] Initialize `episodes` counter
1. [x] For `n_steps`
    1. [x] Take step in environment
    1. [x] (Optional) Add `rewards` to `returns` array
    1. [x] (Optional) Add `np.sum(dones | truncateds)` to episodes
    1. [x] (Optional) Log `returns` for done/truncated episodes in W&B
    1. [x] Store transition in buffer
    1. [x] If enough transitions: Update online network
    1. [x] If enough steps: Update target network

In [22]:
from torch import optim
from tqdm.auto import tqdm

In [23]:
import wandb

In [24]:
def train(
    net: nn.Module,
    target_net: nn.Module,
    optimizer: optim.Optimizer,
    loss_fn,
    env,
    epsilon_start: float,
    epsilon_end: float,
    epsilon_decay: int,
    gamma: float,
    buffer_size: int,
    update_batch_size: int,
    target_update_steps: int,
    n_steps: int,
):
    assert update_batch_size < buffer_size, "Buffer must be large enough to hold at least one batch"
    buffer = ReplayBuffer(capacity=buffer_size)
    epsilon = epsilon_start
    
    obs, _ = env.reset()
    
    num_envs = obs.shape[0]
    returns = np.zeros([num_envs])
    episodes = 0

    try:
        for step in tqdm(range(1, n_steps + 1), desc="Steps"):
            with torch.no_grad():
                q_values = net(obs)
            actions = select_actions(q_values, epsilon)
            next_obs, rewards, dones, truncateds, _ = env.step(actions)
    
            buffer.extend(zip(obs, actions, rewards, next_obs, (dones | truncateds)))
            obs = next_obs
    
            returns += rewards
            episodes += np.sum(dones | truncateds)
    
            done_indices = np.where(dones | truncateds)
            if wandb.run is not None and len(returns[done_indices]) > 0:
                wandb.log({
                    "avg_return": np.mean(returns[done_indices]),
                    "epsilon": epsilon,
                })
            returns[done_indices] = 0
    
            if len(buffer) >= update_batch_size:
                obs_target, actions_target, rewards_target, next_obs_target, dones_target = zip(*buffer.sample(update_batch_size))
    
                q_target = net(obs_target).gather(1, torch.tensor(actions_target, device=device).unsqueeze(1)).squeeze(1)
                rewards_t = torch.tensor(rewards_target, device=device)
                with torch.no_grad():
                    next_q_target = target_net(next_obs_target).max(dim=1).values
                    dones_t = torch.tensor(dones_target, device=device)
    
                target = rewards_t + ~dones_t * gamma * next_q_target
                loss = loss_fn(target, q_target)

                wandb.log({
                    "loss": loss.item(),
                })
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
            if step % target_update_steps == 0:
                target_net.load_state_dict(q_net.state_dict())
    
            epsilon = epsilon_start + (epsilon_end - epsilon_start) * min(1.0, (step / epsilon_decay))

    except KeyboardInterrupt:
        print("Training stopped manually.")
    
    if wandb.run is not None:
        wandb.unwatch()
        wandb.finish()

    env.close()

### Testing things

In [25]:
_, rewards, *_ = vecenv.step(vecenv.action_space.sample())

In [26]:
np.zeros(rewards.shape) + rewards

array([0., 0., 0., 0., 0., 0., 0., 0.])

In [27]:
obs, rewards, dones, truncateds, _ = vecenv.step(vecenv.action_space.sample())
buf.extend(zip(obs, rewards, (dones | truncateds)))

In [28]:
obs, rewards, dones = buf.sample(1)[0]

In [29]:
obs.shape, rewards, dones,

((80, 4, 105), 0.0, False)

In [30]:
rewards = np.array([1, 1, 1, 1, 1, 1, 1, 1])
dones = np.array([True, False, True, False, True, False, True, False])

In [31]:
rewards * ~dones

array([0, 1, 0, 1, 0, 1, 0, 1])

In [32]:
done_indices = np.where(dones)
done_indices

(array([0, 2, 4, 6]),)

In [33]:
rewards

array([1, 1, 1, 1, 1, 1, 1, 1])

In [34]:
rewards[done_indices] = 0
rewards

array([0, 1, 0, 1, 0, 1, 0, 1])

In [35]:
list(zip(*[["a", 1, False], ["b", 2, True]]))

[('a', 'b'), (1, 2), (False, True)]

In [36]:
obs, rewards, dones = zip(*buf.sample(2))
np.array(obs).shape, np.array(rewards), np.array(dones)

((2, 80, 4, 105), array([0., 0.], dtype=float32), array([False, False]))

## Train an agent

In [37]:
import numpy as np
import pufferlib
import pufferlib.vector

from pufferlib.environments import atari

### Experiment Configuration

In [38]:
project = "Puffer-Breakout-DQN"
env_name = "breakout"
learning_rate = 1e-4
num_envs = 8
epsilon_start = 0.75
epsilon_end = 0.01
n_steps = int(1e6)
epsilon_decay = 9 * n_steps // 10
gamma = 1.0
buffer_size = int(1e5)
update_batch_size = 32
target_update_steps = int(1e3)

In [39]:
wandb.init(
    project=project,
    config={
        "lr": learning_rate,
        "eps_start": epsilon_start,
        "eps_end": epsilon_end,
        "eps_decay": epsilon_decay,
        "gamma": gamma,
        "buffer_size": buffer_size,
        "update_batch_size": update_batch_size,
        "target_update_steps": target_update_steps,
        "n_steps": n_steps,
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfitti[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [40]:
env_creator = atari.env_creator(env_name)

vecenv = pufferlib.vector.make(
    env_creator,
    env_kwargs={
        "framestack": 4,
    },
    backend=pufferlib.vector.Multiprocessing,
    num_envs=num_envs
)

In [41]:
obs_shape = vecenv.obs_batch_shape
obs_shape

(8, 80, 4, 105)

In [42]:
n_actions = vecenv.action_space.nvec[0]
n_actions

4

In [43]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [44]:
q_net = QNetwork(obs_shape, n_actions).to(device=device)
target_net = QNetwork(obs_shape, n_actions).to(device=device)
target_net.load_state_dict(q_net.state_dict())
target_net.eval()
print(q_net, f"\n{sum(p.numel() for p in q_net.parameters()):,}")

QNetwork(
  (convs): Sequential(
    (0): Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (mlp): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=8320, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=4, bias=True)
  )
) 
1,089,332


In [45]:
if wandb.run is not None:
    wandb.watch(q_net, log="all")

In [46]:
optimizer = optim.Adam(q_net.parameters(), lr=learning_rate)

In [47]:
loss_fn = nn.SmoothL1Loss()

In [None]:
train(
    net=q_net,
    target_net=target_net,
    optimizer=optimizer,
    loss_fn=loss_fn,
    env=vecenv,
    epsilon_start=epsilon_start,
    epsilon_end=epsilon_end,
    epsilon_decay=epsilon_decay,
    gamma=gamma,
    buffer_size=buffer_size,
    update_batch_size=update_batch_size,
    target_update_steps=target_update_steps,
    n_steps=n_steps,
)

Steps:   0%|          | 0/1000000 [00:00<?, ?it/s]