In [1]:
import gymnasium as gym
import ale_py  # Ensure Atari environments work
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import collections
import random
from collections import deque
import torch.nn.functional as F
import cv2
from tqdm import tqdm
import wandb
from functools import partial
import os

from utils import get_env, wrap_recording, load_demonstrations, record_video

In [2]:
class DQN_CNN(nn.Module):
    def __init__(self, input_channels, action_dim):
        super(DQN_CNN, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=8, stride=4),  # Output: (32, 20, 20)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),  # Output: (64, 9, 9)
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),  # Output: (64, 7, 7)
            nn.ReLU()
        )


        self.fc_layers = nn.Sequential(
            nn.Linear(64*7*7, 512),  # Flattened CNN features
            nn.ReLU(),
            nn.Linear(512, action_dim)  # Output Q-values for each action
        )

    def forward(self, x):
        x = self.conv_layers(x)

        x = x.view(x.size(0), -1)  # Flatten

        x = self.fc_layers(x)
        return x

In [3]:
def select_action(env, model, state, epsilon):
    if random.random() < epsilon:
        return env.action_space.sample()  # Random action (exploration)

    state = torch.FloatTensor(np.array(state)).unsqueeze(0) / 255.0  # Normalize pixels
    state = state.to(device)
    with torch.no_grad():
        return model(state).argmax().item()

def train(model, target_model, buffer, optimizer, batch_size, gamma):
    if buffer.size() < batch_size:
        return 0
    
    # Sample batch from experience replay
    states, actions, rewards, next_states, dones = buffer.sample(batch_size)

    states = states.to(device)
    actions = actions.to(device)
    rewards = rewards.to(device)
    next_states = next_states.to(device)
    dones = dones.to(device)

    # Compute Q-values for current states
    q = model(states)
    # print('q.shape:', q.shape)
    q_values = q.gather(1, actions.unsqueeze(1)).squeeze(1)  # Select Q-values of taken actions

    # Compute next Q-values from the target network
    next_q_values = target_model(next_states).max(1)[0].detach()  # Max Q-value of next state

    dones = dones.to(torch.bool)
    # Zero next_q_values for terminal states
    next_q_values[dones] = 0.0

    # Compute target Q-values
    scaled_rewards = 0.01 * rewards
    target_q_values = scaled_rewards + gamma * next_q_values

    dq_loss = F.mse_loss(q_values, target_q_values.detach())

    # Backpropagation
    optimizer.zero_grad()
    dq_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
    optimizer.step()
    return dq_loss.item()


In [4]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((
            state,
            action,
            int(reward),
            next_state,
            bool(done)
        ))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)

        return (
            torch.FloatTensor(np.array(state)) / 255.0,  # Normalize pixels
            torch.LongTensor(action),
            torch.FloatTensor(reward),
            torch.FloatTensor(np.array(next_state)) / 255.0,
            torch.FloatTensor(done)
        )

    def size(self):
        return len(self.buffer)

In [5]:
# Create the Atari environment
env = get_env()

# Check Action / State space
obs, info = env.reset()

action_dim = env.action_space.n
print(f"Observation space: {env.observation_space}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dqn = DQN_CNN(4, action_dim).to(device)
target_dqn = DQN_CNN(4, action_dim).to(device)
target_dqn.load_state_dict(dqn.state_dict())

lr = 0.0001
weight_decay = 1e-5
replay_buffer_size = 10000
optimizer = optim.AdamW(dqn.parameters(), lr=lr, weight_decay=weight_decay)
replay_buffer = ReplayBuffer(replay_buffer_size)

num_train_iterations = 1000000
batch_size = 32
gamma = 0.99
epsilon = 0.01
target_update_freq = 10000
rewards_list = []

wandb.require("core")
wandb.login()
wandb.init(
      # Set the project where this run will be logged
      project="frogger",
      # We pass a run name (otherwise itâ€™ll be randomly assigned, like sunshine-lollypop-10)
      name=f"dqn",
      # Track hyperparameters and run metadata
      config={
      "lr": lr,
      "weight_decay": weight_decay,
      "batch_size": batch_size,
      "gamma": gamma,
      "epsilon": epsilon,
      "replay_buffer_size": replay_buffer_size,
      "variant": "dqn",
      "num_train_iterations": num_train_iterations,
      "target_update_freq": target_update_freq,
      })

A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]


Observation space: Box(0, 255, (84, 336), uint8)


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mkevinxli[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
os.makedirs('dqn/train', exist_ok=True)

optimizer = optim.AdamW(dqn.parameters(), lr=lr, weight_decay=weight_decay)
state, info = env.reset()
total_loss = 0
total_reward = 0

for iteration in range(num_train_iterations+1):
    action = select_action(env, dqn, state, epsilon)
    next_state, reward, terminated, truncated, info = env.step(action)
    total_reward += reward

    replay_buffer.push(state, action, reward, next_state, terminated)

    if terminated:
        state, info = env.reset()
        wandb.log({'train/loss': total_loss})
        total_loss = 0
        total_reward = 0
    else:
        state = next_state

    loss = train(dqn, target_dqn, replay_buffer, optimizer, batch_size, gamma)
    total_loss += loss

    rewards_list.append(total_reward)

    if iteration % target_update_freq == 0:
        target_dqn.load_state_dict(dqn.state_dict())
        torch.save(dqn.state_dict(), f"dqn/train/frogger_dqn_iter_{iteration}.pth")
        reward, length, time = record_video(select_action=partial(select_action, model=dqn, epsilon=0), video_folder="dqn/videos", episode_trigger=lambda x: True, name_prefix=f"train_iter_{iteration}")
        wandb.log({'train/reward': reward, 'train/length': length, 'train/time': time})


  logger.warn(
