In [1]:
import gymnasium as gym
import ale_py  # Ensure Atari environments work
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import collections
import random
from collections import deque
import torch.nn.functional as F
import cv2
from tqdm import tqdm
import wandb
from functools import partial
import os

from utils import get_env, record_video_oc, get_obj_classes

In [2]:
class DQN_OC(nn.Module):
    def __init__(self, num_classes, hidden_dim, action_dim):
        super(DQN_OC, self).__init__()

        self.class_embs = nn.Embedding(num_embeddings=num_classes, embedding_dim=hidden_dim, padding_idx=0)
        self.time_embs = nn.Embedding(num_embeddings=4, embedding_dim=hidden_dim)
        self.xywh_proj = nn.Linear(4, hidden_dim)

        self.encoder = nn.GRU(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=3,
        )

        self.fc_layers = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, action_dim)  # Output Q-values for each action
        )

    def forward(self, x):
        time_emb = self.time_embs(x[:, :, 0].long())
        class_emb = self.class_embs(x[:, :, 1].long())
        xywh_emb = self.xywh_proj(x[:, :, 2:])
        obj_emb = time_emb + class_emb + xywh_emb
        x, _ = self.encoder(obj_emb)
        x, _ = torch.max(x, dim=1)
        x = self.fc_layers(x)
        return x

In [3]:
def select_action(env, model, state_objs, epsilon):
    if random.random() < epsilon:
        return env.action_space.sample()  # Random action (exploration)

    with torch.no_grad():
        # Add batch dimension
        state_objs = state_objs.to_tensor()
        state_objs = state_objs.unsqueeze(0).to(device)
        return model(state_objs).squeeze().argmax().item()

def train(model, target_model, buffer, optimizer, batch_size, gamma):
    # print('train buffer.size():', buffer.size())
    # print('batch_size:', batch_size)

    if buffer.size() < batch_size:
        return 0
    
    # Sample batch from experience replay
    state_objs, actions, rewards, next_state_objs, dones = buffer.sample(batch_size)

    state_objs = state_objs.to(device)
    actions = actions.to(device)
    rewards = rewards.to(device)
    dones = dones.to(device)
    next_state_objs = next_state_objs.to(device)

    # Compute Q-values for current states
    q = model(state_objs)
    # print('q.shape:', q.shape)
    q_values = q.gather(1, actions.unsqueeze(1)).squeeze(1)  # Select Q-values of taken actions

    # Compute next Q-values from the target network
    next_q_values = target_model(next_state_objs).max(1)[0].detach()  # Max Q-value of next state

    dones = dones.to(torch.bool)
    # Zero next_q_values for terminal states
    next_q_values[dones] = 0.0

    # Compute target Q-values
    scaled_rewards = 0.01 * rewards
    target_q_values = scaled_rewards + gamma * next_q_values

    dq_loss = F.mse_loss(q_values, target_q_values.detach())

    # Backpropagation
    optimizer.zero_grad()
    dq_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
    optimizer.step()
    return dq_loss.item()


In [4]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)

    def push(self, state_objs, action, reward, next_state_objs, done):
        self.buffer.append((
            state_objs,
            action,
            int(reward),
            next_state_objs,
            bool(done)
        ))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state_objs, action, reward, next_state_objs, done = zip(*batch)

        return (
            pad_sequence((objs.to_tensor() for objs in state_objs), batch_first=True, padding_value=0.0),
            torch.LongTensor(action),
            torch.FloatTensor(reward),
            pad_sequence((objs.to_tensor() for objs in next_state_objs), batch_first=True, padding_value=0.0),
            torch.FloatTensor(done)
        )

    def size(self):
        return len(self.buffer)

In [None]:
# Create the Atari environment
env = get_env(process=False, oc=True)

# Check Action / State space
obs, info = env.reset()

action_dim = env.action_space.n

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dqn = DQN_OC(num_classes=len(get_obj_classes())+1, hidden_dim=300, action_dim=action_dim).to(device)
target_dqn = DQN_OC(num_classes=len(get_obj_classes())+1, hidden_dim=300, action_dim=action_dim).to(device)
target_dqn.load_state_dict(dqn.state_dict())

lr = 0.0001
weight_decay = 1e-5
replay_buffer_size = 10000
optimizer = optim.AdamW(dqn.parameters(), lr=lr, weight_decay=weight_decay)
replay_buffer = ReplayBuffer(replay_buffer_size)

num_train_iterations = 1000000
batch_size = 32
gamma = 0.99
epsilon = 1.0
epsilon_min = 0.05
epsilon_decay = 0.999885
target_update_freq = 10000
rewards_list = []

project_name = 'dqn_oc'

wandb.require("core")
wandb.login()
wandb.init(
      # Set the project where this run will be logged
      project="frogger",
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=project_name,
      # Track hyperparameters and run metadata
      config={
      "lr": lr,
      "weight_decay": weight_decay,
      "batch_size": batch_size,
      "gamma": gamma,
      "epsilon": epsilon,
      "epsilon_min": epsilon_min,
      "epsilon_decay": epsilon_decay,
      "replay_buffer_size": replay_buffer_size,
      "variant": project_name,
      "num_train_iterations": num_train_iterations,
      "target_update_freq": target_update_freq,
      })

A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mkevinxli[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
os.makedirs(f'{project_name}/train', exist_ok=True)

optimizer = optim.AdamW(dqn.parameters(), lr=lr, weight_decay=weight_decay)
state, info = env.reset()
total_loss = 0
total_reward = 0

for iteration in tqdm(range(num_train_iterations+1)):
    action = select_action(env, dqn, state, epsilon)
    next_state, reward, terminated, truncated, info = env.step(action)
    total_reward += reward

    replay_buffer.push(state, action, reward, next_state, terminated)

    if terminated or truncated:
        state, info = env.reset()
        wandb.log({'train/loss': total_loss})
        total_loss = 0
        total_reward = 0
    else:
        state = next_state

    loss = train(dqn, target_dqn, replay_buffer, optimizer, batch_size, gamma)
    total_loss += loss
    epsilon = max(epsilon_min, epsilon * epsilon_decay)

    rewards_list.append(total_reward)

    if iteration % target_update_freq == 0:
        target_dqn.load_state_dict(dqn.state_dict())
        torch.save(dqn.state_dict(), f"{project_name}/train/frogger_dqn_iter_{iteration}.pth")
        if iteration > 0:
            reward, length, time = record_video_oc(select_action=partial(select_action, model=dqn, epsilon=0), video_folder=f"{project_name}/videos", video_name=f"train_iter_{iteration}")
            wandb.log({'train/reward': reward, 'train/length': length, 'train/time': time})


  0%|          | 0/1000001 [00:00<?, ?it/s]

  0%|          | 1965/1000001 [00:21<2:56:45, 94.11it/s]