In [1]:
import gymnasium as gym
import ale_py  # Ensure Atari environments work
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import collections
import random
from collections import deque
import torch.nn.functional as F
import cv2
from tqdm import tqdm
import wandb
from functools import partial
import os

from utils import get_env, wrap_recording, load_demonstrations, record_video, layer_init, load_llm_demonstrations

In [2]:
class DQN_CNN(nn.Module):
    def __init__(self, input_channels, action_dim):
        super(DQN_CNN, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=8, stride=4),  # Output: (32, 20, 20)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),  # Output: (64, 9, 9)
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),  # Output: (64, 7, 7)
            nn.ReLU()
        )

        self.fc_layers = nn.Sequential(
            nn.Linear(64*7*7, 512),  # Flattened CNN features
            nn.ReLU(),
            nn.Linear(512, action_dim)  # Output Q-values for each action
        )

    def forward(self, x):
        x = self.conv_layers(x/255.0)

        x = x.view(x.size(0), -1)  # Flatten

        x = self.fc_layers(x)
        return x

In [3]:
# Based on https://github.com/k4ntz/OC_Atari/blob/5386289258c6f240bd107dea0fe41512262281b2/ocatari/utils.py#L122
class DQN_MLP(nn.Module):
    def __init__(self, input_size, framestack, action_dim):
        super().__init__()
        self.network = nn.Sequential(
            layer_init(nn.Linear(input_size, 512)),
            nn.ReLU(),
            layer_init(nn.Linear(512, 256)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(256*framestack, 512)),
            nn.ReLU(),
        )
        self.fc_layer = layer_init(
            nn.Linear(512, action_dim), std=0.01)
    
    def forward(self, x):
        x = self.network(x/255.0)
        x = self.fc_layer(x)
        return x

In [4]:
def select_action(env, model, state, epsilon):
    if random.random() < epsilon:
        return env.action_space.sample()  # Random action (exploration)

    state = torch.FloatTensor(np.array(state)).unsqueeze(0)
    state = state.to(device)
    with torch.no_grad():
        return model(state).argmax().item()

def train(model, target_model, buffer, optimizer, batch_size, gamma, supervised, use_demo):
    if not supervised and buffer.size() < batch_size:
        return 0
    
    # Sample batch from experience replay
    states, actions, rewards, next_states, dones = buffer.sample(batch_size, percent_from_demo=1 if supervised else (0.5 if use_demo else 0))

    states = states.to(device)
    actions = actions.to(device)
    rewards = rewards.to(device)
    next_states = next_states.to(device)
    dones = dones.to(device)

    # Compute Q-values for current states
    q = model(states)
    # print('q.shape:', q.shape)
    q_values = q.gather(1, actions.unsqueeze(1)).squeeze(1)  # Select Q-values of taken actions

    # Compute next Q-values from the target network
    next_q_values = target_model(next_states).max(1)[0].detach()  # Max Q-value of next state

    dones = dones.to(torch.bool)
    # Zero next_q_values for terminal states
    next_q_values[dones] = 0.0

    # Compute target Q-values
    target_q_values = rewards + gamma * next_q_values

    loss = F.mse_loss(q_values, target_q_values.detach())

    if supervised:
        l = torch.full_like(q, 0.8)
        l[:, actions] = 0
        supervised_loss = torch.mean((q + l).max(dim=-1)[0] - q_values)
        loss += supervised_loss

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
    optimizer.step()
    return loss.item()


In [5]:
class ReplayBuffer:
    def __init__(self, capacity, demonstrations):
        self.buffer = collections.deque(maxlen=capacity)
        self.demonstrations = demonstrations

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((
            state,
            action,
            int(reward),
            next_state,
            bool(done)
        ))

    def sample(self, batch_size, percent_from_demo=0.5):
        if len(demonstrations) > 0:
            buffer_sample_size = int(batch_size * (1-percent_from_demo))
            batch = random.sample(self.buffer, buffer_sample_size)
            batch += random.sample(self.demonstrations, batch_size - buffer_sample_size)
            random.shuffle(batch)
        else:
            batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)

        return (
            torch.FloatTensor(np.array(state)),
            torch.LongTensor(action),
            torch.FloatTensor(reward),
            torch.FloatTensor(np.array(next_state)),
            torch.FloatTensor(done)
        )

    def size(self):
        return len(self.buffer)

In [6]:
oc = False
episodic=True
framestack = 4
use_demo=True
supervised=True
project_name = f'dqn{"_oc" if oc else ""}{"_demo" if use_demo else ""}{"_sup" if supervised else ""}'

env = get_env(oc=oc, framestack=framestack, episodic=episodic)

# Check Action / State space
obs, info = env.reset()

action_dim = env.action_space.n

print(f"Observation space: {env.observation_space}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if oc:
      dqn = DQN_MLP(input_size=len(env.ns_state), framestack=framestack, action_dim=action_dim).to(device)
      target_dqn = DQN_MLP(input_size=len(env.ns_state), framestack=framestack, action_dim=action_dim).to(device)
      target_dqn.load_state_dict(dqn.state_dict())
else:
      dqn = DQN_CNN(4, action_dim).to(device)
      target_dqn = DQN_CNN(4, action_dim).to(device)
      target_dqn.load_state_dict(dqn.state_dict())

lr = 0.0001
weight_decay = 1e-5
replay_buffer_size = 1000000
optimizer = optim.AdamW(dqn.parameters(), lr=lr, weight_decay=weight_decay)
if use_demo or supervised:
      demonstrations = load_llm_demonstrations(oc=oc)
      print(f'Loaded {len(demonstrations)} demonstration steps')
replay_buffer = ReplayBuffer(capacity=replay_buffer_size, demonstrations=demonstrations)

num_pretraining_iterations = 30000
num_train_iterations = 1000000
batch_size = 32
gamma = 0.99
if supervised:
      epsilon = 0.01
      epsilon_min = 0.01
      epsilon_decay = 1
else:
      epsilon = 1.0
      epsilon_min = 0.01
      epsilon_decay = 0.999954
target_update_freq = 10000
rewards_list = []

wandb.require("core")
wandb.login()
wandb.init(
      # Set the project where this run will be logged
      project="frogger",
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=project_name,
      # Track hyperparameters and run metadata
      config={
      "lr": lr,
      "weight_decay": weight_decay,
      "batch_size": batch_size,
      "gamma": gamma,
      "epsilon": epsilon,
      "epsilon_min": epsilon_min,
      "epsilon_decay": epsilon_decay,
      "replay_buffer_size": replay_buffer_size,
      "variant": project_name,
      "num_train_iterations": num_train_iterations,
      "target_update_freq": target_update_freq,
      })

A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]


Observation space: Box(0, 255, (84, 336), uint8)




total actions: 280
total reward: 35.0
total length: 280
Loaded 280 demonstration steps


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mkevinxli[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
if supervised:
    os.makedirs(f'{project_name}/pretrained', exist_ok=True)

    for iteration in tqdm(range(num_pretraining_iterations+1)):
        loss = train(dqn, target_dqn, replay_buffer, optimizer, batch_size, gamma, supervised=True, use_demo=use_demo)
        wandb.log({"pretrain/loss": loss})
        if iteration % target_update_freq == 0:
            target_dqn.load_state_dict(dqn.state_dict())
            reward, length, time = record_video(env=get_env(oc=oc, framestack=framestack), select_action=partial(select_action, model=dqn, epsilon=0),
                        video_folder=f"{project_name}/videos", episode_trigger=lambda x: True, name_prefix=f"pretrain_iter_{iteration}")
            wandb.log({'pretrain/reward': reward, 'pretrain/length': length, 'pretrain/time': time})
            torch.save(dqn.state_dict(), f'{project_name}/pretrained/{iteration}.pth')

  logger.warn(
  logger.warn(
100%|██████████| 30001/30001 [04:03<00:00, 123.06it/s]


In [8]:
os.makedirs(f'{project_name}/train', exist_ok=True)

optimizer = optim.AdamW(dqn.parameters(), lr=lr, weight_decay=weight_decay)
state, info = env.reset()
total_loss = 0
total_reward = 0

for iteration in tqdm(range(num_train_iterations+1)):
    action = select_action(env, dqn, state, epsilon)
    next_state, reward, terminated, truncated, info = env.step(action)
    total_reward += reward

    done = terminated or truncated

    replay_buffer.push(state, action, reward, next_state, done)

    if done:
        state, info = env.reset()
        wandb.log({'train/episode_loss': total_loss})
        total_loss = 0
        total_reward = 0
    else:
        state = next_state

    loss = train(dqn, target_dqn, replay_buffer, optimizer, batch_size, gamma, supervised=False, use_demo=use_demo)
    wandb.log({'train/loss': loss, 'train/epsilon': epsilon})
    total_loss += loss

    epsilon = max(epsilon_min, epsilon * epsilon_decay)

    rewards_list.append(total_reward)

    if iteration % target_update_freq == 0:
        target_dqn.load_state_dict(dqn.state_dict())
        torch.save(dqn.state_dict(), f"{project_name}/train/frogger_iter_{iteration}.pth")
        reward, length, time = record_video(env=get_env(oc=oc, framestack=framestack), select_action=partial(select_action, model=dqn, epsilon=0), video_folder=f"{project_name}/videos", episode_trigger=lambda x: True, name_prefix=f"train_iter_{iteration}")
        wandb.log({'train/reward': reward, 'train/length': length, 'train/time': time})


 43%|████▎     | 434857/1000001 [1:51:01<8:20:21, 18.82it/s] 