#### **This notebook is based on [Nimish Sanghi's book on Deep RL](https://link.springer.com/book/10.1007/979-8-8688-0273-7).**

This notebook builds up on the basic Deep Q-Networks (DQNs) notebook: https://www.kaggle.com/code/aryamanbansal/basic-dqn

Feel free to check out this notebook on Kaggle: https://www.kaggle.com/code/aryamanbansal/

### **Motivation**

In the original Deep Q-Network (DQN) algorithm, the action-value function is approximated using a neural network. However, a well-known issue with this approach is **overestimation bias**. 

When the max operator is used to select the best action during the target value computation, it tends to overestimate the Q-values, leading to suboptimal policies. **Double DQN (DDQN)** was proposed to mitigate this problem by decoupling the action selection and evaluation, thereby reducing the overoptimistic estimates.

### **Comparative Study: DDQN vs. Basic DQN**

- **Basic DQN:**
    - Pros: Simple to implement and effective in many scenarios.
    - Cons: Tends to overestimate action values due to the max operator.
- **Double DQN:**
    - Pros: Reduces overestimation bias by decoupling the action selection from evaluation.
    - Cons: Slightly more complex to implement since it requires two networks (or at least two forward passes).

In [None]:
import gymnasium as gym
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt
import random
from scipy.signal import convolve
from scipy.signal.windows import gaussian

from base64 import b64encode
from IPython.display import HTML, clear_output

from tqdm import trange

print("imports done!")

imports done!


In [None]:
# set a seed
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

  and should_run_async(code)


<torch._C.Generator at 0x786050e12d90>

In [None]:
# Assuming a global device setting (CPU or CUDA)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

### **Neural Network for DQN**

In [None]:
class DQNAgent(nn.Module):
    def __init__(self, state_shape, n_actions, epsilon=0.1):
        """
        Initializes the DQNAgent.
        
        Args:
            state_shape (tuple): Shape of the input state.
            n_actions (int): Number of possible actions.
            epsilon (float): Exploration rate for epsilon-greedy policy.
        """
        super().__init__()
        self.epsilon = epsilon
        self.n_actions = n_actions
        self.state_shape = state_shape
        
        # Define a simple feedforward neural network
        self.network = nn.Sequential(
            nn.Linear(state_shape[0], 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, n_actions)
        )

    
    def forward(self, state_t):
        """
        Pass the state at time t through the network to get the Q-value Q(s,a).
        
        Args:
            state_t (torch.Tensor): The state at time t.
            
        Returns:
            torch.Tensor: Q-values for each action.
        """
        qvalues = self.network(state_t)
        return qvalues

    
    def get_qvalues(self, states):
        """
        Compute Q-values for a batch of states provided as arrays.
        
        Args:
            states (array-like): Batch of states.
            
        Returns:
            np.array: Q-values for each state.
        """
        states = torch.tensor(np.array(states), device=device, dtype=torch.float32)
        qvalues = self.forward(states)
        return qvalues.data.cpu().numpy()

    
    def get_action(self, states):
        """
        Returns the best (greedy) actions for a batch of states.
        
        Args:
            states (array-like): Batch of states.
            
        Returns:
            np.array: Best actions.
        """
        states = torch.tensor(np.array(states), device=device, dtype=torch.float32)
        qvalues = self.forward(states)
        best_actions = qvalues.argmax(axis=-1)
        return best_actions


    def sample_actions(self, qvalues):
        """
        Implements the epsilon-greedy policy on a batch of Q-values
        
        Args:
            qvalues (np.array): Q-values for a batch of states.
            
        Returns:
            np.array: Actions selected (random with probability epsilon, otherwise best action).
        """
        epsilon = self.epsilon
        batch_size, n_actions = qvalues.shape
        # Randomly choose actions for exploration
        random_actions = np.random.choice(n_actions, size=batch_size)
        # Greedy actions based on Q-values
        best_actions = qvalues.argmax(axis=-1)
        # Create an array of booleans indicating whether to explore (0) or exploit (1)
        should_explore = np.random.choice([0, 1], batch_size, p=[1 - epsilon, epsilon])
        # Choose random actions where element is 0, otherwise choose best actions
        return np.where(should_explore, random_actions, best_actions)


    def save(self, path):
        """
        Saves the model parameters to a file.
        
        Args:
            path (str): Path to save the model.
        """
        print("Saving model to:", path)
        torch.save(self.network.state_dict(), f"{path}.zip")



### **Replay Buffer**

In [None]:
class ReplayBuffer:
    def __init__(self, size):
        """
        Initialize the ReplayBuffer.

        Args:
            size (int): Maximum number of experiences to store.
        """
        self.size = size  # Maximum buffer size
        self.buffer = []  # List to store experiences
        self.next_id = 0  # Index pointer for cyclic buffer replacement


    def __len__(self):
        """
        Return the current number of experiences in the buffer.
        """
        return len(self.buffer)


    def add(self, state, action, reward, next_state, done):
        """
        Add a new experience to the buffer.

        Args:
            state (object): The current state.
            action (int): The action taken.
            reward (float): The reward received.
            next_state (object): The next state after taking the action.
            done (bool): Flag indicating whether the episode has terminated.
        """
        # Pack the experience into a tuple
        item = (state, action, reward, next_state, done)
        
        # If the buffer isn't full, append the experience
        if len(self.buffer) < self.size:
            self.buffer.append(item)
        else:
            # If the buffer is full, overwrite the oldest experience
            self.buffer[self.next_id] = item
        
        # Update the index in a cyclic manner
        self.next_id = (self.next_id + 1) % self.size


    def sample(self, batch_size):
        """
        Sample a batch of experiences from the buffer.

        Args:
            batch_size (int): Number of experiences to sample.

        Returns:
            A tuple of numpy arrays: (states, actions, rewards, next_states, done_flags)
        """
        # Randomly select indices for the batch
        idxs = np.random.choice(len(self.buffer), batch_size, replace=False)
        # Retrieve the experiences at the selected indices
        samples = [self.buffer[i] for i in idxs]
        # Unzip the list of tuples into separate components
        states, actions, rewards, next_states, done_flags = list(zip(*samples))
        # Convert each component into a numpy array and return
        return (np.array(states),
                np.array(actions),
                np.array(rewards),
                np.array(next_states),
                np.array(done_flags))

### **TD Loss for Double DQN**

1. **Standard DQN Target Calculation**: In basic DQN, the target is computed as:

$$Q_{\text{target}} = r + \gamma \max\limits_{a'} Q_{\text{target}}(s', a')$$

The max operation can lead to an overestimation because it picks the highest value even if it is overestimated.

&nbsp;  

2. **Double DQN Target Calculation**: To reduce overestimation, DDQN decouples the selection and evaluation steps:

- **Action selection**:  $$a^* = \argmax\limits_{a'} \; Q_{\text{online}}(s', a')$$
- **Action evaluation**: $$Q_{\text{target}}(s,a) = r + \gamma Q_{\text{target}}(s', a^*)$$

This approach uses the online network to decide which action to take (action selection) and the target network to estimate its Q-value (action evaluation).

&nbsp;  

3. **Intuitive Analogy**: Imagine you are trying to decide the best restaurant to visit. Instead of asking the same person (who might be biased by recent positive experiences) for both the recommendation and the review, you ask two different people:
- One person (the online network) suggests the restaurant.
- Another (the target network) provides a more impartial review. This separation helps reduce the chance of both recommendation and review being overly optimistic.

In [None]:
def td_loss_ddqn(agent, target_network, states, actions, rewards, next_states, done_flags,
                 gamma=0.99, device=torch.device("cpu")):
    """
    Computes the Temporal Difference (TD) loss for Double DQN (DDQN) and returns the mean squared error (MSE) loss.

    In Double DQN, the online network (agent) is used to select the best next action, while the target network
    is used to evaluate the Q-value of that action. This function calculates the loss based on the difference between
    the predicted Q-values for the taken actions and the target Q-values computed using the target network.

    Args:
        agent (torch.nn.Module): The online Q-network used for action selection and current Q-value estimation.
        target_network (torch.nn.Module): The target Q-network used for evaluating the Q-value of the selected action.
        states (np.array): Batch of current state observations.
        actions (np.array): Batch of actions taken corresponding to the states.
        rewards (np.array): Batch of rewards received after taking the actions.
        next_states (np.array): Batch of next state observations following the actions.
        done_flags (np.array): Batch of flags indicating terminal states (1 if episode ended, 0 otherwise).
        gamma (float, optional): Discount factor for future rewards. Default is 0.99.
        device (torch.device, optional): Device on which to perform the tensor computations (CPU/GPU). Default is CPU.

    Returns:
        torch.Tensor: The computed mean squared error loss.
    """
    # Convert input numpy arrays to torch tensors and move them to the specified device
    states = torch.tensor(states, device=device, dtype=torch.float)
    actions = torch.tensor(actions, device=device, dtype=torch.long)
    rewards = torch.tensor(rewards, device=device, dtype=torch.float)
    next_states = torch.tensor(next_states, device=device, dtype=torch.float)
    # Ensure done_flags is float and on the correct device
    done_flags = torch.tensor(done_flags.astype('float32'), device=device, dtype=torch.float)

    # Compute Q-values for current states using the online network
    q_s = agent(states)  # Shape: [batch_size, num_actions]

    # Select the Q-values for the actions that were actually taken
    # For each state in the batch, we select the Q-value corresponding to the chosen action
    q_s_a = q_s[range(len(actions)), actions]  # Shape: [batch_size]

    # Compute Q-values for next states for action selection
    # Use the agent (or online) network to estimate Q-values for all actions in the next state
    # Detach from the graph to prevent gradients from flowing through the next state
    q_s1 = agent(next_states).detach()  
    # Compute the index of the best action for each next state using the online network
    _, a1max = torch.max(q_s1, dim=1)  # a1max: Indices of best actions for each next state

    # Evaluate the best action using the target network
    # Use the target network to compute Q-values for all actions in the next state
    q_s1_target = target_network(next_states)
    # Gather the Q-value for the best action selected by the online network
    q_s1_a1max = q_s1_target[range(len(a1max)), a1max]

    # Compute the target Q-values using the Bellman equation
    # For non-terminal states (done_flags=0), target = reward + gamma * Q(next_state, best_action)
    # For terminal states (done_flags=1), target = reward (since no future reward is considered)
    target_q = rewards + gamma * q_s1_a1max * (1 - done_flags)

    # Compute the Mean Squared Error loss
    # Calculate the difference between the predicted Q-values for the taken actions and the target Q-values
    loss = torch.mean((q_s_a - target_q).pow(2))

    return loss


### **Recording video of trained agents**



In [None]:
def record_video(env_id, video_folder, video_length, agent):
    """
    Record a video of the agent interacting with the environment.

    Args:
        env_id (str): Environment ID (e.g., 'CartPole-v1').
        video_folder (str): Folder where the video will be saved.
        video_length (int): Number of timesteps to record.
        agent: Trained agent with a get_action() method.
    
    Returns:
        str: The file path to the saved video.
    """
    # Create a dummy vectorized environment with rendering enabled.
    vec_env = DummyVecEnv([lambda: gym.make(env_id, render_mode="rgb_array")])
    
    # Wrap the environment with VecVideoRecorder.
    vec_env = VecVideoRecorder(
        vec_env, video_folder,
        record_video_trigger=lambda x: x == 0,
        video_length=video_length,
        name_prefix=f"{type(agent).__name__}-{env_id}"
    )

    # Reset environment to start recording.
    obs = vec_env.reset()
    for _ in range(video_length + 1):
        # Get action from the agent and step the environment.
        action = agent.get_action(obs).detach().cpu().numpy()
        obs, _, _, _ = vec_env.step(action)
    
    # Construct the file path of the recorded video.
    file_path = "./" + video_folder + vec_env.video_recorder.path.split("/")[-1]
    vec_env.close()
    return file_path


In [None]:
def play_video(file_path):
    """
    Display a video file in a Jupyter Notebook.

    Args:
        file_path (str): Path to the video file.

    Returns:
        HTML: HTML object that can display the video.
    """
    # Read video file in binary mode.
    mp4 = open(file_path, 'rb').read()
    # Encode the video file in base64.
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    # Create HTML snippet with a video player.
    return HTML(f"""
        <video width=400 controls>
            <source src="{data_url}" type="video/mp4">
        </video>
        """)


### **Setting up the training parameters**



In [None]:
def play_and_record(start_state, agent, env, exp_replay, n_steps=1):
    """
    Run the agent in the environment for a fixed number of steps and record the transitions.

    This function allows the agent to interact with the environment for `n_steps` timesteps,
    collects the transitions (state, action, reward, next state, done flag), and stores them
    in the experience replay buffer. It also accumulates the total reward obtained during these steps.

    Args:
        start_state: The initial state from which the agent starts.
        agent: The DQN agent that provides action selection via its get_qvalues and sample_actions methods.
        env: The environment in which the agent is acting (should follow the Gymnasium API).
        exp_replay: The experience replay buffer (an instance of ReplayBuffer) to store transitions.
        n_steps (int, optional): The number of steps to run the agent in the environment. Defaults to 1.

    Returns:
        tuple: A tuple (sum_rewards, s) where:
            - sum_rewards (float): The total reward accumulated over the n_steps.
            - s: The state at the end of the n_steps, which can be used as the starting state for subsequent calls.
    """
    s = start_state          # Initialize the current state.
    sum_rewards = 0          # Initialize the reward accumulator.

    # Run the agent for n_steps steps.
    for _ in range(n_steps):
        # Obtain Q-values for the current state.
        qvalues = agent.get_qvalues([s])
        
        # Select an action using the agent's epsilon-greedy policy.
        a = agent.sample_actions(qvalues)[0]
        
        # Execute the action in the environment.
        next_s, r, terminated, truncated, _ = env.step(a)
        
        # Accumulate the reward.
        sum_rewards += r
        
        # Check if the episode has ended.
        done = terminated or truncated
        
        # Record the transition in the replay buffer.
        exp_replay.add(s, a, r, next_s, done)
        
        # Update the current state:
        # If the episode ended, reset the environment; otherwise, continue with the next state.
        if done:
            s, _ = env.reset()
        else:
            s = next_s

    return sum_rewards, s

In [None]:
# Setup the environment and agent networks
env_name = 'CartPole-v1'
env = gym.make(env_name, render_mode="rgb_array", max_episode_steps=4000)    # Create the environment
state_dim = env.observation_space.shape     # e.g., (4,) for CartPole
n_actions = env.action_space.n              # e.g., 2 for CartPole

In [None]:
# Reset environment and set seed for reproducibility
state, _ = env.reset(seed=seed)

# Initialize DQN agent with initial high exploration (epsilon=1)
agent = DQNAgent(state_dim, n_actions, epsilon=1).to(device)
target_network = DQNAgent(state_dim, n_actions, epsilon=1).to(device)
target_network.load_state_dict(agent.state_dict())  # Synchronize target network

# Populate the experience replay buffer with initial random experiences
exp_replay = ReplayBuffer(10**4)  # Replay buffer with capacity 10,000
for i in range(100):
    play_and_record(state, agent, env, exp_replay, n_steps=10**2)  # Helper function to collect experiences
    if len(exp_replay) == 10**4:
        break

# Set up training hyperparameters
timesteps_per_epoch = 1        # Timesteps per epoch (for logging purposes)
batch_size = 32                # Mini-batch size for training updates
total_steps = 50000            # Total training steps

# Initialize the optimizer (Adam) for updating the agent's parameters
opt = torch.optim.Adam(agent.parameters(), lr=1e-4)

# Define the exploration schedule (epsilon decay)
start_epsilon = 1              # Starting exploration rate
end_epsilon = 0.05             # Minimum exploration rate
eps_decay_final_step = 2 * 10**4  # Steps over which epsilon decays to end_epsilon

# Define frequencies for logging and updating the target network
loss_freq = 20                      # Log the loss every 20 steps
refresh_target_network_freq = 100   # Update target network every 100 steps
eval_freq = 1000                    # Evaluate the agent every 1000 steps

# Set gradient clipping threshold to stabilize training
max_grad_norm = 5000

In [None]:
mean_rw_history = []
td_loss_history = []

  and should_run_async(code)


### **The main training loop**



In [None]:
def epsilon_schedule(start_eps, end_eps, step, final_step):
    """
    Compute the exploration epsilon for the current step using a linear decay schedule.

    Args:
        start_eps (float): The initial epsilon (e.g., 1.0).
        end_eps (float): The final epsilon after decay (e.g., 0.05).
        step (int): The current training step.
        final_step (int): The step at which epsilon decays to end_eps.

    Returns:
        float: The computed epsilon value for the current step.
    """
    # Ensure the step does not exceed final_step for correct interpolation.
    return start_eps + (end_eps - start_eps) * min(step, final_step) / final_step

In [None]:
def smoothen(values):
    """
    Smooths out the given values using a Gaussian filter.

    Args:
        values (list or np.array): The sequence of values to smooth.

    Returns:
        np.array: The smoothed values.
    """
    kernel = gaussian(100, std=100)
    kernel = kernel / np.sum(kernel)
    return convolve(values, kernel, 'valid')

In [None]:
def evaluate(env, agent, n_games=1, greedy=False, t_max=10000):
    """
    Evaluate the agent's performance by running it for a specified number of games.

    Args:
        env (gym.Env): The environment to evaluate in.
        agent (DQNAgent): The DQN agent.
        n_games (int): Number of games (episodes) to run.
        greedy (bool): If True, use the greedy policy (argmax); otherwise use epsilon-greedy.
        t_max (int): Maximum timesteps per episode.

    Returns:
        float: The average total reward over the evaluated games.
    """
    rewards = []
    for _ in range(n_games):
        s, _ = env.reset()
        total_reward = 0
        for _ in range(t_max):
            # Get Q-values from the agent.
            qvalues = agent.get_qvalues([s])
            # Choose action: greedy (argmax) if specified, otherwise use agent's sampling.
            action = qvalues.argmax(axis=-1)[0] if greedy else agent.sample_actions(qvalues)[0]
            s, r, terminated, truncated, _ = env.step(action)
            total_reward += r
            if terminated:
                break
        rewards.append(total_reward)
    return np.mean(rewards)

In [None]:
def train_dqn(total_steps, timesteps_per_epoch, batch_size, 
              start_epsilon, end_epsilon, eps_decay_final_step,
              loss_freq, refresh_target_network_freq, eval_freq,
              max_grad_norm, agent, target_network, env, exp_replay,
              opt, td_loss_history, mean_rw_history, env_name, device):
    """
    Main training loop for the DQN agent.

    The function updates the agent by:
      - Decaying the exploration rate.
      - Collecting experiences and storing them in the replay buffer.
      - Sampling mini-batches from the replay buffer.
      - Computing the TD loss and performing backpropagation.
      - Periodically updating the target network.
      - Evaluating and logging the agent's performance.

    Args:
        total_steps (int): Total number of training steps.
        timesteps_per_epoch (int): Number of environment steps per training epoch.
        batch_size (int): Mini-batch size for training.
        start_epsilon (float): Initial exploration rate.
        end_epsilon (float): Final exploration rate after decay.
        eps_decay_final_step (int): The step at which epsilon should reach end_epsilon.
        loss_freq (int): Frequency (in steps) to log TD loss.
        refresh_target_network_freq (int): Frequency (in steps) to update the target network.
        eval_freq (int): Frequency (in steps) to evaluate the agent.
        max_grad_norm (float): Maximum gradient norm for clipping.
        agent (DQNAgent): The online agent network.
        target_network (DQNAgent): The target network.
        env (gym.Env): The environment for interaction.
        exp_replay (ReplayBuffer): The experience replay buffer.
        opt (torch.optim.Optimizer): The optimizer for training.
        td_loss_history (list): List to record TD loss history.
        mean_rw_history (list): List to record mean reward history.
        env_name (str): Environment name (used for creating a new env during evaluation).
        device (torch.device): Device to perform computations on.

    Returns:
        None
    """
    # Reset the environment to get the initial state.
    state, _ = env.reset(seed=seed)

    # Main training loop.
    for step in trange(total_steps + 1):
        # 1. Update exploration rate (epsilon) based on schedule.
        agent.epsilon = epsilon_schedule(start_epsilon, end_epsilon, step, eps_decay_final_step)

        # 2. Interact with the environment and record experiences.
        #    play_and_record() should update the replay buffer.
        _, state = play_and_record(state, agent, env, exp_replay, timesteps_per_epoch)

        # 3. Sample a mini-batch from the replay buffer.
        states, actions, rewards, next_states, done_flags = exp_replay.sample(batch_size)

        # 4. Compute the TD loss using the agent and target networks.
        opt.zero_grad()
        loss = td_loss_ddqn(agent, target_network,
                               states, actions, rewards, next_states, done_flags,
                               gamma=0.99, device=device)

        # 5. Perform backpropagation and update the network.
        loss.backward()
        # Clip gradients to stabilize training.
        nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
        opt.step()

        # 6. Log the TD loss at specified intervals.
        if step % loss_freq == 0:
            td_loss_history.append(loss.data.cpu().item())

        # 7. Update the target network periodically.
        if step % refresh_target_network_freq == 0:
            target_network.load_state_dict(agent.state_dict())

        # 8. Evaluate the agent and update logs/plots.
        if step % eval_freq == 0:
            # Create a fresh environment for evaluation.
            eval_env = gym.make(env_name, render_mode="rgb_array", max_episode_steps=4000)
            mean_reward = evaluate(eval_env, agent, n_games=3, greedy=True, t_max=1000)
            mean_rw_history.append(mean_reward)

            clear_output(wait=True)
            print("Buffer size = %i, Epsilon = %.5f" % (len(exp_replay), agent.epsilon))

            # Plot the mean return and smoothened TD loss.
            plt.figure(figsize=[16, 5])
            plt.subplot(1, 2, 1)
            plt.title("Mean return per episode")
            plt.plot(mean_rw_history)
            plt.grid()

            plt.subplot(1, 2, 2)
            plt.title("TD loss history (smoothened)")
            plt.plot(smoothen(td_loss_history))
            plt.grid()

            plt.show()

### **Applying DQN on the CartPole**



In [None]:
train_dqn(total_steps, timesteps_per_epoch, batch_size,
          start_epsilon, end_epsilon, eps_decay_final_step,
          loss_freq, refresh_target_network_freq, eval_freq,
          max_grad_norm, agent, target_network, env, exp_replay,
          opt, td_loss_history, mean_rw_history, env_name, device)

In [None]:
final_score = evaluate(
  gym.make(env_name, render_mode="rgb_array", max_episode_steps=4000),
  agent, n_games=30, greedy=True, t_max=1000
)
print('final score:', final_score)

In [None]:
# video_folder = ""  # enter folder location
# video_length = 500

# video_file = record_video(env_name, video_folder, video_length, agent)

# play_video(video_file)