In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gym
import cv2
from collections import deque
import matplotlib.pyplot as plt

In [2]:
GAMMA = 0.99
LR = 2.5e-4
BATCH_SIZE = 64
FRAME_SKIP = 4

EPSILON_START = 1.0
EPSILON_END = 0.001
NUM_EPISODES = 500
EPSILON_DECAY = 1e-6

TARGET_UPDATE = 100
REPLAY_BUFFER_SIZE = 200_000
INPUT_SHAPE = (4, 84, 84)
CHECKPOINT_DIR = "checkpoints"
BEST_CHECKPOINT = os.path.join(CHECKPOINT_DIR, "best_checkpoint.pth")
BEST_SNAPSHOT = os.path.join("snapshots", "best_snapshot.png")

In [3]:
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs("snapshots", exist_ok=True)

def preprocess_state(state):
    state = cv2.cvtColor(state, cv2.COLOR_RGB2GRAY)
    state = cv2.resize(state, (84, 84), interpolation=cv2.INTER_AREA)
    return np.array(state, dtype=np.float32) / 255.0

In [4]:
class SumTree:
    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1) 
        self.data = [None] * capacity           
        self.write = 0                          

    def _propagate(self, idx, change):
        parent = (idx - 1) // 2
        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])

    def total(self):
        return self.tree[0]

    def add(self, p, data):
        idx = self.write + self.capacity - 1
        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

    def update(self, idx, p):
        change = p - self.tree[idx]
        self.tree[idx] = p
        self._propagate(idx, change)

    def get(self, s):
        idx = self._retrieve(0, s)
        data_idx = idx - self.capacity + 1
        return idx, self.tree[idx], self.data[data_idx]


class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6, beta=0.5, beta_increment_per_sampling=1e-4):
        self.alpha = alpha
        self.beta = beta
        self.beta_increment_per_sampling = beta_increment_per_sampling
        self.capacity = capacity
        self.tree = SumTree(capacity)
        self.epsilon = 1e-5

    def __len__(self):
        return min(self.tree.write, self.capacity)

    def add(self, error, sample):
        p = (abs(error) + self.epsilon) ** self.alpha
        self.tree.add(p, sample)

    def sample(self, n):
        batch = []
        idxs = []
        segment = self.tree.total() / n
        priorities = []

        self.beta = min(1.0, self.beta + self.beta_increment_per_sampling)

        for i in range(n):
            s = random.uniform(segment * i, segment * (i + 1))
            idx, p, data = self.tree.get(s)
            batch.append(data)
            idxs.append(idx)
            priorities.append(p)

        sampling_probabilities = priorities / self.tree.total()
        is_weight = np.power(self.capacity * sampling_probabilities, -self.beta)
        is_weight /= is_weight.max() 

        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            np.array(states),
            np.array(actions),
            np.array(rewards, dtype=np.float32),
            np.array(next_states),
            np.array(dones, dtype=np.uint8),
            idxs,
            is_weight
        )

    def update(self, idx, error):
        p = (abs(error) + self.epsilon) ** self.alpha
        self.tree.update(idx, p)

In [5]:
class DQNetwork(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(DQNetwork, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, 8, 4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1),
            nn.ReLU()
        )
        conv_out = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        )

    def _get_conv_out(self, shape):
        with torch.no_grad():
            dummy = torch.zeros(1, *shape)
            return int(np.prod(self.conv(dummy).shape[1:]))

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [6]:
class DQNAgent:
    def __init__(self, input_shape, num_actions):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.online_net = DQNetwork(input_shape, num_actions).to(self.device)
        self.target_net = DQNetwork(input_shape, num_actions).to(self.device)
        self.target_net.load_state_dict(self.online_net.state_dict())

        self.optimizer = optim.Adam(self.online_net.parameters(), lr=LR)
        
        self.replay_buffer = PrioritizedReplayBuffer(
            capacity=REPLAY_BUFFER_SIZE,
            alpha=0.6,
            beta=0.5,
            beta_increment_per_sampling=1e-4
        )
        
        self.epsilon = EPSILON_START
        self.epsilon_decay = EPSILON_DECAY
        self.epsilon_min = EPSILON_END
        self.best_reward = -float("inf")
        self.num_actions = num_actions

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.num_actions - 1)
        state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            q_values = self.online_net(state_t)
            return q_values.argmax().item()

    def learn(self):
        if len(self.replay_buffer) < BATCH_SIZE:
            return

        (states, actions, rewards, next_states, dones,
         idxs, is_weight) = self.replay_buffer.sample(BATCH_SIZE)

        states_t = torch.FloatTensor(states).to(self.device)
        actions_t = torch.LongTensor(actions).to(self.device)
        rewards_t = torch.FloatTensor(rewards).to(self.device)
        next_states_t = torch.FloatTensor(next_states).to(self.device)
        dones_t = torch.FloatTensor(dones).to(self.device)
        is_weight_t = torch.FloatTensor(is_weight).to(self.device)

        q_values = self.online_net(states_t)
        current_q = q_values.gather(1, actions_t.unsqueeze(1)).squeeze(1)

        next_q_online = self.online_net(next_states_t).detach()
        next_actions = next_q_online.argmax(dim=1)

        next_q_target = self.target_net(next_states_t).detach()
        target_q = next_q_target.gather(1, next_actions.unsqueeze(1)).squeeze(1)
        expected_q = rewards_t + (1 - dones_t) * GAMMA * target_q

        td_errors = current_q - expected_q
        loss = (is_weight_t * td_errors.pow(2)).mean()

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.online_net.parameters(), max_norm=10)
        self.optimizer.step()

        for i, idx in enumerate(idxs):
            self.replay_buffer.update(idx, td_errors[i].detach().cpu().numpy())

    def update_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon - self.epsilon_decay)

    def update_target_network(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def save_best_model(self, reward):
        if reward > self.best_reward:
            self.best_reward = reward
            torch.save({
                'online_state_dict': self.online_net.state_dict(),
                'target_state_dict': self.target_net.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'epsilon': self.epsilon,
                'reward': reward
            }, BEST_CHECKPOINT)
            print(f"Best model saved with reward: {self.best_reward}")


In [7]:
def save_best_snapshot(env, reward):
    frame = env.render()
    plt.imshow(frame)
    plt.title(f"Best Reward: {reward}")
    plt.axis("off")
    plt.savefig(BEST_SNAPSHOT)
    plt.close()
    print(f"Best snapshot saved with reward: {reward}")


In [8]:
def train(env, agent, start_episode=1):
    best_reward = agent.best_reward

    for episode in range(start_episode, start_episode + NUM_EPISODES):
        obs, _ = env.reset()
        state = preprocess_state(obs)
        state = np.stack([state] * 4, axis=0)
        total_reward = 0
        done = False
        step = 0

        while not done:
            action = agent.select_action(state)
            reward = 0

            for _ in range(FRAME_SKIP):
                next_state, r, done, _, _ = env.step(action)
                reward += r
                if done:
                    break

            next_state = preprocess_state(next_state)
            next_state_stack = np.roll(state, shift=-1, axis=0)
            next_state_stack[-1] = next_state

            state_t = torch.FloatTensor(state).unsqueeze(0).to(agent.device)
            next_state_t = torch.FloatTensor(next_state_stack).unsqueeze(0).to(agent.device)
            with torch.no_grad():
                current_q_val = agent.online_net(state_t)[0, action]
                next_q_online = agent.online_net(next_state_t).detach()
                next_q_target = agent.target_net(next_state_t).detach()
                next_actions = next_q_online.argmax(dim=1)
                td_target = reward + (1 - float(done)) * GAMMA * next_q_target[0, next_actions.item()]

            td_error = (current_q_val - td_target).detach().cpu().numpy()
            agent.replay_buffer.add(td_error, (state, action, reward, next_state_stack, done))

            state = next_state_stack
            total_reward += reward
            step += 1

            if step % 4 == 0:
                agent.learn()

        agent.update_epsilon()

        if episode % TARGET_UPDATE == 0:
            agent.update_target_network()

        if total_reward > best_reward:
            best_reward = total_reward
            agent.save_best_model(best_reward)
            save_best_snapshot(env, best_reward)

        print(f"Ep {episode}/{start_episode + NUM_EPISODES - 1} | "
              f"Reward: {total_reward} | "
              f"Epsilon: {agent.epsilon:.3f} | "
              f"Best: {best_reward}")


def load_best_model(agent):
    if os.path.exists(BEST_CHECKPOINT):
        checkpoint = torch.load(BEST_CHECKPOINT, map_location=agent.device)
        agent.online_net.load_state_dict(checkpoint['online_state_dict'])
        agent.target_net.load_state_dict(checkpoint['target_state_dict'])
        agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        agent.epsilon = checkpoint['epsilon'] 
        agent.best_reward = checkpoint['reward']
        print(f"Loaded best model with reward: {agent.best_reward} and resumed epsilon={agent.epsilon}.")
    else:
        print("No best checkpoint found. Starting fresh.")



: 

In [None]:
if __name__ == "__main__":
    env = gym.make('ALE/Breakout-v5', render_mode='human')
    agent = DQNAgent(INPUT_SHAPE, env.action_space.n)
    load_best_model(agent)
    train(env, agent)
    env.close()

  checkpoint = torch.load(BEST_CHECKPOINT, map_location=agent.device)
  if not isinstance(terminated, (bool, np.bool8)):


Loaded best model with reward: 45.0 and resumed epsilon=0.01.
Ep 1/500 | Reward: 25.0 | Epsilon: 0.010 | Best: 45.0
Ep 2/500 | Reward: 7.0 | Epsilon: 0.010 | Best: 45.0
Ep 3/500 | Reward: 7.0 | Epsilon: 0.010 | Best: 45.0
