# Deep Reinforcement Learning: CartPole with DQN

In [None]:
import gym
import gymnasium
import numpy as np
import torch
import random
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report


### DQN Agent

DQN (Deep Q-Network) that inherits from `torch.nn` module. A simple feed-forward network with two hidden layers of size `hidden_size`. Takes the environment state (a vector of length `input_size`) and outputs **Q-values** ffor each of the `output_size` actions

In [None]:
class DQN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu2 = nn.ReLU()
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu2(x)
        return self.out(x)
    

### ReplayBuffer

A datastructure (buffer) to efficiently store experiences, that we may sample a tuple containing `(state, action, reward, next_state, done)` from, get the length of the buffer and add batches to it

In [None]:
class ReplayBuffer:
    """A simple replay buffer to store experiences"""
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
    # args is a tuple of (state, action, reward, next_state, done)
    def push(self, *args):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = args
        self.position = (self.position + 1) % self.capacity

    # sample a batch of experiences from the buffer
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

### Target Network Update

if `soft` then:
for each parameter 
* we take a fraction (`tau`) of the corresponding parameter from the `model`
* we take a large fraction (`1-tau`) of the **current** parameter from the `target_model`

else (`hard`) then:
load the weights and biases of `model` into `target_model`


In [None]:
def update_target_model(model, target_model, soft=False, tau=0.1):
    """
    Update the target model with the online model's weights (main network).
    Parameters:
    - model: The online model.
    - target_model: The target model.
    - soft: If True, use soft update (tau) instead of hard update.
    - tau: The interpolation factor for soft update.
      """
    if soft:
        target_model_state_dict = target_model.state_dict()
        model_state_dict = model.state_dict()
        for key in model_state_dict:
            target_model_state_dict[key] = model_state_dict[key]*tau + target_model_state_dict[key]*(1-tau)
        target_model.load_state_dict(target_model_state_dict)
    else:
        target_model.load_state_dict(model.state_dict())



### DQN Loss Function

Computes the loss for the DQN Agent. Implements the **Bellman Backup** 

$$L = E\bigg[Q(s, a) - (r + \gamma \ \mathrm{max}_a Q'(s´,a´)))^2 \bigg]$$



In [None]:
# DQN loss function and training loop
def dqn_loss(model, target_model, batch, gamma):
    states, actions, rewards, next_states, dones = zip(*batch)
    # Convert to tensors
    states = torch.FloatTensor(states)
    actions = torch.LongTensor(actions).unsqueeze(1)
    rewards = torch.FloatTensor(rewards).unsqueeze(1)
    next_states = torch.FloatTensor(next_states)
    dones = torch.FloatTensor(dones).unsqueeze(1)

    # Compute Q-values
    # Q(s, a) = r + gamma * max_a' Q(s', a')
    # where s' is the next state, a' is the action taken in the next state
    # and r is the reward received
    q_values = model(states).gather(1, actions)
    next_q_values = target_model(next_states).max(1)[0].detach()
    expected_q_values = rewards + (1 - dones) * gamma * next_q_values
    # Compute the loss
    loss = nn.MSELoss()(q_values, expected_q_values.unsqueeze(1))
    return loss

### Train Model

Takes a batch of past experiences and calculates how wrong the current model predictions are (i.e the loss from `dqn_loss(...)`) and then updates the models weights to make its predictions better in future iterations

* `model` is the online network meanwhile `target_model` is the frozen (or slowly) updated target network (the target Q-network). 
* `optimizer` is the PyTorch `Adam` that will update the `model`s weights
* `replay_buffer` is the replaybuffer object from `ReplayBuffer` class
* `gamma` is the discount factor used in the Q-learning update rule, determining the importance of future rewards

In [None]:
def train_model(
    model, target_model, replay_buffer, optimizer, batch_size, gamma
):
    if len(replay_buffer) < batch_size:
        return

    batch = replay_buffer.sample(batch_size)
    loss = dqn_loss(model, target_model, batch, gamma)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Loss: {loss.item()}")

    return loss.item()

## $\epsilon$-greedy policy

Implements the $\epsilon$-greedy action selection strategy (in order to balance **exploration** and **exploitation**)

* **Exploration**: If the `random.random()` is less than `epsilon`, the agent **explore** (not using its current knowledge)

* **Exploitation**: If the `random.random()` is **not** less than `epsilon`, the agent **exploits** its current knowledge, using Q-values from the model with the current state


In [None]:
def e_greedy(state, model, epsilon, action_dim):
    if random.random() < epsilon:
        return random.randint(0, action_dim - 1)
    else:
        state = torch.FloatTensor(state).unsqueeze(0)  
        with torch.no_grad():
            q_values = model(state)
        return q_values.argmax().item()


### Main function (training loop)

Hyperparameters
```input_size
    hidden_size
    output_size
    num_episodes
    batch_size
    gamma
    learning_rate
    target_update_freq
    replay_buffer_capacity
```

Continues on until `CartPole` environment-step gives `done=True`

In [None]:
def main():
    # Hyperparameters
    env = gymnasium.make('CartPole-v1') # use gymnasium for the latest version
    input_size = env.observation_space.shape[0]
    hidden_size = 128
    output_size = env.action_space.n
    num_episodes = 1000
    batch_size = 64
    gamma = 0.99
    learning_rate = 0.001
    target_update_freq = 10
    replay_buffer_capacity = 10000

    # init the environment, model, optimizer, and replay buffer
    model = DQN(input_size, hidden_size, output_size)
    print("model", model)
    target_model = DQN(input_size, hidden_size, output_size)
    target_model.load_state_dict(model.state_dict())
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    replay_buffer = ReplayBuffer(replay_buffer_capacity)

    # main training loop
    for episode in range(num_episodes):
        state, _ = env.reset()
        done = False
        total_reward = 0

        while not done:
            # using our greedy function
            action = e_greedy(
                state, model, epsilon=0.1, action_dim=output_size
            )

            next_state, reward, done, _, _ = env.step(action)
            total_reward += reward

            # store experience in replay buffer
            replay_buffer.push(state, action, reward, next_state, done)

            # update the state
            state = next_state

        print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {total_reward}")

        # train the model
        train_model(model, target_model, replay_buffer, optimizer, batch_size, gamma)

        # update the target model every few (10) episodes
        if episode % target_update_freq == 0:
            update_target_model(model, target_model, soft=True, tau=0.1)

In [None]:
main()