# Cart Pole

<img src="https://gymnasium.farama.org/_images/cart_pole.gif" style="margin:auto"/>

This environment is part of the Classic Control environments which contains general information about the environment.

|                   |                                                                                                                                                |
|-------------------|------------------------------------------------------------------------------------------------------------------------------------------------|
| Action Space      | `Discrete(2)`                                                                                                                                  |
| Observation Space | `Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)` |
| import            | `gymnasium.make("CartPole-v1")`                                                                                                                |

## Description

This environment corresponds to the version of the cart-pole problem described by Barto, Sutton, and Anderson in ["Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem"](https://ieeexplore.ieee.org/document/6313077). A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The pendulum is placed upright on the cart and the goal is to balance the pole by applying forces in the left and right direction on the cart.

## Action Space

The action is a `ndarray` with shape `(1,)` which can take values `{0, 1}` indicating the direction of the fixed force the cart is pushed with.
* 0: Push cart to the left
* 1: Push cart to the right

**Note**: The velocity that is reduced or increased by the applied force is not fixed and it depends on the angle the pole is pointing. The center of gravity of the pole varies the amount of energy needed to move the cart underneath it

## Observation Space

The observation is a `ndarray` with shape `(4,)` with the values corresponding to the following positions and velocities:

| NUm | Observation           | Min                 | Max               |
|-----|-----------------------|---------------------|-------------------|
| 0   | Cart Position         | -4.8                | 4.8               |
| 1   | Cart Velocity         | -Inf                | Inf               |
| 2   | Pole Angle            | ~ -0.418 rad (-24°) | ~ 0.418 rad (24°) |
| 3   | Pole Angular Velocity | -Inf                | Inf               |

**Note**: While the ranges above denote the possible values for observation space of each element, it is not reflective of the allowed values of the state space in an unterminated episode. Particularly:
* The cart x-position (index 0) can be take values between `(-4.8, 4.8)`, but the episode terminates if the cart leaves the `(-2.4, 2.4)` range. 
* The pole angle can be observed between `(-.418, .418)` radians (or **±24°**), but the episode terminates if the pole angle is not in the range `(-.2095, .2095)` (or **±12°**)

## Rewards

Since the goal is to keep the pole upright for as long as possible, a reward of `+1` for every step taken, including the termination step, is allotted. The threshold for rewards is 500 for v1 and 200 for v0.

## Starting State

All observations are assigned a uniformly random value in `(-0.05, 0.05)`

## Episode End

The episode ends if any one of the following occurs:
1. Termination: Pole Angle is greater than ±12°
2. Termination: Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display)
3. Truncation: Episode length is greater than 500 (200 for v0)

## Arguments
```python
import gymnasium as gym
gym.make('CartPole-v1')
```
On reset, the `options` parameter allows the user to change the bounds used to determine the new random state.

# Deep Q-Network (DQN) Overview

DQN is an extension of Q-Learning where a neural network is used to approximate the Q-values. It combines reinforcement learning with deep learning, allowing the agent to handle complex environments with high-dimensional state spaces.

## Key Concepts

1. **Q-Value (Quality)**: Represents the expected future rewards for an action taken in a given state.
2. **Q-Network**: A neural network that approximates the Q-values for each state-action pair.
3. **Experience Replay**: A technique where the agent's experiences (state, action, reward, next state) are stored in a replay buffer and sampled randomly during training to break correlation and improve learning stability.
4. **Target Network**: A separate network used to stabilize training by providing consistent Q-value targets.

## Algorithm

1. Initialize the Q-network and target network with random weights.
2. Initialize the replay buffer.
3. For each episode:
   * Initialize the starting state.
   * For each step of the episode:
     * Choose an action using an ε-greedy policy.
     * Take the action, observe the reward and next state.
     * Store the experience in the replay buffer.
     * Sample a random batch from the replay buffer.
     * Compute the target Q-value using the target network.
     * Update the Q-network weights using gradient descent.
     * Periodically update the target network weights with the Q-network weights.
4. Repeat until the Q-values converge.

## Update Rule

$$ \ell = (Q(s, a) - (r + \gamma \max_{a'} Q'(s', a')))^2 $$

Where
* $Q$ is the Q-network
* $Q'$ is the target network
* $r$ is the reward
* $\gamma$ is the discount factor
* $(s', a')$ is the next state-action pair

In [None]:
from __future__ import annotations

import random
from collections import deque

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim


def load_environment(training=True):
    """
    Load the Cart Pole environment
    
    :return: environment, state space, action space
    """
    env = gym.make(
        "CartPole-v1",
        render_mode=None if training else "human",
    )
    return env, env.observation_space.shape[0], env.action_space.n


def epsilon_decay(epsilon, epsilon_min, decay_method: callable):
    """
    Decay the epsilon value
    
    :param epsilon: epsilon value
    :param epsilon_min: minimum epsilon value
    :param decay_method: decay method (e.g., lambda e: e * 0.999)
    :return: decayed epsilon value
    """
    return max(epsilon_min, decay_method(epsilon))


def epsilon_greedy(env, policy_dqn, state, epsilon, device):
    """
    Choose an action using an epsilon-greedy policy
    
    :param env: the environment
    :param policy_dqn: the policy DQN
    :param state: the current state
    :param epsilon: the epsilon value
    :return: the chosen action
    """
    if random.random() < epsilon:
        # Explore
        return env.action_space.sample()
    else:
        # Exploit
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            q_values = policy_dqn(state)
        return q_values.argmax().item()

In [None]:
class DeepQNetwork(nn.Module):
    """
    Deep Q-Network (DQN) class
    """
    layers: nn.Sequential

    def __init__(
            self: DeepQNetwork,
            state_space: int,
            action_space: int,
    ):
        super(DeepQNetwork, self).__init__()

        self.fc1 = nn.Linear(state_space, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 64)
        self.fc4 = nn.Linear(64, action_space)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [None]:
class DQNCartPoleAgent:
    """
    Cart Pole agent using Deep Q-Network (DQN)
    """
    env: gym.Env
    state_space: int
    action_space: int

    def __init__(self):
        self.env, self.state_space, self.action_space = load_environment()

        # Set the device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Device: {self.device}")

        # Initialize the DQN
        self.dpn = DeepQNetwork(self.state_space, self.action_space).to(self.device)

        # Initialize training parameters
        self.optimizer = optim.RMSprop
        self.criterion = nn.MSELoss()
        self.memory = deque(maxlen=10000)

    def _remember(self, state, action, reward, next_state, done):
        """
        Store the experience in the replay buffer

        :param state: the current state
        :param action: the chosen action
        :param reward: the received reward
        :param next_state: the next state
        :param done: whether the episode is done
        :return: None
        """
        self.memory.append((state, action, reward, next_state, done))

    def _act(self, state, epsilon):
        """
        Choose an action using an epsilon-greedy policy

        :param state: the current state
        :param epsilon: the current epsilon value
        :return: the chosen action
        """
        if np.random.uniform() < epsilon:
            return self.env.action_space.sample()
        state = torch.FloatTensor(state).to(self.device)
        with torch.no_grad():
            q_values = self.dpn(state)
            return torch.argmax(q_values).item()

    def _replay(self, batch_size, gamma, optimizer):
        """
        Train the DQN using experience replay

        :param batch_size: the size of the mini-batch
        :param gamma: the discount factor
        :param optimizer: the optimizer to use
        :return: None
        """
        if len(self.memory) < batch_size:
            return  # Not enough memory to train

        batch = random.sample(self.memory, batch_size)

        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.BoolTensor(dones).to(self.device)

        # Compute the Q-values
        q_values = self.dpn(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        next_q_values = self.dpn(next_states).max(1).values
        target_q_values = rewards + gamma * next_q_values * ~dones

        # Compute the loss L = ( Q(s, a) - ( r + gamma max_a' Q'(s', a') ) )^2
        loss = self.criterion(q_values, target_q_values.detach())

        # Update the DQN
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    def load(self, path):
        """
        Load the DQN model weights

        :param path: the path to the model weights
        :return: None
        """
        self.dpn.load_state_dict(torch.load(path))

    def save(self, path):
        """
        Save the DQN model weights
        
        :param path: the path to save the model weights
        :return: None
        """
        torch.save(self.dpn.state_dict(), path)

    def train(self,
              lr, gamma, epsilon, epsilon_min,
              n_episodes, batch_size,
              train_start):
        optimizer = self.optimizer(self.dpn.parameters(), lr=lr)

        # rewards = np.zeros(n_episodes)
        rewards = []
        episode_reward = 0
        episode = 0
        # for episode in range(n_episodes):
        while episode_reward < 1000:
            state = self.env.reset()[0]
            done = False

            episode_reward = 0
            while not done:
                # Choose an action using an epsilon-greedy policy
                action = self._act(state, epsilon)

                # Take the action, observe the reward and next state
                next_state, reward, terminated, truncated, info = self.env.step(action)
                done = terminated

                # Penalize the agent for ending the episode early
                if terminated:
                    reward = -100

                # Store the experience in the replay buffer
                self._remember(state, action, reward, next_state, done)

                # Update the DQN using experience replay
                self._replay(batch_size, gamma, optimizer)

                state = next_state
                episode_reward += reward

            # Perform epsilon decay
            if epsilon > epsilon_min and len(self.memory) >= train_start:
                epsilon = epsilon_decay(epsilon, epsilon_min, lambda e: e * 0.999)

            # rewards[episode] = episode_reward
            rewards.append(episode_reward)

            if episode % 100 == 0:
                print(f"Episode: {episode:5d}, Epsilon: {epsilon:.2f}, Reward: {episode_reward:4.2f}")
            episode += 1

        print("Training complete with episode reward:", episode_reward)
        return rewards

    def test(self, truncate):
        env, _, _ = load_environment(training=False)

        state = env.reset()[0]
        terminated = False
        rewards = 0
        while not terminated:
            env.render()
            action = self._act(state, 0)
            next_state, reward, terminated, truncated, _ = env.step(action)
            state = next_state
            rewards += reward

            if rewards >= truncate:
                print(f"Truncate at {truncate} steps")
                break

            if rewards % (truncate // 10) == 0:
                print(f"Reached {rewards} steps")

        print("Test complete with reward:", rewards)
        env.close()

    def close(self):
        self.env.close()

In [None]:
agent = DQNCartPoleAgent()

rewards = agent.train(
    lr=0.00025,
    gamma=0.95,
    epsilon=1.0,
    epsilon_min=0.001,
    n_episodes=2000,
    batch_size=64,
    train_start=1000
)

# Save the model
agent.save("output/cart_pole_dqn.pth")

# Plot the rewards
import matplotlib.pyplot as plt

plt.plot(rewards)
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("Cart Pole with DQN")
plt.savefig("output/cart_pole_rewards.png")

agent.close()

In [None]:
agent = DQNCartPoleAgent()

agent.load("output/cart_pole_dqn.pth")
agent.test(200)
agent.close()

<img src="./output/cart_pole_rewards.png" style="margin:auto"/>

<br>

<img src="./output/result.gif" style="margin:auto"/>