# Deep $Q$-learning

In [None]:
import random
from collections import deque
from copy import deepcopy

import matplotlib.pyplot as plt
import gymnasium as gym
import torch
from torch import nn, optim, Tensor

torch.manual_seed(2022)  # set random seeds for reproducibility 
random.seed(2022)  

In [None]:
class DeepQNetwork(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dims: list[int],
        output_dim: int,
        activation_fn: nn.Module,
    ):
        super().__init__()
        assert len(hidden_dims) > 0, "Must have at least one hidden layer."

        # add input layer
        self.layers = nn.ModuleList(
            [nn.Linear(in_features=input_dim, out_features=hidden_dims[0])]
        )
        self.layers.append(activation_fn())

        # add hidden layers
        for i in range(len(hidden_dims) - 1):
            self.layers.append(
                nn.Linear(in_features=hidden_dims[i], out_features=hidden_dims[i+1])
            )
            self.layers.append(activation_fn())

        # add output layer
        self.layers.append(
            nn.Linear(in_features=hidden_dims[-1], out_features=output_dim)
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            input = layer(input)

        return input


class DQNTrainer:

    def __init__(
        self,
        qnet: nn.Module,
        num_actions: int,
        buffer_size: int,
        batch_size: int,
        gamma: float,
        epsilon_decay: float,
        epsilon_min: float,
        network_update_interval: int,
        learning_rate: float
    ):
        # initialize replay buffer/memory
        self.replay_buffer = deque([], maxlen=buffer_size)
        self.batch_size = batch_size
        self.qnet = qnet
        self.target_qnet = deepcopy(qnet)
        self.optimizer = optim.Adam(self.qnet.parameters(), lr=learning_rate)

        self.gamma = gamma
        self.epsilon = 1.0
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.num_actions = num_actions
        self.network_update_interval = network_update_interval

        self.loss_func = nn.MSELoss()

    def select_action(self, state: torch.Tensor) -> int:
        if self.epsilon > random.uniform(0, 1):
            # select random action from [0,1,...,num_actions]
            action = random.randrange(self.num_actions)
        else:
            Q_values = self.qnet(state)
            action = torch.argmax(Q_values).item()

        return action


    def compute_loss(self, minibacth: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        loss = 0
        for item, label in zip(minibacth, labels):
            state = item[0]
            action = item[1]
            pred = self.qnet(state)[action]
            loss += self.loss_func(pred, label) / len(minibacth)

        return loss
        

    def get_labels(self, minibacth: list[tuple[Tensor, int, float, Tensor, bool]]) -> torch.Tensor:
        labels = []
        for memory in minibacth:
            state = memory[0]
            action = memory[1]
            reward = memory[2]
            next_state = memory[3]
            next_state_is_terminal = memory[4]
            if next_state_is_terminal:
                labels.append(reward)
            else:
                labels.append( reward + self.gamma*torch.max(self.target_qnet(next_state)).item() )

        return torch.tensor(labels)

    def train(self, num_episodes: int, max_timesteps: int):
        env = gym.make("CartPole-v1", render_mode=None)
        losses = []
        episode_rewards = [0]*num_episodes

        for m in range(num_episodes):
            # reset env at start of episode and get starting state
            state, _ = env.reset()

            for t in range(max_timesteps):
                action = self.select_action(torch.tensor(state))
                
                new_state, reward, terminated, truncated, info = env.step(action)
                self.replay_buffer.append( (torch.tensor(state), action, float(reward), torch.tensor(new_state), terminated) )
                episode_rewards[m] += reward



                if len(self.replay_buffer) >= self.batch_size:
                    # sample uniformly without replacement
                    minibatch = random.sample(self.replay_buffer, self.batch_size)
                    labels = self.get_labels(minibatch)
                    
                    loss = self.compute_loss(minibatch, labels)
                    
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    losses.append(loss.item())

                    if t % self.network_update_interval == 0:
                        qnet_state_dict = self.qnet.state_dict()
                        self.target_qnet.load_state_dict(qnet_state_dict)  # \hat{Q} = Q
                
                state = new_state
                
                if terminated:
                    break

            self.epsilon = min(self.epsilon*(1-self.epsilon_decay), self.epsilon_min)

        return losses, episode_rewards

In [None]:
qnet = DeepQNetwork(4, [50, 50], 2, nn.ReLU)

trainer = DQNTrainer(
    qnet,
    num_actions=2,
    buffer_size=75000,
    batch_size=256,
    gamma=0.99,
    epsilon_decay=0.009,
    epsilon_min=0.005,
    network_update_interval=250,
    learning_rate=0.0005,
)



In [None]:
losses, episode_rewards = trainer.train(num_episodes=150, max_timesteps=500)

In [None]:
plt.plot(episode_rewards)
plt.show()

In [None]:
# plt.yscale("log")
plt.plot(losses)
plt.show()