# Training an Agent on CartPole 🏋️‍♂️

In this notebook, we’ll train an agent using **Deep Q-Learning (DQN)** on the popular `CartPole-v1` environment from OpenAI Gym.

**Goals:**
- Understand the CartPole environment.
- Implement a Deep Q-Network agent using PyTorch.
- Train and evaluate performance.
- Visualize the learning progress.

## 1. Setup and Imports

In [None]:
!pip install gym torch numpy matplotlib --quiet

In [None]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

## 2. Define the DQN Model

In [None]:
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_size)
        )
    def forward(self, x):
        return self.fc(x)

## 3. Define Replay Memory and Helper Functions

In [None]:
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.memory = deque(maxlen=capacity)
    def push(self, transition):
        self.memory.append(transition)
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def __len__(self):
        return len(self.memory)

## 4. Initialize Environment and Parameters

In [None]:
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

q_net = DQN(state_size, action_size).to(device)
target_net = DQN(state_size, action_size).to(device)
target_net.load_state_dict(q_net.state_dict())

optimizer = optim.Adam(q_net.parameters(), lr=1e-3)
memory = ReplayBuffer(10000)

batch_size = 64
gamma = 0.99
epsilon = 1.0
epsilon_decay = 0.995
epsilon_min = 0.05
target_update = 10
episodes = 500
reward_history = []

## 5. Training Loop

In [None]:
def select_action(state, epsilon):
    if np.random.rand() < epsilon:
        return random.randrange(action_size)
    state = torch.FloatTensor(state).unsqueeze(0).to(device)
    with torch.no_grad():
        return q_net(state).argmax().item()

for ep in range(episodes):
    state, _ = env.reset()
    total_reward = 0
    done = False

    while not done:
        action = select_action(state, epsilon)
        next_state, reward, done, _, _ = env.step(action)
        memory.push((state, action, reward, next_state, done))
        state = next_state
        total_reward += reward

        if len(memory) >= batch_size:
            transitions = memory.sample(batch_size)
            states, actions, rewards, next_states, dones = zip(*transitions)

            states = torch.FloatTensor(states).to(device)
            actions = torch.LongTensor(actions).unsqueeze(1).to(device)
            rewards = torch.FloatTensor(rewards).unsqueeze(1).to(device)
            next_states = torch.FloatTensor(next_states).to(device)
            dones = torch.FloatTensor(dones).unsqueeze(1).to(device)

            q_values = q_net(states).gather(1, actions)
            next_q_values = target_net(next_states).max(1)[0].unsqueeze(1)
            target_values = rewards + gamma * next_q_values * (1 - dones)

            loss = nn.MSELoss()(q_values, target_values)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    epsilon = max(epsilon_min, epsilon * epsilon_decay)
    reward_history.append(total_reward)

    if ep % target_update == 0:
        target_net.load_state_dict(q_net.state_dict())

    if (ep + 1) % 50 == 0:
        print(f"Episode {ep+1}/{episodes}, Avg Reward (last 50): {np.mean(reward_history[-50:]):.2f}, Epsilon: {epsilon:.2f}")

## 6. Visualize Training Performance

In [None]:
plt.plot(reward_history)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('DQN Training on CartPole')
plt.show()

## 7. Test the Trained Agent

In [None]:
state, _ = env.reset()
done = False
total_reward = 0
while not done:
    action = select_action(state, epsilon=0.0)
    next_state, reward, done, _, _ = env.step(action)
    total_reward += reward
    state = next_state
print(f'Total reward during test: {total_reward}')

### ✅ Summary
- Implemented a **Deep Q-Network (DQN)** for CartPole.
- Used **experience replay** and **target networks** for stability.
- Visualized the training progress.

This forms the foundation for advanced RL algorithms like **DDQN**, **Dueling DQN**, and **PPO**.