In [None]:
%pip install gymnasium
import gymnasium as gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
from time import sleep

from tqdm.notebook import tqdm
from collections import namedtuple
from collections import deque
import random

<a href="https://colab.research.google.com/github/EffiSciencesResearch/ML4G-2.0/blob/master/workshops/rl/DQN-workbook-empty-hard.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
## The Cartpole environment

Cartpole is a classic RL control task. A cart lies on a track and has a pole attached to it which can freely rotate. At each time step, the cart may choose to move left or right. The goal is to keep the pole upright for as long as possible. 

### Termination and truncations conditions
The environment terminates if the pole angle is more than 12 degrees from vertical, or if the cart position is more than 2.4 units from the center.
The environment is truncated after 500 time steps, if the pole is still upright.

### Rewards
The agent recieves a reward of +1 at each time step, as long as the pole is upright.

### Observations
At each time point, the RL agent recieves the following observations:
1. Cart position
2. Cart velocity 
3. Pole angle from vertical
4. Pole anglular velocity 

### Action space
The agent can take one of two actions at each time step:
1. Move the cart to the left
2. Move the cart to the right

# DQN algorithm components

The DQN algorithm requires the following components, with the following roles:
1. **Value network**: A neural network that takes in states and outputs predictions for the value of each action in that state
2. **Target network**: A copy of the value network with lagged parameters that is used to compute the target values for the regression targets
3. **Replay buffer**: A buffer that stores transitions (state, action, reward, next_state, terminated) that the agent has experienced. This is used to sample mini-batches of transitions to train the value network
4. **Policies**: These use the value network to select actions in the environment. We will consider two policies:
    - **Greedy policy**: This policy selects the action with the highest value, and is used to evaluate the value network
    - **$\epsilon$-greedy policy**: This policy selects a random action with probability $\epsilon$, and the action with the highest value with probability 1-$\epsilon$. This is used to explore the environment using the training phase. 
   

## The value and target networks

We start by defining the value network architecture. This is a simple feedforward neural network (a **multi-layer perceptron**). The input to the network is the state of the environment, and the output is the value prediction *of each action*. The target network is a copy of the value network, with lagged parameters. 

### Architecture recommendations
* Use two hidden layers with 128 units each
* Use ReLU activation functions 


In [None]:
# Define the QNet class
class QNet(nn.Module):
    # Initialise the network using the size of the observation space and the number of actions
    def __init__(self, env: gym.Env):
        # Use the nn.Module's __init__ method to ensure that the parameters can be updated during training
        _____________
        # Store the size of the observation space and the number of actions
        self.obs_size = env.observation_space.shape[0]
        self.n_actions = env.action_space.n

        # Define the layers of the network
        self.layers = nn.Sequential(________________________)

    # Define the forward method
    def ________(self, x):
        return __________

## The replay buffer

To implement the replay buffer, we first define a named tuple data type to store transitions. We then define the replay buffer class, which stores transitions and can sample mini-batches of transitions. The replay buffer has the following methods:
1. **Push**: This method stores a new transition in the buffer, and removes the oldest transition if the buffer is full
2. **Sample**: This method samples a mini-batch of transitions from the buffer

In [None]:
# Define the transition named tuple
# This will be used to store the transitions (state, action, reward, next_state, terminated) in the replay buffer
Transition = namedtuple("Transition", ("state", "action", "reward", "next_state", "terminated"))


# Define the ReplayBuffer class
class ReplayBuffer:
    # Initialise the buffer with a capacity.
    def __init__(self, capacity: int):
        # Use a deque object to implement the buffer
        self.buffer = ___________

    # Define the push method to add a transition to the buffer
    def push(
        self,
        state: np.ndarray,
        action: int,
        reward: float,
        next_state: np.ndarray,
        terminated: bool,
    ):
        new_transition = ___________
        self.buffer._________

    # Sample a mini_batch of transitions for training
    def sample(self, mini_batch_size: int):
        return random.sample(________, mini_batch_size)

## The DQN agent
We now define the DQN agent class. This class will have:
1. **A value network**, implemented using the QNet class
2. **A target network**, implemented using the QNet class
3. **A replay buffer**, implemented using the ReplayBuffer class
4. **An optimiser**, which is used to train the value network

The DQN agent will have the following methods:
1. **Greedy policy**: This method selects the action with the highest value
2. **$\epsilon$-greedy policy**: This method selects a random action with probability $\epsilon$, and the action with the highest value with probability 1-$\epsilon$
3. **Sync**: This method synchronises the target network with the value network by loading the parameters of the value network into the target network

In [None]:
# Define the DQNAgent class
class DQNAgent:
    def __init__(
        self,
        env: gym.Env,
        buffer_capacity: int = 200000,
        gamma: float = 0.95,
        epsilon: float = 0.05,
        lr: float = 0.003,
        mini_batch_size: int = 32,
    ):
        self.buffer_capacity = buffer_capacity
        self.gamma = gamma
        self.epsilon = epsilon
        self.lr = lr
        self.mini_batch_size = mini_batch_size

        # Save the number of actions and the observation size
        self.n_actions = env.action_space.n
        self.obs_size = env.observation_space.shape[0]

        # Define the value network for the agent.
        self.value_network = ______
        # Define the target network for the agent.
        self.target_network = ______
        # Sync the networks
        _____________
        # Define the replay buffer
        self.replay_buffer = _________________
        # Define the optimizer
        self.optimiser = _______________________

    # Define the sync method to sync the target network with the value network
    def sync(self):
        # First, retrieve the state dictionary of the value network
        value_network_state_dict = ______________
        # Load the state dictionary into the target network
        _________________________

    # Define the greedy policy
    def greedy_policy(self, state) -> int:
        # Enter no-gradient mode (since we are not training the network when we sample actions)
        with _____________:
            # Convert the state to a tensor
            state = torch.tensor(state, dtype=torch.float32)
            # Compute the Q values for the state using the value network
            q_values = ____________
            # Find the action with the highest Q value, converting it to a integer
            max_action = _____________.item()
            # Return the action
            return max_action

    # Define the epsilon greedy policy
    def epsilon_greedy_policy(self, state) -> int:
        # Sample a random number uniformly betwen 0 and 1
        rand_num = np.random.random()

        # If the random number is less than the exploration rate, choose a random action
        if ______________:
            # Choose a random action
            action = np.random.randint(___________)
            # Return the action
            return action

        # Otherwise, choose the action with the highest action-value
        else:
            # Use the greedy policy to choose the action
            action = ___________
            # Return the action
            return action

## Interacting with the environment
During training, the DQN agent must interact with the environment in order to collect transitions to be added to the replay buffer. We define an **interact** function that takes in the environment, the DQN agent, and the number of time steps to interact for. This function will use the $\epsilon$-greedy policy to select actions, and will store the transitions in the replay buffer. 

In [None]:
# Define the interact method
# This method takes in an agent, an environment, and the number of steps to interact for.
# It returns the final state after the interaction.
def interact(agent, env: gym.Env, current_state: np.ndarray, n_steps: int) -> np.ndarray:
    # Loop over the steps
    for _ in range(n_steps):
        # Choose an action using the epsilon greedy policy
        action = ________________(current_state)
        # Take a step in the environment using the action
        next_state, reward, terminated, truncated, info = env.step(action)
        # Push the transition to the replay buffer
        _______________________
        # Reset the environment if the episode is terminated or truncated
        if terminated or truncated:
            current_state, _ = env.reset()
        else:
            # Update the current state
            current_state = next_state
    return current_state

## Training the DQN agent

We define a **update weights** function that updates the weights of the value network for a DQN agent. This function will:
1. Sample a mini-batch of transitions from the replay buffer
2. Extracts the states, actions, rewards, next_states, and terminated flags from the transitions
3. Uses the target network, rewards, and terminated flags to compute the target values for the value network 
4. Computes the loss between the value network predictions and the target values
5. Runs backpropagation and updates the weights of the value network using that loss

As a reminder, the regression targets for $Q(s,a)$ are given by:
$$ y = r + \gamma (1-d) \max_{a'} Q(s',a')$$

In [None]:
# This method updates the weights of the value-network of the agent using a mini-batch of transitions
def update_weights(agent: DQNAgent):
    # Sample a mini-batch of transitions from the replay buffer
    mini_batch = ___________________________

    # Extract the mini-batch of states as float32 tensors
    states = torch.tensor([transition.state for transition in mini_batch], dtype=torch.float32)
    # Extract the mini-batch of actions as int64 tensors
    actions = torch.tensor([transition.action for transition in mini_batch], dtype=torch.int64)
    # Extract the mini-batch of rewards as float32 tensors
    rewards = torch.tensor([transition.reward for transition in mini_batch], dtype=torch.float32)
    # Extract the mini-batch of next states as float32 tensors
    next_states = torch.tensor(
        [transition.next_state for transition in mini_batch], dtype=torch.float32
    )
    # Extract the mini-batch of terminated flags as bool tensors
    terminated = torch.tensor(
        [transition.terminated for transition in mini_batch], dtype=torch.bool
    )

    # Enter no-gradient mode
    with ____________:
        # Compute the next state values using the target network
        next_state_values = agent._________(next_states)
        # Take the max over the actions (dim=1)
        max_next_values = _____________.max(dim=1)[0]
        # Zero out the max-next-values for the terminal states
        max_next_values[_______] = ________
        # Compute the regression targets
        regression_targets = ________________

    # Compute the Q values for all actions
    q_values = agent.__________(states)
    # Compute the Q values for the actions that were taken
    q_SA = torch.zeros(agent.mini_batch_size)
    for i in range(agent.mini_batch_size):
        q_SA[i] = ____________
    # Compute the loss using the mean squared error
    loss = _________________

    # Zero the gradients using the optimiser
    ___________
    # Compute the gradients of the loss through backpropagation
    ___________
    # Take a step with the optimiser
    ___________

## The training loop for the agent
We now put everything together into a training loop for the DQN agent. This loop will:
1. Interact with the environment for a number of time steps
2. Update the weights of the value network
3. Synchronize the target network with the value network at regular intervals

In [None]:
# Define the train_loop method
def train_loop(
    agent: DQNAgent, env: gym.Env, interactions_per_update: int, update_steps: int, sync_delay: int
):
    # Initialise the environment
    state, _ = env.reset()

    # Gather initial interactions for the experience replay buffer
    state = interact(agent, env, state, 10000)

    # Loop over the update_steps
    for step in tqdm(range(update_steps)):
        # Interact with the environment
        state = ____________________________

        # Update the weights of the value network
        ____________

        # Sync the networks every sync_delay steps
        if _____________________:
            ___________

## Some helper functions
We define some helper functions to:
1. Evaluate the agent's performance
2. Visualise the agent's performance

In [None]:
# Define the evaluate function
def evaluate(agent: DQNAgent, env: gym.Env, n_episodes: int) -> float:
    # Initialise the list of rewards
    returns = []

    # Loop over the episodes
    for episode in tqdm(range(n_episodes)):
        # Get the initial state
        state, _ = env.reset()
        # Initialise the episode reward
        episode_return = 0

        # Loop over the steps
        while True:
            # Choose the action with the highest Q value
            action = agent.greedy_policy(state)
            # Take the action
            next_state, reward, terminated, truncated, info = env.step(action)
            # Update the state and reward
            state = next_state
            episode_return += reward
            # Break if the episode has terminated
            if terminated or truncated:
                break

        # Append the episode reward to the list of rewards
        returns.append(episode_return)
    # Return the mean of the rewards
    return np.mean(returns)

In [None]:
# Define the visualise function
# This displays the agent's behaviour in the environment.
def visualise(agent: DQNAgent, env: gym.Env, n_steps: int):
    # Reset the environment
    state, _ = env.reset()

    # Initialise the list of frames
    frames = []

    for _ in range(n_steps):
        # Render the environment and store the frame
        frames.append(env.render())

        # Take an action using the greedy policy
        action = agent.greedy_policy(state)
        next_state, reward, terminated, truncated, info = env.step(action)
        if terminated or truncated:
            state, _ = env.reset()
        else:
            state = next_state

    # Display the movie
    for frame in frames:
        clear_output(wait=True)
        plt.imshow(frame)
        plt.show()
        sleep(0.003)

# Let's gooooo

We will now train our network using the DQN algorithm and visualise the agent's performance. Have fun!

In [None]:
# Create the environment
env = gym.make("CartPole-v1", render_mode="rgb_array")
# Create the agent
agent = DQNAgent(env)

In [None]:
# Evaluate the agent's performance before training
print("Performance before training:", evaluate(agent, env, 100))

In [None]:
# Visualise the agent's behaviour before training
visualise(agent, env, 100)

In [None]:
# Train the agent
train_loop(agent, env, interactions_per_update=8, update_steps=40000, sync_delay=500)

In [None]:
# Evaluate the agent's performance after training
print("Performance after training:", evaluate(agent, env, 100))

In [None]:
# Visualise the agent's behaviour after training
visualise(agent, env, 200)