### DQN (Deep Q-Learning)
This notebook implements the DQN algorithm that learn to solve the CartePole environement using Pytorch

In [21]:
# Hyperparameters
GAMMA = 0.99
BATCH_SIZE = 32
BUFFER_SIZE = 50000
MIN_REPLAY_SIZE = 1000
EPS_START = 1.0
EPS_END = 0.02
EPS_DECAY = 10000
TARGET_UPDATE_FREQ = 1000
N_STEPS = 25000

In [3]:
import torch
import torch.nn as nn
import gymnasium as gym
from collections import deque
import itertools
import numpy as np
import random
import tqdm

In [15]:
class Network(nn.Module):
    def __init__(self, env):
        super(Network, self).__init__()
        in_features = int(np.prod(env.observation_space.shape))
        self.net = nn.Sequential(
            nn.Linear(in_features, 64),
            nn.Tanh(),
            nn.Linear(64, env.action_space.n)
        )

    def forward(self, x):
        return self.net(x)

    def act(self, state):
        state = torch.as_tensor(state).float().unsqueeze(0)
        q_values = self.forward(state)
        max_q_idx = torch.argmax(q_values, dim=1)[0]
        action = max_q_idx.item()
        return action

In [22]:
env = gym.make("CartPole-v1")
replay_buffer = deque(maxlen=BUFFER_SIZE)
reward_buffer = deque([0.0], maxlen=100)
episode_reward = 0.0

In [23]:
online_net = Network(env)
target_net = Network(env)

target_net.load_state_dict(online_net.state_dict())

<All keys matched successfully>

In [24]:
optimizer = torch.optim.Adam(online_net.parameters(), lr=5e-4) 

In [25]:
# Initialize the relay buffer 
obs = env.reset()[0]
for _ in range(MIN_REPLAY_SIZE):
    action = env.action_space.sample()
    new_obs, reward, terminated, truncated, _ = env.step(action)
    transition = (obs, action, reward, terminated, truncated, new_obs)
    replay_buffer.append(transition)
    obs = new_obs
    if terminated or truncated:
        obs = env.reset()[0]

In [26]:
# Training Loop
obs = env.reset()[0]
for step in tqdm.tqdm(range(N_STEPS)):
    eps = np.interp(step, [0, EPS_DECAY], [EPS_START, EPS_END])
    if random.random() <= eps:
        action = env.action_space.sample()
    else:
        action = online_net.act(obs)
    new_obs, reward, terminated, truncated, _ = env.step(action)
    transition = (obs, action, reward, terminated, truncated, new_obs)
    replay_buffer.append(transition)
    obs = new_obs
    episode_reward += reward

    if terminated or truncated:
        obs = env.reset()[0]
        reward_buffer.append(episode_reward)
        episode_reward = 0.0

    # Start the gradient step
    # Sample BATCH_SIZE transitions
    transitions = random.sample(replay_buffer, BATCH_SIZE)
    observations = np.asarray([ts[0] for ts in transitions])
    actions = np.asarray([ts[1] for ts in transitions])
    rewards = np.asarray([ts[2] for ts in transitions])
    terminateds = np.asarray([ts[3] for ts in transitions])
    truncateds = np.asarray([ts[4] for ts in transitions])
    new_observations = np.asarray([ts[5] for ts in transitions])
    # Convert to tensors
    observations = torch.as_tensor(observations, dtype=torch.float32)
    actions = torch.as_tensor(actions, dtype=torch.int64).unsqueeze(-1)
    rewards = torch.as_tensor(rewards, dtype=torch.float32).unsqueeze(-1)
    terminateds = torch.as_tensor(terminateds, dtype=torch.float32).unsqueeze(-1)
    truncateds = torch.as_tensor(truncateds, dtype=torch.float32).unsqueeze(-1)
    new_observations = torch.as_tensor(new_observations, dtype=torch.float32)
    target_q_values = target_net(new_observations)
    # Get the max of q_values per observation
    max_target_q_values = target_q_values.max(dim=1, keepdim=True)[0]
    # Calculate the y_j
    targets = rewards + GAMMA * (1 - terminateds) * (1 - truncateds) * max_target_q_values
    q_values = online_net(observations)
    action_q_values = torch.gather(input=q_values, dim=1, index=actions)

    loss = nn.functional.smooth_l1_loss(action_q_values, targets)
    # Gradient Descent step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Update the target network
    if (step % TARGET_UPDATE_FREQ == 0):
        target_net.load_state_dict(online_net.state_dict())
    if step % 1000 == 0:
        print(f"Step {step}, Average Reward : {np.mean(reward_buffer)}")

  0%|                                                                              | 38/25000 [00:00<01:07, 369.06it/s]

Step 0, Average Reward : 0.0


  4%|███▎                                                                        | 1107/25000 [00:02<00:48, 490.72it/s]

Step 1000, Average Reward : 23.0


  8%|██████▎                                                                     | 2087/25000 [00:04<00:45, 499.04it/s]

Step 2000, Average Reward : 20.13131313131313


 12%|█████████▎                                                                  | 3055/25000 [00:06<00:45, 485.96it/s]

Step 3000, Average Reward : 21.22


 16%|████████████▎                                                               | 4057/25000 [00:08<00:44, 472.36it/s]

Step 4000, Average Reward : 25.01


 20%|███████████████▎                                                            | 5053/25000 [00:10<00:40, 487.96it/s]

Step 5000, Average Reward : 31.38


 24%|██████████████████▍                                                         | 6073/25000 [00:12<00:39, 476.19it/s]

Step 6000, Average Reward : 38.52


 28%|█████████████████████▍                                                      | 7037/25000 [00:14<00:38, 466.56it/s]

Step 7000, Average Reward : 45.96


 32%|████████████████████████▍                                                   | 8054/25000 [00:17<00:38, 438.77it/s]

Step 8000, Average Reward : 53.91


 36%|███████████████████████████▌                                                | 9074/25000 [00:19<00:35, 447.04it/s]

Step 9000, Average Reward : 62.34


 40%|██████████████████████████████▎                                            | 10095/25000 [00:21<00:32, 459.10it/s]

Step 10000, Average Reward : 71.0


 44%|█████████████████████████████████▏                                         | 11083/25000 [00:23<00:31, 446.18it/s]

Step 11000, Average Reward : 79.93


 48%|████████████████████████████████████▏                                      | 12056/25000 [00:26<00:28, 453.76it/s]

Step 12000, Average Reward : 86.52


 52%|███████████████████████████████████████▏                                   | 13064/25000 [00:28<00:26, 457.68it/s]

Step 13000, Average Reward : 98.47


 56%|██████████████████████████████████████████▏                                | 14076/25000 [00:30<00:23, 466.65it/s]

Step 14000, Average Reward : 107.24


 60%|█████████████████████████████████████████████▏                             | 15056/25000 [00:32<00:20, 494.75it/s]

Step 15000, Average Reward : 114.91


 64%|████████████████████████████████████████████████▏                          | 16073/25000 [00:34<00:18, 470.36it/s]

Step 16000, Average Reward : 124.87


 68%|███████████████████████████████████████████████████▏                       | 17068/25000 [00:36<00:16, 479.09it/s]

Step 17000, Average Reward : 133.28


 72%|██████████████████████████████████████████████████████▏                    | 18048/25000 [00:39<00:16, 414.86it/s]

Step 18000, Average Reward : 142.51


 76%|█████████████████████████████████████████████████████████                  | 19036/25000 [00:41<00:13, 451.96it/s]

Step 19000, Average Reward : 150.29


 80%|████████████████████████████████████████████████████████████▏              | 20050/25000 [00:43<00:10, 459.25it/s]

Step 20000, Average Reward : 159.39


 84%|███████████████████████████████████████████████████████████████▏           | 21071/25000 [00:45<00:09, 429.56it/s]

Step 21000, Average Reward : 168.04


 88%|██████████████████████████████████████████████████████████████████▏        | 22073/25000 [00:48<00:06, 454.96it/s]

Step 22000, Average Reward : 172.87


 92%|█████████████████████████████████████████████████████████████████████▏     | 23075/25000 [00:50<00:04, 447.97it/s]

Step 23000, Average Reward : 182.9


 96%|████████████████████████████████████████████████████████████████████████▏  | 24075/25000 [00:52<00:02, 454.84it/s]

Step 24000, Average Reward : 191.87


100%|███████████████████████████████████████████████████████████████████████████| 25000/25000 [00:54<00:00, 457.84it/s]


In [28]:
from IPython import display
env = gym.make("CartPole-v1", render_mode="human")
state = env.reset()[0]
for t in range(1000):
    env.render()
    action = online_net.act(state)
    state, reward, terminated, truncated, _ = env.step(action)
    if terminated or truncated:
        state = env.reset()[0]
env.close()