In [1]:
import gymnasium as gym
import ale_py  # Ensure Atari environments work
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import collections
import random
from collections import deque
import torch.nn.functional as F
import cv2
from tqdm import tqdm
import wandb
from functools import partial
import os

from utils import get_env, record_video_oc, get_obj_classes, extract_objs

In [2]:
class DQN_OC(nn.Module):
    def __init__(self, num_classes, hidden_dim, action_dim):
        super(DQN_OC, self).__init__()

        self.class_embs = nn.Embedding(num_embeddings=num_classes, embedding_dim=hidden_dim, padding_idx=0)
        self.xywh_proj = nn.Linear(4, hidden_dim)

        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=4, batch_first=True),
            num_layers=3
        )

        self.fc_layers = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, action_dim)  # Output Q-values for each action
        )

    def forward(self, x):
        class_emb = self.class_embs(x[:, :, 0].long())
        xywh_emb = self.xywh_proj(x[:, :, 1:])
        obj_emb = class_emb + xywh_emb
        # print('obj_emb.shape', obj_emb.shape)
        x = self.transformer_encoder(obj_emb)
        # print('x.shape', x.shape)
        x, _ = torch.max(x, dim=1)
        # print('x.shape', x.shape)
        x = self.fc_layers(x)
        # print('x.shape', x.shape)
        return x

In [3]:
def select_action(env, model, state_objs, epsilon):
    if random.random() < epsilon:
        return env.action_space.sample()  # Random action (exploration)

    with torch.no_grad():
        # Add batch dimension
        state_objs = state_objs.unsqueeze(0).to(device)
        return model(state_objs).squeeze().argmax().item()

def train(model, target_model, buffer, optimizer, batch_size, gamma):
    # print('train buffer.size():', buffer.size())
    # print('batch_size:', batch_size)

    if buffer.size() < batch_size:
        return 0
    
    # Sample batch from experience replay
    state_objs, actions, rewards, next_state_objs, dones = buffer.sample(batch_size)

    state_objs = state_objs.to(device)
    actions = actions.to(device)
    rewards = rewards.to(device)
    dones = dones.to(device)
    next_state_objs = next_state_objs.to(device)

    # Compute Q-values for current states
    q = model(state_objs)
    # print('q.shape:', q.shape)
    q_values = q.gather(1, actions.unsqueeze(1)).squeeze(1)  # Select Q-values of taken actions

    # Compute next Q-values from the target network
    next_q_values = target_model(next_state_objs).max(1)[0].detach()  # Max Q-value of next state

    dones = dones.to(torch.bool)
    # Zero next_q_values for terminal states
    next_q_values[dones] = 0.0

    # Compute target Q-values
    scaled_rewards = 0.01 * rewards
    target_q_values = scaled_rewards + gamma * next_q_values

    dq_loss = F.mse_loss(q_values, target_q_values.detach())

    # Backpropagation
    optimizer.zero_grad()
    dq_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
    optimizer.step()
    return dq_loss.item()


In [4]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)

    def push(self, state_objs, action, reward, next_state_objs, done):
        self.buffer.append((
            state_objs,
            action,
            int(reward),
            next_state_objs,
            bool(done)
        ))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state_objs, action, reward, next_state_objs, done = zip(*batch)

        return (
            pad_sequence(state_objs, batch_first=True, padding_value=0.0),
            torch.LongTensor(action),
            torch.FloatTensor(reward),
            pad_sequence(next_state_objs, batch_first=True, padding_value=0.0),
            torch.FloatTensor(done)
        )

    def size(self):
        return len(self.buffer)

In [5]:
# Create the Atari environment
env = get_env(process=False, oc=True)

# Check Action / State space
obs, info = env.reset()

action_dim = env.action_space.n
print(f"Observation space: {env.observation_space}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dqn = DQN_OC(num_classes=len(get_obj_classes())+1, hidden_dim=64, action_dim=action_dim).to(device)
target_dqn = DQN_OC(num_classes=len(get_obj_classes())+1, hidden_dim=64, action_dim=action_dim).to(device)
target_dqn.load_state_dict(dqn.state_dict())

lr = 0.0001
weight_decay = 1e-5
replay_buffer_size = 10000
optimizer = optim.AdamW(dqn.parameters(), lr=lr, weight_decay=weight_decay)
replay_buffer = ReplayBuffer(replay_buffer_size)

num_train_iterations = 1000000
batch_size = 32
gamma = 0.99
epsilon = 1.0
epsilon_min = 0.05
epsilon_decay = 0.999885
target_update_freq = 10000
rewards_list = []

project_name = 'dqn_oc'

wandb.require("core")
wandb.login()
wandb.init(
      # Set the project where this run will be logged
      project="frogger",
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=project_name,
      # Track hyperparameters and run metadata
      config={
      "lr": lr,
      "weight_decay": weight_decay,
      "batch_size": batch_size,
      "gamma": gamma,
      "epsilon": epsilon,
      "epsilon_min": epsilon_min,
      "epsilon_decay": epsilon_decay,
      "replay_buffer_size": replay_buffer_size,
      "variant": project_name,
      "num_train_iterations": num_train_iterations,
      "target_update_freq": target_update_freq,
      })

A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]


Observation space: Box(0, 255, (210, 160, 3), uint8)


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mkevinxli[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
os.makedirs(f'{project_name}/train', exist_ok=True)

optimizer = optim.AdamW(dqn.parameters(), lr=lr, weight_decay=weight_decay)
state, info = env.reset()
state_objs = extract_objs(env, return_tensor=True)
total_loss = 0
total_reward = 0

for iteration in tqdm(range(num_train_iterations+1)):
    action = select_action(env, dqn, state_objs, epsilon)
    next_state, reward, terminated, truncated, info = env.step(action)
    next_state_objs = extract_objs(env, return_tensor=True)
    total_reward += reward

    replay_buffer.push(state_objs, action, reward, next_state_objs, terminated)

    if terminated or truncated:
        state, info = env.reset()
        state_objs = extract_objs(env, return_tensor=True)
        wandb.log({'train/loss': total_loss})
        total_loss = 0
        total_reward = 0
    else:
        state = next_state
        state_objs = next_state_objs

    loss = train(dqn, target_dqn, replay_buffer, optimizer, batch_size, gamma)
    total_loss += loss
    epsilon = max(epsilon_min, epsilon * epsilon_decay)

    rewards_list.append(total_reward)

    if iteration % target_update_freq == 0:
        target_dqn.load_state_dict(dqn.state_dict())
        torch.save(dqn.state_dict(), f"{project_name}/train/frogger_dqn_iter_{iteration}.pth")
        reward, length, time = record_video_oc(select_action=partial(select_action, model=dqn, epsilon=0), video_folder=f"{project_name}/videos", video_name=f"train_iter_{iteration}")
        wandb.log({'train/reward': reward, 'train/length': length, 'train/time': time})


  0%|          | 0/1000001 [00:00<?, ?it/s]

iteration 0 action 1
iteration 1 action 1
iteration 2 action 1
iteration 3 action 1
iteration 4 action 1
iteration 5 action 1
iteration 6 action 1
iteration 7 action 1
iteration 8 action 1
iteration 9 action 1
iteration 10 action 1
iteration 11 action 1
iteration 12 action 1
iteration 13 action 1
iteration 14 action 1
iteration 15 action 1
iteration 16 action 1
iteration 17 action 1
iteration 18 action 1
iteration 19 action 1
iteration 20 action 1
iteration 21 action 1
iteration 22 action 1
iteration 23 action 1
iteration 24 action 1
iteration 25 action 1
iteration 26 action 1
iteration 27 action 1
iteration 28 action 1
iteration 29 action 1
iteration 30 action 1
iteration 31 action 1
iteration 32 action 1
iteration 33 action 1
iteration 34 action 1
iteration 35 action 1
iteration 36 action 1
iteration 37 action 1
iteration 38 action 1
iteration 39 action 1
iteration 40 action 1
iteration 41 action 1
iteration 42 action 1
iteration 43 action 1
iteration 44 action 1
iteration 45 action 

  0%|          | 0/1000001 [00:37<?, ?it/s]

MoviePy - Building video dqn_oc/videos/train_iter_0.mp4.
MoviePy - Writing video dqn_oc/videos/train_iter_0.mp4



  0%|          | 25/1000001 [01:08<543:26:28,  1.96s/it] 

MoviePy - Done !
MoviePy - video ready dqn_oc/videos/train_iter_0.mp4


  1%|          | 9997/1000001 [03:31<4:00:33, 68.59it/s]

iteration 0 action 1
iteration 1 action 1
iteration 2 action 1
iteration 3 action 1
iteration 4 action 1
iteration 5 action 1
iteration 6 action 1
iteration 7 action 1
iteration 8 action 1
iteration 9 action 1
iteration 10 action 1
iteration 11 action 1
iteration 12 action 1
iteration 13 action 1
iteration 14 action 1
iteration 15 action 2
iteration 16 action 1
iteration 17 action 2
iteration 18 action 1
iteration 19 action 1
iteration 20 action 1
iteration 21 action 1
iteration 22 action 1
iteration 23 action 1
iteration 24 action 4
iteration 25 action 1
iteration 26 action 2
iteration 27 action 1
iteration 28 action 1
iteration 29 action 1
iteration 30 action 1
iteration 31 action 1
iteration 32 action 1
iteration 33 action 1
iteration 34 action 1
iteration 35 action 1
iteration 36 action 1
iteration 37 action 1
iteration 38 action 1
iteration 39 action 1
iteration 40 action 1
iteration 41 action 1
iteration 42 action 1
iteration 43 action 1
iteration 44 action 1
iteration 45 action 

  1%|          | 9997/1000001 [03:36<4:00:33, 68.59it/s]

iteration 296 action 1
MoviePy - Building video dqn_oc/videos/train_iter_10000.mp4.
MoviePy - Writing video dqn_oc/videos/train_iter_10000.mp4



  1%|          | 10011/1000001 [03:41<78:05:01,  3.52it/s] 

MoviePy - Done !
MoviePy - video ready dqn_oc/videos/train_iter_10000.mp4


  2%|▏         | 19996/1000001 [06:12<4:10:26, 65.22it/s] 

iteration 0 action 1
iteration 1 action 1
iteration 2 action 4
iteration 3 action 2
iteration 4 action 1
iteration 5 action 4
iteration 6 action 1
iteration 7 action 3
iteration 8 action 3
iteration 9 action 2
iteration 10 action 1
iteration 11 action 2
iteration 12 action 2
iteration 13 action 3
iteration 14 action 4
iteration 15 action 1
iteration 16 action 1
iteration 17 action 3
iteration 18 action 2
iteration 19 action 1
iteration 20 action 2
iteration 21 action 0
iteration 22 action 4
iteration 23 action 1
iteration 24 action 2
iteration 25 action 1
iteration 26 action 2
iteration 27 action 1
iteration 28 action 1
iteration 29 action 1
iteration 30 action 1
iteration 31 action 1
iteration 32 action 3
iteration 33 action 1
iteration 34 action 1
iteration 35 action 1
iteration 36 action 1
iteration 37 action 1
iteration 38 action 4
iteration 39 action 4
iteration 40 action 2
iteration 41 action 2
iteration 42 action 1
iteration 43 action 1
iteration 44 action 3
iteration 45 action 

  2%|▏         | 19996/1000001 [06:17<4:10:26, 65.22it/s]

iteration 259 action 1
iteration 260 action 1
iteration 261 action 4
iteration 262 action 1
iteration 263 action 4
iteration 264 action 1
iteration 265 action 1
iteration 266 action 1
iteration 267 action 1
iteration 268 action 2
iteration 269 action 1
MoviePy - Building video dqn_oc/videos/train_iter_20000.mp4.
MoviePy - Writing video dqn_oc/videos/train_iter_20000.mp4



  2%|▏         | 20009/1000001 [06:21<77:58:00,  3.49it/s] 

MoviePy - Done !
MoviePy - video ready dqn_oc/videos/train_iter_20000.mp4


  3%|▎         | 29999/1000001 [08:54<4:17:35, 62.76it/s] 

iteration 0 action 4
iteration 1 action 1
iteration 2 action 1
iteration 3 action 4
iteration 4 action 3
iteration 5 action 1
iteration 6 action 1
iteration 7 action 3
iteration 8 action 1
iteration 9 action 1
iteration 10 action 1
iteration 11 action 1
iteration 12 action 4
iteration 13 action 1
iteration 14 action 1
iteration 15 action 1
iteration 16 action 3
iteration 17 action 1
iteration 18 action 0
iteration 19 action 1
iteration 20 action 1
iteration 21 action 1
iteration 22 action 1
iteration 23 action 3
iteration 24 action 1
iteration 25 action 1
iteration 26 action 1
iteration 27 action 0
iteration 28 action 4
iteration 29 action 1
iteration 30 action 1
iteration 31 action 1
iteration 32 action 1
iteration 33 action 1
iteration 34 action 1
iteration 35 action 1
iteration 36 action 4
iteration 37 action 1
iteration 38 action 1
iteration 39 action 1
iteration 40 action 1
iteration 41 action 1
iteration 42 action 1
iteration 43 action 1
iteration 44 action 1
iteration 45 action 

  3%|▎         | 29999/1000001 [08:58<4:17:35, 62.76it/s]

iteration 228 action 4
iteration 229 action 1
iteration 230 action 1
iteration 231 action 4
iteration 232 action 3
iteration 233 action 3
iteration 234 action 1
MoviePy - Building video dqn_oc/videos/train_iter_30000.mp4.
MoviePy - Writing video dqn_oc/videos/train_iter_30000.mp4



  3%|▎         | 30013/1000001 [09:02<64:24:26,  4.18it/s]

MoviePy - Done !
MoviePy - video ready dqn_oc/videos/train_iter_30000.mp4


  4%|▍         | 39996/1000001 [11:35<4:09:36, 64.10it/s] 

iteration 0 action 1
iteration 1 action 3
iteration 2 action 1
iteration 3 action 2
iteration 4 action 1
iteration 5 action 1
iteration 6 action 1
iteration 7 action 1
iteration 8 action 1
iteration 9 action 1
iteration 10 action 1
iteration 11 action 1
iteration 12 action 1
iteration 13 action 3
iteration 14 action 3
iteration 15 action 1
iteration 16 action 1
iteration 17 action 1
iteration 18 action 2
iteration 19 action 3
iteration 20 action 4
iteration 21 action 1
iteration 22 action 1
iteration 23 action 1
iteration 24 action 1
iteration 25 action 3
iteration 26 action 3
iteration 27 action 2
iteration 28 action 1
iteration 29 action 1
iteration 30 action 1
iteration 31 action 1
iteration 32 action 1
iteration 33 action 1
iteration 34 action 1
iteration 35 action 1
iteration 36 action 1
iteration 37 action 1
iteration 38 action 1
iteration 39 action 1
iteration 40 action 2
iteration 41 action 4
iteration 42 action 1
iteration 43 action 1
iteration 44 action 1
iteration 45 action 

  4%|▍         | 39996/1000001 [11:41<4:09:36, 64.10it/s]

iteration 326 action 1
iteration 327 action 1
iteration 328 action 1
iteration 329 action 1
iteration 330 action 1
iteration 331 action 1
iteration 332 action 3
MoviePy - Building video dqn_oc/videos/train_iter_40000.mp4.
MoviePy - Writing video dqn_oc/videos/train_iter_40000.mp4



  4%|▍         | 40010/1000001 [11:46<88:34:44,  3.01it/s] 

MoviePy - Done !
MoviePy - video ready dqn_oc/videos/train_iter_40000.mp4


  5%|▍         | 49996/1000001 [14:19<4:02:01, 65.42it/s] 

iteration 0 action 1
iteration 1 action 2
iteration 2 action 4
iteration 3 action 2
iteration 4 action 2
iteration 5 action 3
iteration 6 action 1
iteration 7 action 3
iteration 8 action 1
iteration 9 action 4
iteration 10 action 2
iteration 11 action 1
iteration 12 action 2
iteration 13 action 1
iteration 14 action 2
iteration 15 action 1
iteration 16 action 1
iteration 17 action 1
iteration 18 action 2
iteration 19 action 2
iteration 20 action 1
iteration 21 action 0
iteration 22 action 2
iteration 23 action 3
iteration 24 action 1
iteration 25 action 2
iteration 26 action 1
iteration 27 action 1
iteration 28 action 3
iteration 29 action 1
iteration 30 action 1
iteration 31 action 1
iteration 32 action 1
iteration 33 action 1
iteration 34 action 4
iteration 35 action 1
iteration 36 action 1
iteration 37 action 3
iteration 38 action 2
iteration 39 action 1
iteration 40 action 2
iteration 41 action 1
iteration 42 action 1
iteration 43 action 1
iteration 44 action 1
iteration 45 action 

  5%|▍         | 49996/1000001 [14:23<4:02:01, 65.42it/s]

iteration 235 action 1
MoviePy - Building video dqn_oc/videos/train_iter_50000.mp4.
MoviePy - Writing video dqn_oc/videos/train_iter_50000.mp4



  5%|▌         | 50010/1000001 [14:27<63:32:17,  4.15it/s]

MoviePy - Done !
MoviePy - video ready dqn_oc/videos/train_iter_50000.mp4


  6%|▌         | 60000/1000001 [17:01<4:05:11, 63.90it/s] 

iteration 0 action 4
iteration 1 action 1
iteration 2 action 1
iteration 3 action 0
iteration 4 action 4
iteration 5 action 1
iteration 6 action 1
iteration 7 action 1
iteration 8 action 1
iteration 9 action 1
iteration 10 action 4
iteration 11 action 4
iteration 12 action 3
iteration 13 action 1
iteration 14 action 1
iteration 15 action 1
iteration 16 action 1
iteration 17 action 0
iteration 18 action 3
iteration 19 action 0
iteration 20 action 1
iteration 21 action 3
iteration 22 action 1
iteration 23 action 1
iteration 24 action 1
iteration 25 action 1
iteration 26 action 3
iteration 27 action 1
iteration 28 action 3
iteration 29 action 1
iteration 30 action 3
iteration 31 action 3
iteration 32 action 1
iteration 33 action 1
iteration 34 action 3
iteration 35 action 3
iteration 36 action 1
iteration 37 action 3
iteration 38 action 0
iteration 39 action 1
iteration 40 action 1
iteration 41 action 3
iteration 42 action 1
iteration 43 action 1
iteration 44 action 1
iteration 45 action 

  6%|▌         | 60000/1000001 [17:04<4:05:11, 63.90it/s]

iteration 207 action 3
iteration 208 action 4
MoviePy - Building video dqn_oc/videos/train_iter_60000.mp4.
MoviePy - Writing video dqn_oc/videos/train_iter_60000.mp4



  6%|▌         | 60007/1000001 [17:07<78:11:33,  3.34it/s]

MoviePy - Done !
MoviePy - video ready dqn_oc/videos/train_iter_60000.mp4


  6%|▌         | 62483/1000001 [17:46<4:26:37, 58.60it/s] 


KeyboardInterrupt: 