In [1]:
import gymnasium as gym
import ale_py  # Ensure Atari environments work
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import collections
import random
from gymnasium.wrappers import GrayscaleObservation, ResizeObservation, RecordVideo, RecordEpisodeStatistics
from collections import deque
import torch.nn.functional as F
import cv2
from tqdm import tqdm
import wandb

In [2]:
import numpy as np
from collections import deque
from gymnasium import spaces

class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        """Stack k last frames.

        Returns lazy array, which is much more memory efficient.

        See Also
        --------
        baselines.common.atari_wrappers.LazyFrames
        """
        gym.Wrapper.__init__(self, env)
        self.k = k
        self.frames = deque([], maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype)

    def reset(self, seed=None, options=None):
        ob, info = self.env.reset(seed=seed, options=options)
        for _ in range(self.k):
            self.frames.append(ob)
        return self._get_ob(), info

    def step(self, action):
        ob, reward, done, truncated, info = self.env.step(action)
        self.frames.append(ob)
        return self._get_ob(), reward, done, truncated, info

    def _get_ob(self):
        assert len(self.frames) == self.k
        return LazyFrames(list(self.frames))

class ScaledFloatFrame(gym.ObservationWrapper):
    def __init__(self, env):
        gym.ObservationWrapper.__init__(self, env)
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)

    def observation(self, observation):
        # careful! This undoes the memory optimization, use
        # with smaller replay buffers only.
        return np.array(observation).astype(np.float32) / 255.0

class LazyFrames(object):
    def __init__(self, frames):
        """This object ensures that common frames between the observations are only stored once.
        It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
        buffers.

        This object should only be converted to numpy array before being passed to the model.

        You'd not believe how complex the previous solution was."""
        self._frames = frames
        self._out = None

    def _force(self):
        if self._out is None:
            self._out = np.stack(self._frames)
            self._frames = None
        return self._out

    def __array__(self, dtype=None):
        out = self._force()
        if dtype is not None:
            out = out.astype(dtype)
        return out

    def __len__(self):
        return len(self._force())

    def __getitem__(self, i):
        return self._force()[i]

    def count(self):
        frames = self._force()
        return frames.shape[frames.ndim - 1]

    def frame(self, i):
        return self._force()[..., i]


def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
    """Configure environment for DeepMind-style Atari.
    """
    if episode_life:
        env = EpisodicLifeEnv(env)
    if 'FIRE' in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    if scale:
        env = ScaledFloatFrame(env)
    if clip_rewards:
        env = ClipRewardEnv(env)
    if frame_stack:
        env = FrameStack(env, 4)
    return env

In [3]:
class DQN_CNN(nn.Module):
    def __init__(self, input_channels, action_dim):
        super(DQN_CNN, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=8, stride=4),  # Output: (32, 20, 20)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),  # Output: (64, 9, 9)
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),  # Output: (64, 7, 7)
            nn.ReLU()
        )


        self.fc_layers = nn.Sequential(
            nn.Linear(64*7*7, 512),  # Flattened CNN features
            nn.ReLU(),
            nn.Linear(512, action_dim)  # Output Q-values for each action
        )

    def forward(self, x):
        x = self.conv_layers(x)

        x = x.view(x.size(0), -1)  # Flatten

        x = self.fc_layers(x)
        return x

In [4]:
def select_action(env, model, state, epsilon):
    if random.random() < epsilon:
        return env.action_space.sample()  # Random action (exploration)

    state = torch.FloatTensor(state).unsqueeze(0) / 255.0  # Normalize pixels
    state = state.to(device)
    with torch.no_grad():
        return model(state).argmax().item()

def train(model, target_model, buffer, optimizer, batch_size, gamma, use_supervised_loss=False):
    # Sample batch from experience replay
    states, actions, rewards, next_states, dones = buffer.sample(batch_size)

    states = states.to(device)
    actions = actions.to(device)
    rewards = rewards.to(device)
    next_states = next_states.to(device)
    dones = dones.to(device)

    # Compute Q-values for current states
    q = model(states)
    # print('q.shape:', q.shape)
    q_values = q.gather(1, actions.unsqueeze(1)).squeeze(1)  # Select Q-values of taken actions

    # Compute next Q-values from the target network
    next_q_values = target_model(next_states).max(1)[0].detach()  # Max Q-value of next state

    dones = dones.to(torch.bool)
    # Zero next_q_values for terminal states
    next_q_values[dones] = 0.0

    # Compute target Q-values
    target_q_values = rewards + gamma * next_q_values

    # Compute loss using Huber loss (smooth_l1_loss)
    dq_loss = F.smooth_l1_loss(q_values, target_q_values.detach())

    if use_supervised_loss:
        l = torch.full_like(q, 0.8)
        l[:, actions] = 0
        # print('actions', actions)
        # print('l', l)
        # print('q', q)
        # print('q_values', q_values)
        # print('q_values.shape:', q_values.shape)
        supervised_loss = torch.mean((q + l).max(dim=-1)[0] - q_values)

    loss = dq_loss + supervised_loss

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
    optimizer.step()
    if use_supervised_loss:
        return dq_loss.item(), supervised_loss.item()
    else:
        return loss.item()


In [5]:
class ReplayBuffer:
    def __init__(self, capacity, demonstrations):
        self.buffer = collections.deque(maxlen=capacity)
        self.demonstrations = demonstrations

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((
            state,
            action,
            int(reward),
            next_state,
            bool(done)
        ))

    def sample(self, batch_size):
        if len(self.buffer) < batch_size // 2:
            batch = random.sample(self.demonstrations, batch_size)
        else:
            batch = random.sample(self.buffer, batch_size // 2)
            batch += random.sample(self.demonstrations, batch_size // 2)
            random.shuffle(batch)
        state, action, reward, next_state, done = zip(*batch)

        return (
            torch.FloatTensor(np.array(state)) / 255.0,  # Normalize pixels
            torch.LongTensor(action),
            torch.FloatTensor(reward),
            torch.FloatTensor(np.array(next_state)) / 255.0,
            torch.FloatTensor(done)
        )

    def size(self):
        return len(self.buffer)

In [6]:
import pickle

# load expert demonstrations
trace = []
for i in range(2,3):
    with open(f'traces/trace_{i}.pkl', 'rb') as f:
        trace += pickle.load(f)

def process_frame(frame):
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
    frame = cv2.resize(
        frame, (84, 84), interpolation=cv2.INTER_AREA
    )
    return frame

demonstrations = []
for obj in trace:
    demonstrations.append((process_frame(obj['state']), obj['action'], obj['reward'], process_frame(obj['next_state']), obj['done']))

demonstrations = [(
    np.stack([demonstrations[j][0] for j in range(i-3,i+1)]),
    demonstrations[i][1],
    demonstrations[i][2],
    np.stack([demonstrations[j][3] for j in range(i-3,i+1)]),
    demonstrations[i][4])
    for i in range(4, len(demonstrations))]


In [7]:
def get_env():
    env = gym.make("ALE/Frogger-v5", render_mode="rgb_array")  # Create Atari env
    env = GrayscaleObservation(env, keep_dim=False)
    env = ResizeObservation(env, (84, 84))
    env = wrap_deepmind(env, episode_life=False, clip_rewards=False, frame_stack=True, scale=False)
    return env

In [8]:
def record_video(name_prefix):
    env = get_env()
    env = RecordVideo(env, video_folder="dqfd/videos", episode_trigger=lambda x: True, name_prefix=name_prefix)
    env = RecordEpisodeStatistics(env, buffer_length=1)
    state, info = env.reset()
    while True:
        action = select_action(env, dqn, state, epsilon)
        next_state, reward, done, truncated, info = env.step(action)
        state = next_state
        if done:
            break
    env.close()
    return env.return_queue[0], env.length_queue[0], env.time_queue[0]

In [9]:
# Create the Atari environment
env = get_env()

# Check Action / State space
obs, info = env.reset()

action_dim = env.action_space.n
print(f"Observation space: {env.observation_space}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dqn = DQN_CNN(4, action_dim).to(device)
target_dqn = DQN_CNN(4, action_dim).to(device)
target_dqn.load_state_dict(dqn.state_dict())

lr = 0.0001
weight_decay = 1e-5
replay_buffer_size = 10000
optimizer = optim.AdamW(dqn.parameters(), lr=lr, weight_decay=weight_decay)
replay_buffer = ReplayBuffer(replay_buffer_size, demonstrations)

num_pretraining_iterations = 100000
num_train_iterations = 1000
batch_size = 32
gamma = 0.99
epsilon = 0.01
target_update_freq = 10000
rewards_list = []

wandb.require("core")
wandb.login()
wandb.init(
      # Set the project where this run will be logged
      project="frogger",
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=f"dqfd",
      # Track hyperparameters and run metadata
      config={
      "lr": lr,
      "weight_decay": weight_decay,
      "batch_size": batch_size,
      "gamma": gamma,
      "epsilon": epsilon,
      "replay_buffer_size": replay_buffer_size,
      "variant": "dqfd",
      "num_pretraining_iterations": num_pretraining_iterations,
      "num_train_iterations": num_train_iterations,
      "target_update_freq": target_update_freq,
      })

A.L.E: Arcade Learning Environment (version 0.10.2+c9d4b19)
[Powered by Stella]


Observation space: Box(0, 255, (84, 336), uint8)


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mkevinxli[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [20]:
# Supervised pretraining
for iteration in tqdm(range(num_pretraining_iterations)):
    dq_loss, s_loss = train(dqn, target_dqn, replay_buffer, optimizer, batch_size, gamma, use_supervised_loss=True)
    wandb.log({"pretrain/loss": dq_loss + s_loss, "pretrain/dq_loss": dq_loss, "pretrain/supervised_loss": s_loss})
    if iteration % target_update_freq == 0:
        target_dqn.load_state_dict(dqn.state_dict())
    if iteration % 1000 == 0:
        print(f"DQ Loss {dq_loss:.5f} S Loss {s_loss:.5f}")
        reward, length, time = record_video(name_prefix=f"pretrain_iter_{iteration}")
        print(f'Episode total rewards: {reward}, lengths: {length}, time taken: {time}')
        wandb.log({'pretrain/reward': reward, 'pretrain/length': length, 'pretrain/time': time})


  0%|          | 0/100000 [00:00<?, ?it/s]

DQ Loss 0.03156 S Loss 0.76866


  state = torch.FloatTensor(state).unsqueeze(0) / 255.0  # Normalize pixels
  0%|          | 4/100000 [00:10<54:08:45,  1.95s/it] 

Episode total rewards: 4.0, lengths: 1329, time taken: 8.837131


  logger.warn(


DQ Loss 0.03860 S Loss 0.02591


  1%|          | 1004/100000 [00:59<17:10:05,  1.60it/s]

Episode total rewards: 4.0, lengths: 1257, time taken: 7.256957


  2%|▏         | 1998/100000 [01:43<59:53, 27.28it/s]   

DQ Loss 0.01150 S Loss 0.82132


  2%|▏         | 2003/100000 [01:45<5:23:14,  5.05it/s]

Episode total rewards: 9.0, lengths: 265, time taken: 1.607035


  3%|▎         | 2998/100000 [02:24<1:04:01, 25.25it/s]

DQ Loss 0.07003 S Loss 0.02982


  3%|▎         | 3004/100000 [02:37<25:27:13,  1.06it/s]

Episode total rewards: 3.0, lengths: 1753, time taken: 11.324477


  4%|▍         | 4000/100000 [03:17<1:11:29, 22.38it/s] 

DQ Loss 0.02147 S Loss 0.01311


  4%|▍         | 4003/100000 [03:20<9:36:00,  2.78it/s]

Episode total rewards: 8.0, lengths: 425, time taken: 2.554721


  5%|▌         | 5000/100000 [04:04<1:00:30, 26.17it/s]

DQ Loss 0.01908 S Loss 0.01625


  5%|▌         | 5007/100000 [04:08<7:50:25,  3.37it/s] 

Episode total rewards: 9.0, lengths: 625, time taken: 3.503478


  6%|▌         | 5998/100000 [04:46<1:16:08, 20.58it/s]

DQ Loss 0.01861 S Loss 0.05268


  6%|▌         | 6004/100000 [04:49<5:59:49,  4.35it/s]

Episode total rewards: 12.0, lengths: 377, time taken: 2.204538


  7%|▋         | 6999/100000 [05:23<58:54, 26.31it/s]  

DQ Loss 4.30416 S Loss 0.00227


  7%|▋         | 7005/100000 [05:27<6:25:52,  4.02it/s]

Episode total rewards: 12.0, lengths: 381, time taken: 2.518098


  8%|▊         | 7998/100000 [06:06<57:53, 26.49it/s]  

DQ Loss 2.23699 S Loss 0.00167


  8%|▊         | 8004/100000 [06:09<5:43:42,  4.46it/s]

Episode total rewards: 11.0, lengths: 381, time taken: 2.256722


  9%|▉         | 8999/100000 [06:47<54:23, 27.88it/s]  

DQ Loss 0.01358 S Loss 0.00004


  9%|▉         | 9005/100000 [06:50<5:35:09,  4.53it/s]

Episode total rewards: 8.0, lengths: 457, time taken: 2.232832


 10%|█         | 10000/100000 [07:28<48:01, 31.23it/s] 

DQ Loss 0.01290 S Loss 0.00922


 10%|█         | 10004/100000 [07:31<5:19:43,  4.69it/s]

Episode total rewards: 12.0, lengths: 358, time taken: 1.909942


 11%|█         | 10998/100000 [08:08<47:08, 31.47it/s]  

DQ Loss 2.02860 S Loss 0.01501


 11%|█         | 11005/100000 [08:11<4:38:54,  5.32it/s]

Episode total rewards: 17.0, lengths: 495, time taken: 2.343634


 12%|█▏        | 11997/100000 [08:41<44:07, 33.24it/s]  

DQ Loss 0.00943 S Loss 0.00579


 12%|█▏        | 12005/100000 [08:45<4:52:52,  5.01it/s]

Episode total rewards: 18.0, lengths: 621, time taken: 2.763642


 13%|█▎        | 13000/100000 [09:15<40:30, 35.80it/s]  

DQ Loss 1.30665 S Loss 0.00777


 13%|█▎        | 13008/100000 [09:19<5:32:05,  4.37it/s]

Episode total rewards: 8.0, lengths: 697, time taken: 3.32441


 14%|█▍        | 14000/100000 [09:49<41:28, 34.56it/s]  

DQ Loss 0.00378 S Loss 0.00040


 14%|█▍        | 14004/100000 [09:52<6:58:51,  3.42it/s]

Episode total rewards: 35.0, lengths: 632, time taken: 3.05099


 15%|█▍        | 14999/100000 [10:22<42:54, 33.01it/s]  

DQ Loss 0.00580 S Loss 0.00000


 15%|█▌        | 15003/100000 [10:26<6:30:18,  3.63it/s]

Episode total rewards: 16.0, lengths: 541, time taken: 2.843194


 16%|█▌        | 16000/100000 [10:57<41:46, 33.51it/s]  

DQ Loss 1.03982 S Loss 0.00000


 16%|█▌        | 16004/100000 [11:00<5:41:06,  4.10it/s]

Episode total rewards: 14.0, lengths: 500, time taken: 2.271989


 17%|█▋        | 16999/100000 [11:34<50:00, 27.66it/s]  

DQ Loss 0.00407 S Loss 0.00000


 17%|█▋        | 17006/100000 [11:38<6:07:00,  3.77it/s]

Episode total rewards: 79.0, lengths: 671, time taken: 3.314191


 18%|█▊        | 17999/100000 [12:09<45:11, 30.25it/s]  

DQ Loss 0.00804 S Loss 0.00138


 18%|█▊        | 18007/100000 [12:12<3:28:55,  6.54it/s]

Episode total rewards: 12.0, lengths: 385, time taken: 1.98006


 19%|█▉        | 18997/100000 [12:43<39:21, 34.30it/s]  

DQ Loss 1.40181 S Loss 0.01644


 19%|█▉        | 19005/100000 [12:46<3:49:03,  5.89it/s]

Episode total rewards: 18.0, lengths: 459, time taken: 2.250636


 20%|█▉        | 19999/100000 [13:16<40:51, 32.63it/s]  

DQ Loss 0.67936 S Loss 0.00000


 20%|██        | 20003/100000 [13:20<6:14:49,  3.56it/s]

Episode total rewards: 20.0, lengths: 603, time taken: 2.80124


 21%|██        | 20997/100000 [13:50<38:11, 34.48it/s]  

DQ Loss 0.03072 S Loss 0.00000


 21%|██        | 21004/100000 [13:53<4:43:46,  4.64it/s]

Episode total rewards: 30.0, lengths: 555, time taken: 2.841188


 22%|██▏       | 21997/100000 [14:23<37:50, 34.36it/s]  

DQ Loss 0.03814 S Loss 0.00039


 22%|██▏       | 22005/100000 [14:26<3:56:37,  5.49it/s]

Episode total rewards: 13.0, lengths: 493, time taken: 2.476319


 23%|██▎       | 22997/100000 [14:56<43:41, 29.37it/s]  

DQ Loss 0.05883 S Loss 0.02940


 23%|██▎       | 23005/100000 [14:59<3:54:05,  5.48it/s]

Episode total rewards: 20.0, lengths: 541, time taken: 2.443799


 24%|██▍       | 23997/100000 [15:29<37:09, 34.09it/s]  

DQ Loss 0.02589 S Loss 0.00000


 24%|██▍       | 24005/100000 [15:31<3:00:04,  7.03it/s]

Episode total rewards: 12.0, lengths: 357, time taken: 1.847406


 25%|██▌       | 25000/100000 [16:00<33:52, 36.89it/s]  

DQ Loss 0.00998 S Loss 0.00279


 25%|██▌       | 25008/100000 [16:03<3:37:43,  5.74it/s]

Episode total rewards: 21.0, lengths: 505, time taken: 2.356721


 26%|██▌       | 25998/100000 [16:38<44:54, 27.46it/s]  

DQ Loss 0.06316 S Loss 0.00000


 26%|██▌       | 26005/100000 [16:41<5:14:44,  3.92it/s]

Episode total rewards: 26.0, lengths: 516, time taken: 3.039076


 27%|██▋       | 26998/100000 [17:16<48:00, 25.34it/s]  

DQ Loss 0.07422 S Loss 0.00000


 27%|██▋       | 27001/100000 [17:20<8:36:44,  2.35it/s]

Episode total rewards: 68.0, lengths: 529, time taken: 3.231191


 28%|██▊       | 27997/100000 [17:58<39:56, 30.05it/s]  

DQ Loss 0.00402 S Loss 0.00000


 28%|██▊       | 28004/100000 [18:02<4:33:31,  4.39it/s]

Episode total rewards: 14.0, lengths: 516, time taken: 2.945987


 29%|██▉       | 29000/100000 [18:38<47:20, 24.99it/s]  

DQ Loss 0.00857 S Loss 0.00000


In [1]:
state, info = env.reset()
total_loss = 0
total_reward = 0

for iteration in range(num_train_iterations):
    action = select_action(env, dqn, state, epsilon)
    next_state, reward, terminated, truncated, info = env.step(action)
    total_reward += reward

    replay_buffer.push(state, action, reward, next_state, terminated)

    if terminated:
        state, info = env.reset()
        total_loss = 0
        total_reward = 0
        wandb.log({'train/loss': total_loss, 'train/reward': total_reward})
    else:
        state = next_state

    loss = train(dqn, target_dqn, replay_buffer, optimizer, batch_size, gamma)
    total_loss += loss

    rewards_list.append(total_reward)

    if iteration % target_update_freq == 0:
        target_dqn.load_state_dict(dqn.state_dict())
        torch.save(dqn.state_dict(), f"frogger_dqfd_iter_{iteration}.pth")
        reward, length, time = record_video(name_prefix=f"train_iter_{iteration}")

wandb.finish()

plt.plot(rewards_list)
plt.xlabel("Iteration")
plt.ylabel("Total Reward")
plt.title("DQN Training Performance on Frogger")
plt.show()

# store the model
torch.save(dqn.state_dict(), "frogger_dqfd_model.pth")
# save the rewards_list in a txt file with comma separated
np.savetxt("frogger_dqn_rewards.txt", rewards_list, delimiter=",")

NameError: name 'env' is not defined