In [7]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import numpy as np

In [8]:
# Compatibility fix for newer NumPy versions
if not hasattr(np, 'bool8'):
    np.bool8 = np.bool_

In [9]:
# Hyperparameters
GAMMA = 0.99
LR = 1e-3
EPS_START = 1.0
EPS_END = 0.05
EPS_DECAY = 3000
BATCH_SIZE = 64
MEM_SIZE = 10000
TARGET_UPDATE = 100
FRAME_LIMIT = 5000

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Dueling DQN Network ---
class DuelingDQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.feature = nn.Sequential(
            nn.Linear(input_dim, 128), nn.ReLU()
        )
        self.value = nn.Sequential(
            nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 1)
        )
        self.advantage = nn.Sequential(
            nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, output_dim)
        )

    def forward(self, x):
        x = self.feature(x)
        value = self.value(x)
        advantage = self.advantage(x)
        return value + advantage - advantage.mean(dim=1, keepdim=True)

In [11]:
# --- Replay Buffer ---
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

In [12]:
def select_action(state, epsilon, model, action_space):
    if random.random() < epsilon:
        return action_space.sample()
    state = torch.FloatTensor(state).unsqueeze(0).to(device)
    with torch.no_grad():
        q_values = model(state)
    return q_values.argmax().item()

In [13]:
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

policy_net = DuelingDQN(state_dim, action_dim).to(device)
target_net = DuelingDQN(state_dim, action_dim).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr=LR)
memory = ReplayBuffer(MEM_SIZE)

frame_count = 0
episode = 0
epsilon = EPS_START

while frame_count < FRAME_LIMIT:
    state, _ = env.reset()  # Correct unpacking
    done = False
    episode_reward = 0

    while not done and frame_count < FRAME_LIMIT:
        epsilon = max(EPS_END, EPS_START - frame_count / EPS_DECAY)
        action = select_action(state, epsilon, policy_net, env.action_space)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        memory.push((state, action, reward, next_state, done))
        state = next_state
        frame_count += 1
        episode_reward += reward

        # Learn
        if len(memory) >= BATCH_SIZE:
            batch = memory.sample(BATCH_SIZE)
            states, actions, rewards, next_states, dones = zip(*batch)

            states = torch.tensor(np.array(states), dtype=torch.float32).to(device)
            actions = torch.tensor(actions, dtype=torch.int64).unsqueeze(1).to(device)
            rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1).to(device)
            next_states = torch.tensor(np.array(next_states), dtype=torch.float32).to(device)
            dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(device)

            q_values = policy_net(states).gather(1, actions)

            next_actions = policy_net(next_states).argmax(1, keepdim=True)
            next_q_values = target_net(next_states).gather(1, next_actions)
            target_q_values = rewards + GAMMA * next_q_values * (1 - dones)

            loss = nn.MSELoss()(q_values, target_q_values.detach())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Update target network
        if frame_count % TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())

    episode += 1
    print(f"Episode {episode}, Reward: {episode_reward}, Epsilon: {epsilon:.3f}, Frame: {frame_count}")

env.close()

Episode 1, Reward: 21.0, Epsilon: 0.993, Frame: 21
Episode 2, Reward: 54.0, Epsilon: 0.975, Frame: 75
Episode 3, Reward: 11.0, Epsilon: 0.972, Frame: 86
Episode 4, Reward: 42.0, Epsilon: 0.958, Frame: 128
Episode 5, Reward: 10.0, Epsilon: 0.954, Frame: 138
Episode 6, Reward: 14.0, Epsilon: 0.950, Frame: 152
Episode 7, Reward: 25.0, Epsilon: 0.941, Frame: 177
Episode 8, Reward: 40.0, Epsilon: 0.928, Frame: 217
Episode 9, Reward: 21.0, Epsilon: 0.921, Frame: 238
Episode 10, Reward: 22.0, Epsilon: 0.914, Frame: 260
Episode 11, Reward: 16.0, Epsilon: 0.908, Frame: 276
Episode 12, Reward: 12.0, Epsilon: 0.904, Frame: 288
Episode 13, Reward: 19.0, Epsilon: 0.898, Frame: 307
Episode 14, Reward: 37.0, Epsilon: 0.886, Frame: 344
Episode 15, Reward: 73.0, Epsilon: 0.861, Frame: 417
Episode 16, Reward: 11.0, Epsilon: 0.858, Frame: 428
Episode 17, Reward: 22.0, Epsilon: 0.850, Frame: 450
Episode 18, Reward: 38.0, Epsilon: 0.838, Frame: 488
Episode 19, Reward: 15.0, Epsilon: 0.833, Frame: 503
Episo