In [3]:
import random
import collections
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim

In [4]:
# --- Hyperparameters ---
GAMMA = 0.99
LR = 1e-3
BATCH_SIZE = 64
BUFFER_SIZE = 10000
EPS_START = 1.0
EPS_END = 0.01
EPS_DECAY = 500
TARGET_UPDATE = 10
NUM_EPISODES = 500
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
# --- Neural Network for Q-values ---
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )

    def forward(self, x):
        return self.net(x)

In [6]:
# --- Replay Buffer ---
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states), actions, rewards, np.array(next_states), dones)

    def __len__(self):
        return len(self.buffer)

In [7]:
# --- Epsilon Greedy Policy ---
def select_action(state, policy_net, steps_done):
    epsilon = EPS_END + (EPS_START - EPS_END) * np.exp(-1. * steps_done / EPS_DECAY)
    if random.random() < epsilon:
        return random.randrange(action_dim)
    else:
        state = torch.FloatTensor(state).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            return policy_net(state).argmax().item()

In [8]:
# --- Training Loop ---
env = gym.make("CartPole-v1")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

policy_net = QNetwork(state_dim, action_dim).to(DEVICE)
target_net = QNetwork(state_dim, action_dim).to(DEVICE)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), lr=LR)
replay_buffer = ReplayBuffer(BUFFER_SIZE)

steps_done = 0

In [10]:
torch.save(policy_net.state_dict(), "dqn_cartpole.pth")

In [9]:
for episode in range(NUM_EPISODES):
    state, _ = env.reset()
    total_reward = 0

    while True:
        action = select_action(state, policy_net, steps_done)
        next_state, reward, done, truncated, _ = env.step(action)
        total_reward += reward
        replay_buffer.push(state, action, reward, next_state, done or truncated)

        state = next_state
        steps_done += 1

        if len(replay_buffer) > BATCH_SIZE:
            states, actions, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)

            states = torch.FloatTensor(states).to(DEVICE)
            actions = torch.LongTensor(actions).unsqueeze(1).to(DEVICE)
            rewards = torch.FloatTensor(rewards).to(DEVICE)
            next_states = torch.FloatTensor(next_states).to(DEVICE)
            dones = torch.FloatTensor(dones).to(DEVICE)

            # Compute Q values
            q_values = policy_net(states).gather(1, actions)

            # Compute target values
            next_q_values = target_net(next_states).max(1)[0].detach()
            target = rewards + (1 - dones) * GAMMA * next_q_values

            loss = nn.MSELoss()(q_values.squeeze(), target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if done or truncated:
            print(f"Episode {episode+1}: Total Reward = {total_reward}")
            break

    # Update target network
    if episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

env.close()

Episode 1: Total Reward = 10.0
Episode 2: Total Reward = 16.0
Episode 3: Total Reward = 10.0
Episode 4: Total Reward = 25.0
Episode 5: Total Reward = 17.0
Episode 6: Total Reward = 16.0
Episode 7: Total Reward = 10.0
Episode 8: Total Reward = 15.0
Episode 9: Total Reward = 12.0
Episode 10: Total Reward = 17.0
Episode 11: Total Reward = 11.0
Episode 12: Total Reward = 15.0
Episode 13: Total Reward = 13.0
Episode 14: Total Reward = 16.0
Episode 15: Total Reward = 32.0
Episode 16: Total Reward = 43.0
Episode 17: Total Reward = 13.0
Episode 18: Total Reward = 38.0
Episode 19: Total Reward = 10.0
Episode 20: Total Reward = 14.0
Episode 21: Total Reward = 13.0
Episode 22: Total Reward = 18.0
Episode 23: Total Reward = 23.0
Episode 24: Total Reward = 16.0
Episode 25: Total Reward = 21.0
Episode 26: Total Reward = 14.0
Episode 27: Total Reward = 27.0
Episode 28: Total Reward = 11.0
Episode 29: Total Reward = 95.0
Episode 30: Total Reward = 20.0
Episode 31: Total Reward = 13.0
Episode 32: Total

In [16]:
# --- Load Environment ---
env = gym.make("CartPole-v1", render_mode="human")  # 'human' opens a window
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# --- Load Trained Model ---
policy_net = QNetwork(state_dim, action_dim)
policy_net.load_state_dict(torch.load("dqn_cartpole.pth"))
policy_net.eval()

# --- Run Visualization ---
for episode in range(5):  # run 5 demo episodes
    state, _ = env.reset()
    total_reward = 0
    done = False
    while not done:
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        with torch.no_grad():
            action = policy_net(state_tensor).argmax().item()
        state, reward, done, truncated, _ = env.step(action)
        total_reward += reward
        if done or truncated:
            print(f"Episode {episode+1} Reward: {total_reward}")
            break

env.close()

Episode 1 Reward: 156.0
Episode 2 Reward: 165.0
Episode 3 Reward: 163.0
Episode 4 Reward: 157.0
Episode 5 Reward: 152.0
Episode 6 Reward: 159.0
Episode 7 Reward: 152.0
Episode 8 Reward: 161.0
Episode 9 Reward: 175.0
Episode 10 Reward: 168.0
Episode 11 Reward: 172.0
Episode 12 Reward: 169.0
Episode 13 Reward: 155.0
Episode 14 Reward: 161.0
Episode 15 Reward: 166.0
Episode 16 Reward: 155.0
Episode 17 Reward: 165.0
Episode 18 Reward: 160.0
Episode 19 Reward: 156.0
Episode 20 Reward: 158.0
Episode 21 Reward: 164.0
Episode 22 Reward: 170.0
Episode 23 Reward: 154.0
Episode 24 Reward: 160.0
Episode 25 Reward: 160.0
Episode 26 Reward: 158.0
Episode 27 Reward: 150.0
Episode 28 Reward: 160.0
Episode 29 Reward: 169.0
Episode 30 Reward: 165.0
Episode 31 Reward: 151.0
Episode 32 Reward: 164.0
Episode 33 Reward: 172.0
Episode 34 Reward: 162.0
Episode 35 Reward: 160.0
Episode 36 Reward: 162.0
Episode 37 Reward: 176.0
Episode 38 Reward: 162.0
Episode 39 Reward: 175.0
Episode 40 Reward: 150.0
Episode 4