In [131]:
# gym imports
import gym
from gym import spaces

# helpers
import random
import numpy as np
import pandas as pd
from collections import deque
import matplotlib.pyplot as plt
%matplotlib inline

# torch imports
import torch
import torch.nn as nn
import torch.optim as optim
from torchview import draw_graph

In [132]:
# Check for GPU availability and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [133]:
env = gym.make("CarRacing-v2", continuous=False)

In [134]:
class QNetwork(nn.Module):
    def __init__(self, observation_space: spaces.Box, feature_dim):
        super(QNetwork, self).__init__()
        n_channels = observation_space.shape[0]
        self.conv_layers = nn.Sequential(
            nn.Conv2d(n_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
        )

        with torch.no_grad():
            dummy_input = torch.zeros(1, *observation_space.shape)
            n_flatten = self.conv_layers(dummy_input).shape[1]


        self.fc_layers = nn.Sequential(
            nn.Linear(n_flatten, 128),
            nn.ReLU(),
            nn.Linear(128, feature_dim),
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

In [135]:
observation_space = gym.spaces.Box(low=0, high=255, shape=(3, 96, 96), dtype=np.float32)
feature_dim = 5

q_net = QNetwork(observation_space=observation_space, feature_dim=feature_dim).to(device)
target_net = QNetwork(observation_space=observation_space, feature_dim=feature_dim).to(device)
target_net.load_state_dict(q_net.state_dict())
target_net.eval()

QNetwork(
  (conv_layers): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (fc_layers): Sequential(
    (0): Linear(in_features=18432, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=5, bias=True)
  )
)

In [136]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(q_net.parameters())

In [137]:
# Parameters
GAMMA = 0.99
EPSILON = 1.0
BATCH_SIZE = 64
NUM_EPISODES = 10
EPSILON_DECAY = 1.005
MAX_TRANSITIONS = 100_000
LEARN_AFTER_STEPS = 4
TARGET_UPDATE_AFTER = 1000

In [138]:
REPLAY_BUFFER = deque(maxlen=MAX_TRANSITIONS)

In [139]:
def insert_transition(transition):
    REPLAY_BUFFER.append(transition)

In [140]:
def sample_transitions(batch_size=16):
    sampled = random.sample(REPLAY_BUFFER, batch_size)
    states, actions, rewards, next_states, dones = zip(*sampled)

    return (
        torch.tensor(states, dtype=torch.float32,device=device).permute(0,3,1,2),
        torch.tensor(actions, dtype=torch.int64,device=device),
        torch.tensor(rewards, dtype=torch.float32,device=device),
        torch.tensor(next_states, dtype=torch.float32,device=device).permute(0,3,1,2),
        torch.tensor(dones, dtype=torch.bool,device=device),
    )

In [141]:
def policy(state, explore=0.0):
    if np.random.rand() <= explore:
        return np.random.randint(5)

    with torch.no_grad():
        state_tensor = torch.tensor([state], dtype=torch.float32, device=device)
        state_tensor = state_tensor.permute(0,3,1,2)

        return q_net(state_tensor).argmax().item()

In [142]:
random_states = []
done = False
state = env.reset()[0]

for _ in range(20):
    if not done:
        random_states.append(state)
        state, _, terminated, truncated, _ = env.step(policy(state))
        done = terminated or truncated

random_states = torch.tensor(random_states, dtype=torch.float32, device=device)

  if not isinstance(terminated, (bool, np.bool8)):


In [143]:
def get_q_values(states):
    with torch.no_grad():
        states = states.permute(0,3,1,2)
        return q_net(states).max(dim=1)[0]

In [144]:
step_counter = 0
metric = {
    "episode": [],
    "length": [],
    "total_reward": [],
    "avg_q": [],
    "exploration": [],
}

In [145]:
for episode in range(NUM_EPISODES):
    state = env.reset()[0]
    done = False
    total_rewards = 0
    episode_length = 0

    while not done:
        action = policy(state=state, explore=EPSILON)
        next_state, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        insert_transition((state, action, reward, next_state, done))
        state = next_state
        step_counter += 1

        if step_counter >= BATCH_SIZE and step_counter % LEARN_AFTER_STEPS == 0:
            states, actions, rewards, next_states, dones = sample_transitions(BATCH_SIZE)

            # Compute targets
            with torch.no_grad():
                next_action_values = target_net(next_states).argmax().item()
                targets = rewards + GAMMA * next_action_values * (~dones)

            # Compute Q values for the selected actions
            preds = q_net(states)
            current_values = preds.gather(1, actions.unsqueeze(1)).squeeze()

            # Compute loss and backprop
            loss = loss_fn(current_values, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if step_counter % TARGET_UPDATE_AFTER == 0:
            target_net.load_state_dict(q_net.state_dict())

        total_rewards += reward
        episode_length += 1

    # Save metrics
    avg_q = get_q_values(random_states).mean().item()
    metric["episode"].append(episode)
    metric["length"].append(episode_length)
    metric["total_reward"].append(total_rewards)
    metric["avg_q"].append(avg_q)
    metric["exploration"].append(EPSILON)

    EPSILON /= EPSILON_DECAY

    print(f"episode: {episode}, episode_length: {episode_length}, total_reward: {total_rewards}, avg_q: {avg_q}")

    pd.DataFrame(metric).to_csv("metric.csv", index=False)

env.close()
torch.save(q_net.state_dict(), "dqn_q_net.pth")

episode: 0, episode_length: 1000, total_reward: -66.8674698795188, avg_q: 0.503722608089447
episode: 1, episode_length: 1000, total_reward: -62.732919254659194, avg_q: 0.4370291233062744
episode: 2, episode_length: 1000, total_reward: -45.84837545126425, avg_q: 0.3730293810367584
episode: 3, episode_length: 1000, total_reward: -57.59717314487714, avg_q: 0.1820891946554184
episode: 4, episode_length: 1000, total_reward: -61.194029850747114, avg_q: 0.1603076308965683
episode: 5, episode_length: 1000, total_reward: -45.45454545454622, avg_q: 0.13089799880981445
episode: 6, episode_length: 1000, total_reward: -46.66666666666738, avg_q: 0.12478765100240707
episode: 7, episode_length: 1000, total_reward: -56.67870036101169, avg_q: 0.11879073828458786
episode: 8, episode_length: 1000, total_reward: -59.86622073578678, avg_q: 0.09934001415967941
episode: 9, episode_length: 1000, total_reward: -60.52631578947451, avg_q: 0.09613050520420074


# Inference
- With this speed it's going to take ages for DQN to get trained.
- Let's try continuous RL Algorithms such as A2C/A3C