# DQN from Memory

I need:
* [x] The actual Q network
* [x] A replay buffer
* [x] Action selection
* [x] Training loop

The process:
1. Set up online Q network and target network
1. Initialize (vectorized) environment
1. Perform training loop (Optional: With logging)
1. Evalute and record an episode

## The Q network

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [2]:
class QNetwork(nn.Module):
    def __init__(self, obs_shape, n_actions, n_hiddens=128):
        super().__init__()
        # The observations are Batch x width x stack (channels) x height, for some reason
        in_shape = (obs_shape[0], obs_shape[2], obs_shape[3], obs_shape[1])

        # Convolutional layers
        self.convs = nn.Sequential(
            nn.Conv2d(in_channels=in_shape[1], out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.MaxPool2d(2, 2),
        )

        # Calculate the input size for the first fully connected layer
        dummy_in = torch.zeros(*in_shape[1:])
        dummy_out = self.convs(dummy_in)
        n_input = dummy_out.shape.numel()

        # Fully connected layers
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(n_input, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, n_actions),
        )

    def forward(self, x):
        device = next(self.mlp.parameters())[0].device
        x = torch.tensor(np.array(x) / 255.0, dtype=torch.float32, device=device)
        x = x.permute(0, 2, 3, 1)
        
        x = self.convs(x)
        return self.mlp(x)

In [3]:
class PufferQNetwork(nn.Module):
    def __init__(self, n_input, n_actions, n_hiddens=128):
        super().__init__()
        # Fully connected layers
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(n_input, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, n_actions),
        )

    def forward(self, x):
        x = torch.tensor(np.array(x), dtype=torch.float32, device=device)
        return self.mlp(x)

## Replay Buffer

* Needs to be able to store some number of transitions
* Needs to automatically kick out old transitions
* Needs `store`
* Needs `extend`
* Needs `__len__()`
* Neesd `sample()`

In [4]:
from collections import deque
import random

In [5]:
class ReplayBuffer:
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.queue = deque(maxlen=capacity)

    def store(self, transition):
        self.queue.append(transition)

    def extend(self, transitions):
        for transition in transitions:
            self.store(transition)

    def sample(self, n_samples: int):
        return [random.choice(self.queue) for _ in range(n_samples)]

    def __len__(self):
        return len(self.queue)

    def __repr__(self):
        return self.queue.__repr__()

## Action Selection

In [6]:
def select_actions(q_values, epsilon):
    batch_size, n_actions = q_values.shape
    if random.random() < epsilon:
        return torch.randint(0, n_actions, [batch_size])
    return torch.argmax(q_values, dim=1).cpu()

## Training Loop

1. [x] Initialize replay buffer
1. [x] Reset `env`
1. [x] Initialize `returns` array
1. [x] Initialize `episodes` counter
1. [x] For `n_steps`
    1. [x] Take step in environment
    1. [x] (Optional) Add `rewards` to `returns` array
    1. [x] (Optional) Add `np.sum(dones | truncateds)` to episodes
    1. [x] (Optional) Log `returns` for done/truncated episodes in W&B
    1. [x] Store transition in buffer
    1. [x] If enough transitions: Update online network
    1. [x] If enough steps: Update target network

In [7]:
from torch import optim
from tqdm.auto import tqdm

from collections.abc import Iterable

In [8]:
import wandb

In [9]:
def train(
    net: nn.Module,
    target_net: nn.Module,
    optimizer: optim.Optimizer,
    loss_fn,
    env,
    epsilon_start: float,
    epsilon_end: float,
    epsilon_decay: int,
    gamma: float,
    buffer_size: int,
    update_batch_size: int,
    target_update_steps: int,
    n_steps: int,
):
    assert update_batch_size < buffer_size, "Buffer must be large enough to hold at least one batch"
    device = next(net.parameters()).device
    buffer = ReplayBuffer(capacity=buffer_size)
    epsilon = epsilon_start
    
    obs, _ = env.reset()
    obs = obs.copy()
    
    num_envs = obs.shape[0]
    returns = np.zeros([num_envs])
    episodes = 0
    total_reward = 0

    try:
        for step in tqdm(range(1, n_steps + 1), desc="Steps"):
            with torch.no_grad():
                q_values = net(obs)
            actions = select_actions(q_values, epsilon)
            next_obs, rewards, dones, truncateds, _ = env.step(actions)

            if isinstance(obs, Iterable):
                # Vectorized Environment
                buffer.extend(zip(obs, actions, rewards, next_obs, (dones | truncateds)))
                total_reward += sum(rewards)
            else:
                # Single Environment
                buffer.store((obs.copy(), actions.copy(), rewards.copy(), next_obs.copy(), (dones | truncateds).copy()))
                total_reward += rewards
                if dones:
                    next_obs, _ = env.reset()
            obs = next_obs.copy()

            #if wandb.run is not None:
                #wandb.log({
                    #"total_reward": total_reward,
                    #"epsilon": epsilon,
                #})
    
            returns += rewards
            episodes += np.sum(dones | truncateds)
    
            done_indices = np.where(dones | truncateds)
            if wandb.run is not None and len(returns[done_indices]) > 0:
                wandb.log({
                    "avg_return": np.mean(returns[done_indices]),
                    "epsilon": epsilon,
                })
            returns[done_indices] = 0
    
            if len(buffer) >= update_batch_size:
                obs_target, actions_target, rewards_target, next_obs_target, dones_target = zip(*buffer.sample(update_batch_size))
                #print("Actions:", actions_target)
    
                q_target = net(obs_target).gather(1, torch.tensor(actions_target, device=device).unsqueeze(1)).squeeze(1)
                #print("Qs:", net(obs_target))
                #print("Selected Qs:", q_target)
                rewards_t = torch.tensor(rewards_target, device=device)
                with torch.no_grad():
                    next_q_target = target_net(next_obs_target).max(dim=1).values
                    #print("Target Qs:", target_net(next_obs_target))
                    #print("Rewards:", rewards_t)
                    #print("Gamma:", gamma)
                    #print("Target max Qs:", next_q_target)
                    dones_t = torch.tensor(dones_target, device=device)
                    #print("Dones:", dones_t)
                    target = rewards_t + ~dones_t * gamma * next_q_target
                    #print("Target:", target)
                    
                loss = loss_fn(target, q_target)
                #print("Loss:", loss.item())

                wandb.log({
                    "loss": loss.item(),
                })
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
            if step % target_update_steps == 0:
                target_net.load_state_dict(net.state_dict())
    
            epsilon = epsilon_start + (epsilon_end - epsilon_start) * min(1.0, (step / epsilon_decay))

    except KeyboardInterrupt:
        print("Training stopped manually.")
    
    if wandb.run is not None:
        wandb.unwatch()
        wandb.finish()

    env.close()

## Train an agent

In [10]:
import numpy as np
import pufferlib
import pufferlib.vector

from pufferlib.environments import atari

### Experiment Configuration

In [11]:
env_name = "CartPole"
project = f"Puffer-{env_name.title()}-DQN"
learning_rate = 3e-2
num_envs = 12 * 2
epsilon_start = 0.2
epsilon_end = 0.01
n_steps = int(1e5)
epsilon_decay = 5 * n_steps // 10
gamma = 1.0
buffer_size = int(1e4)
update_batch_size = 4
target_update_steps = int(50)

In [12]:
wandb.init(
    project=project,
    config={
        "lr": learning_rate,
        "eps_start": epsilon_start,
        "eps_end": epsilon_end,
        "eps_decay": epsilon_decay,
        "gamma": gamma,
        "buffer_size": buffer_size,
        "update_batch_size": update_batch_size,
        "target_update_steps": target_update_steps,
        "n_steps": n_steps,
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfitti[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [13]:
from math import ceil
from pufferlib.environments import classic_control

env_creator = classic_control.env_creator("CartPole-v1")
#env_creator = atari.env_creator(env_name)
#from pufferlib.ocean.sanity import Squared

vecenv = pufferlib.vector.make(
    env_creator,
    backend=pufferlib.vector.Multiprocessing,
    num_envs=num_envs,
    num_workers=12,
)

#env = Squared()

Process Process-4:
Process Process-8:
Process Process-1:
Process Process-9:
Process Process-12:
Process Process-2:
Process Process-11:
Process Process-10:
Process Process-3:
Process Process-5:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314

In [14]:
obs, _ = vecenv.reset()

In [15]:
print(type(obs))

<class 'numpy.ndarray'>


In [16]:
n_input = vecenv.single_observation_space.shape[0]
n_actions = vecenv.single_action_space.n

n_input, n_actions

(4, 2)

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [18]:
net = PufferQNetwork(n_input, n_actions).to(device=device)
target_net = PufferQNetwork(n_input, n_actions).to(device=device)
target_net.load_state_dict(net.state_dict())
target_net.eval()
print(net, f"\n{sum(p.numel() for p in net.parameters()):,}")

PufferQNetwork(
  (mlp): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=4, out_features=128, bias=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=2, bias=True)
  )
) 
898


In [19]:
if wandb.run is not None:
    wandb.watch(net, log="all")

In [20]:
optimizer = optim.Adam(net.parameters(), lr=learning_rate)

In [21]:
loss_fn = nn.SmoothL1Loss()

In [22]:
train(
    net=net,
    target_net=target_net,
    optimizer=optimizer,
    loss_fn=loss_fn,
    env=vecenv,
    epsilon_start=epsilon_start,
    epsilon_end=epsilon_end,
    epsilon_decay=epsilon_decay,
    gamma=gamma,
    buffer_size=buffer_size,
    update_batch_size=update_batch_size,
    target_update_steps=target_update_steps,
    n_steps=n_steps,
)

Steps:   0%|          | 0/100000 [00:00<?, ?it/s]

Training stopped manually.


0,1
avg_return,▄▂▂▂▃█▁▁▂▂▂▁▄▂▃▂▄▁▃▁▁▂▂▁▃▂▂▃▄▁▁▁▃▂▅▅▃▆▁▁
epsilon,██▇▇▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁
loss,▂▁▁▁▂▁▁▅▁▁▁▁▄▂▃▁▁▁▁▁▂▁▃▂▂▁▁▂▄▁▆▁▃▁▂▁▁█▁▁

0,1
avg_return,13.0
epsilon,0.19589
loss,1.1967


## Agent Evaluation