In [4]:
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1, dtype=np.float32)
        self.data = np.empty(capacity, dtype=object)
        self.ptr = 0
        self.n_entries = 0

    def _propagate(self, idx, change):
        parent = (idx - 1) // 2
        self.tree[parent] += change
        if parent != 0:
            self._propagate(parent, change)

    def update(self, idx, priority):
        change = priority - self.tree[idx]
        self.tree[idx] = priority
        self._propagate(idx, change)

    def add(self, priority, data):
        if data is None:
            priority = 0.0
        idx = self.ptr + self.capacity - 1
        self.data[self.ptr] = data
        self.update(idx, priority)
        self.ptr = (self.ptr + 1) % self.capacity
        self.n_entries = min(self.n_entries + 1, self.capacity)

    def _retrieve(self, idx, value):
        left = 2 * idx + 1
        right = left + 1
        if left >= len(self.tree):
            return idx
        if value <= self.tree[left]:
            return self._retrieve(left, value)
        else:
            return self._retrieve(right, value - self.tree[left])

    def sample(self, batch_size):
        segment = self.tree[0] / batch_size
        batch_idx, priorities, data = [], [], []
        for i in range(batch_size):
            a, b = segment * i, segment * (i + 1)
            value = random.uniform(a, b)
            idx = self._retrieve(0, value)
            tree_idx = idx - self.capacity + 1
            batch_idx.append(tree_idx)
            priorities.append(self.tree[idx])
            data.append(self.data[tree_idx])
        return batch_idx, np.array(priorities, dtype=np.float32), data
    @property
    def total(self):
        return self.tree[0]


In [5]:
import os
import random
from collections import deque, namedtuple
import numpy as np
import cv2
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from typing import Tuple
import ale_py
class AtariPreprocess:
    def __init__(self, width=84, height=84):
        self.width = width
        self.height = height

    def __call__(self, obs):
        img = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
        img = cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_AREA)
        return np.ascontiguousarray(img, dtype=np.uint8)

class FrameStack:
    def __init__(self, k):
        self.k = k
        self.frames = deque(maxlen=k)

    def reset(self, frame):
        for _ in range(self.k):
            self.frames.append(frame)
        return self._get_obs()

    def append(self, frame):
        self.frames.append(frame)
        return self._get_obs()

    def _get_obs(self):
        return np.stack(self.frames, axis=0)

class SkipEnvWrapper(gym.Wrapper):
    def __init__(self, env, skip=4):
        super().__init__(env)
        self.skip = skip

    def step(self, action):
        total_reward = 0.0
        done = False
        for i in range(self.skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            total_reward += reward
            done = terminated or truncated
            if done:
                break
        return obs, total_reward, terminated, truncated, info

def make_env(env_id="ALE/air_raid-v5", seed=0, skip=4, stack=4):
    raw = gym.make(env_id, render_mode=None)
    raw.reset(seed=seed)
    pre = AtariPreprocess()
    skipw = SkipEnvWrapper(raw, skip=skip)
    stacker = FrameStack(stack)
    class EnvObj:
        def __init__(self, e, pre, stacker):
            self.e = e
            self.pre = pre
            self.stacker = stacker
            self.action_space = e.action_space
            self.observation_space = gym.spaces.Box(
                low=0, high=255, shape=(stacker.k, pre.height, pre.width), dtype=np.uint8
            )

        def reset(self):
            obs, info = self.e.reset()
            obs_proc = self.pre(obs)
            stacked = self.stacker.reset(obs_proc)
            return stacked, info

        def step(self, action):
            obs, reward, terminated, truncated, info = self.e.step(action)
            obs_proc = self.pre(obs)
            stacked = self.stacker.append(obs_proc)
            reward = np.clip(reward, -1.0, 1.0)
            return stacked, reward, terminated, truncated, info

        def render(self, *args, **kwargs):
            return self.e.render(*args, **kwargs)

        def close(self):
            return self.e.close()

    return EnvObj(skipw, pre, stacker)

Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))

class PrioritizedReplayBuffer:
    def __init__(self, capacity, device, alpha=0.6, beta=0.4, beta_increment=1e-6, eps=1e-6):
        self.tree = SumTree(capacity)
        self.device = device
        self.alpha = alpha
        self.beta = beta
        self.beta_increment = beta_increment
        self.eps = eps
        self.max_priority = 1.0

    def push(self, *args):
        self.tree.add(self.max_priority, Transition(*args))

    def sample(self, batch_size):
        idxs, priorities, batch = [], [], []
        while len(batch) < batch_size:
            i, p, d = self.tree.sample(1)
            if d[0] is None:
                continue
            idxs.extend(i)
            priorities.extend(p)
            batch.extend(d)
            if len(batch) == batch_size:
                break
        probs = np.array(priorities, dtype=np.float32) / self.tree.total
        self.beta = min(1.0, self.beta + self.beta_increment)
        weights = (probs * self.tree.n_entries) ** (-self.beta)
        weights /= weights.max()
        states = torch.from_numpy(np.stack([b.state for b in batch])).float().to(self.device) / 255.0
        actions = torch.tensor([b.action for b in batch], dtype=torch.long, device=self.device)
        rewards = torch.tensor([b.reward for b in batch], dtype=torch.float32, device=self.device)
        next_states = torch.from_numpy(np.stack([b.next_state for b in batch])).float().to(self.device) / 255.0
        dones = torch.tensor([float(b.done) for b in batch], dtype=torch.float32, device=self.device)
        weights = torch.from_numpy(weights.astype(np.float32)).to(self.device)
        return states, actions, rewards, next_states, dones, weights, idxs

    def update_priorities(self, idxs, td_errors):
        for idx, td in zip(idxs, td_errors):
            priority = (abs(td) + self.eps) ** self.alpha
            self.tree.update(idx + self.tree.capacity - 1, priority)
            self.max_priority = max(self.max_priority, priority)

    def __len__(self):
        return self.tree.n_entries

class DQN(nn.Module):
    def __init__(self, in_channels, n_actions):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self._conv_out = self._get_conv_out(in_channels)
        self.fc1 = nn.Linear(self._conv_out, 512)
        self.fc2 = nn.Linear(512, n_actions)

    def _get_conv_out(self, in_channels):
        x = torch.zeros(1, in_channels, 84, 84)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return int(np.prod(x.size()))

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class DQNAgent:
    def __init__(self, in_channels, n_actions, device, lr=1e-4, gamma=0.99,
                 buffer_size=100000, batch_size=32, target_update=1000,alpha=0.6,beta=0.4):
        self.device = device
        self.n_actions = n_actions
        self.gamma = gamma
        self.batch_size = batch_size
        self.target_update = target_update
        self.policy_net = DQN(in_channels, n_actions).to(device)
        self.target_net = DQN(in_channels, n_actions).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.replay = PrioritizedReplayBuffer(buffer_size, device,alpha=alpha, beta=beta)
        self.steps_done = 0

    def select_action(self, state, epsilon):
        if random.random() < epsilon:
            return random.randrange(self.n_actions)
        else:
            with torch.no_grad():
                s = torch.from_numpy(state).float().unsqueeze(0).to(self.device) / 255.0
                q = self.policy_net(s)
                return int(q.argmax(dim=1).item())

    def optimize(self):
        if len(self.replay) < self.batch_size:
            return None
        states, actions, rewards, next_states, dones, weights, idxs = self.replay.sample(self.batch_size)
        q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            next_actions = self.policy_net(next_states).argmax(1)
            next_q = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1)
            target_q = rewards + (1.0 - dones) * self.gamma * next_q
        td_errors = (q_values - target_q).detach().cpu().numpy()
        loss = (weights * F.smooth_l1_loss(q_values, target_q, reduction='none')).mean()
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.policy_net.parameters(), 10.0)
        self.optimizer.step()
        self.replay.update_priorities(idxs, td_errors)
        self.steps_done += 1
        if self.steps_done % self.target_update == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())
        return loss.item()

    def save(self, path):
        torch.save({
            'policy_state_dict': self.policy_net.state_dict(),
            'target_state_dict': self.target_net.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)

    def load(self, path):
        ck = torch.load(path, map_location=self.device)
        self.policy_net.load_state_dict(ck['policy_state_dict'])
        self.target_net.load_state_dict(ck['target_state_dict'])
        self.optimizer.load_state_dict(ck['optimizer_state_dict'])

def train(env_id='ALE/air_raid-v5',
          seed=42,
          total_steps=2_000_000,
          start_learning=50_000,
          eval_interval=50_000,
          save_path='dqn_airraid.pth'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    env = make_env(env_id, seed=seed, skip=4, stack=4)
    eval_env = make_env(env_id, seed=seed+1, skip=4, stack=4)
    obs, _ = env.reset()
    in_channels = obs.shape[0]
    n_actions = env.action_space.n
    agent = DQNAgent(in_channels, n_actions, device,
                     lr=1e-4, gamma=0.99,
                     buffer_size=500000, batch_size=64, target_update=1000)
    epsilon_start = 1.0
    epsilon_final = 0.1
    epsilon_decay = 1_000_000
    state, _ = env.reset()
    episode_reward = 0.0
    episode_cnt = 0
    total_step = 0
    losses = []

    while total_step < total_steps:
        eps = epsilon_final + (epsilon_start - epsilon_final) * max(0, (epsilon_decay - total_step) / epsilon_decay)
        action = agent.select_action(state, eps)
        next_state, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        agent.replay.push(state, action, reward, next_state, done)
        state = next_state
        episode_reward += reward
        total_step += 1
        if len(agent.replay) >= start_learning:
            loss = agent.optimize()
            if loss is not None:
                losses.append(loss)
        if done:
            state, _ = env.reset()
            episode_cnt += 1
            print(f"Step {total_step} | Episode {episode_cnt} ended | Reward {episode_reward:.2f} | Eps {eps:.3f}")
            episode_reward = 0.0
        if total_step % eval_interval == 0:
            avg_score = evaluate(agent, eval_env, episodes=5)
            print(f"== Eval at step {total_step}: avg score = {avg_score:.2f} ==")
    agent.save(save_path)
    env.close()
    eval_env.close()

def evaluate(agent:DQNAgent, env, episodes=5, render=False):
    device = agent.device
    scores = []
    for ep in range(episodes):
        state, _ = env.reset()
        done = False
        total = 0.0
        while not done:
            action = agent.select_action(state, epsilon=0.001)  # near-greedy
            state, reward, terminated, truncated, info = env.step(action)
            total += reward
            done = terminated or truncated
            if render:
                env.render()
        scores.append(total)
    return float(np.mean(scores))


if __name__ == "__main__":
    train(env_id="ALE/AirRaid-v5",
          seed=123,
          total_steps=2_000_000,
          start_learning=50_000,
          eval_interval=50_000,
          save_path="dqn_airraid_final_improving.pth")


Step 156 | Episode 1 ended | Reward 5.00 | Eps 1.000
Step 252 | Episode 2 ended | Reward 0.00 | Eps 1.000
Step 347 | Episode 3 ended | Reward 5.00 | Eps 1.000
Step 473 | Episode 4 ended | Reward 8.00 | Eps 1.000
Step 666 | Episode 5 ended | Reward 10.00 | Eps 0.999
Step 837 | Episode 6 ended | Reward 8.00 | Eps 0.999
Step 1010 | Episode 7 ended | Reward 5.00 | Eps 0.999
Step 1111 | Episode 8 ended | Reward 4.00 | Eps 0.999
Step 1246 | Episode 9 ended | Reward 9.00 | Eps 0.999
Step 1405 | Episode 10 ended | Reward 5.00 | Eps 0.999
Step 1657 | Episode 11 ended | Reward 10.00 | Eps 0.999
Step 1826 | Episode 12 ended | Reward 9.00 | Eps 0.998
Step 2085 | Episode 13 ended | Reward 11.00 | Eps 0.998
Step 2318 | Episode 14 ended | Reward 8.00 | Eps 0.998
Step 2568 | Episode 15 ended | Reward 8.00 | Eps 0.998
Step 2738 | Episode 16 ended | Reward 9.00 | Eps 0.998
Step 2936 | Episode 17 ended | Reward 8.00 | Eps 0.997
Step 3104 | Episode 18 ended | Reward 9.00 | Eps 0.997
Step 3205 | Episode 19

In [7]:
import gymnasium as gym
import torch
import numpy as np
import ale_py

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = make_env(env_id="ALE/AirRaid-v5", seed=114514, skip=4, stack=4)
env.e = gym.make("ALE/AirRaid-v5", render_mode="human")
num_actions = env.action_space.n
obs, _ = env.reset()
in_channels = obs.shape[0]
policy_net = DQN(in_channels,num_actions).to(device)
checkpoint = torch.load("models/dqn_airraid_final_improving.pth", map_location=device)
policy_net.load_state_dict(checkpoint['policy_state_dict'])
policy_net.eval()
episodes = 3
for ep in range(episodes):
    state, _ = env.reset()
    done = False
    score = 0
    while not done:
        s = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) / 255.0
        with torch.no_grad():
            q_values = policy_net(s)
            action = q_values.argmax(1).item()
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        state = next_state
        score += reward
        env.render()
    print(f"Episode {ep+1} finished, score = {score}")

env.close()

  checkpoint = torch.load("dqn_airraid_final_improving.pth", map_location=device)


Episode 1 finished, score = 6.0
Episode 2 finished, score = 18.0
Episode 3 finished, score = 3.0
