# Deep Q Network
---

What we need are:
1. $Q_\theta(s,a)$ to tell how good it is to take action $a$ given state $s$, parameterized by $\theta$
2. A policy $\pi_\theta(s)$ utilize the $Q_\theta(s,a)$, which is $\pi_\theta(s)=\arg\max_a Q_\theta(s,a)$, but also be able to explore different actions to find better one.
3. Temporal difference $\delta(s_t, a_t, s_{t+1}, r_{t+1})=Q_\theta(s_t, a_t) - (r_{t+1} + \max_{a_{t+1}}\gamma Q_\theta(s_{t+1}, a_{t+1}))$
4. A replay buffer to store tranistion $(s_t, a_t, s_{t+1}, r_{t+1})$ to improve sample efficiency.

### Define the Deep Q Network
What we need is $Q_\theta:S\times A\rightarrow R$.

For a discrete action space, we can instead define a function of type $S\rightarrow R^{|A|}$. We also need $V_\theta:S\rightarrow R$ defined as $V_\theta(s) =\max_a Q_\theta(s,a)$

In [1]:
import torch
import torch.nn as nn
import numpy as np

class DQN(nn.Module):
    def __init__(self, dim_state, n_actions):
        super().__init__()
        self.dim_state = dim_state
        self.n_actions = n_actions
        self.net = nn.Sequential(
            nn.Linear(dim_state, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, n_actions)
        )
        self.forward = self.net.forward
    
    def q_vals(self, states, actions):
        #Q(s,a)
        return self.forward(states).gather(
            1, actions.reshape(-1, 1).to(torch.int64)
        ).reshape(-1)
    
    def v_vals(self, states):
        #max_a Q(s,a)
        return self.forward(states).amax(-1)
        
    def save(self, path):
        torch.save(self.net.state_dict(), path)
    
    def load(self, path):
        self.net.load_state_dict(torch.load(path))
        
    @property
    def device(self):
        return next(self.net.parameters()).device

    def get_compat(self, x):
        x = np.stack(x)
        x = torch.tensor(x, device=self.device)
        return x


### Define the replay buffer

In [2]:
from collections import deque, namedtuple

Transition = namedtuple("Transition", ("state", "action", "next_state", "reward"))
class ReplayBuffer:
    def __init__(self, capacity):
        self.deque = deque(maxlen=capacity)
    
    def push(self, *arg):
        self.deque.append(Transition(*arg))
    
    def sample(self, batch_size):
        return random.sample(self.deque, batch_size)
        
    def __len__(self):
        return len(self.deque)


### Define the Policy
Use $\epsilon$-greedy as exploration strategy.

In [3]:
import random

class EpsGreedy:
    def __init__(self, q_net):
        self.q_value_net = q_net
    
    def get_action(self, state, mode="eval", eps=0.1):
        if mode == "train" and random.random() < eps:
            return random.randint(0, self.q_value_net.n_actions-1)
        else:
            state = self.q_value_net.get_compat(state)
            return self.q_value_net(state).argmax().item()   

### Training
To stablize training, use a target net, parametrized by $\theta'$, to give value of the next state $s_{t+1}$.

$\delta(s_t, a_t, s_{t+1}, r_{t+1})=Q_\theta(s_t, a_t) - (r_{t+1} + \max_{a_{t+1}}\gamma Q_{\theta'}(s_{t+1}, a_{t+1}))$

$\theta'$ would be synchronized with $\theta$ every a certain number of update steps.

In [4]:
import numpy as np
import copy
from itertools import count
from IPython.display import display

class DQNTrainer:
    def __init__(self, q_net, env,
                 lr=1e-4, gamma=0.99, replay_buf_size=10000,
                 eps_start=0.5, eps_end=0.01, eps_decay=1e4,
                 update_thres=500):
        
        self.q_net = q_net
        self.env = env
        self.opt = torch.optim.RMSprop(
            self.q_net.parameters(),
            lr=lr
        )
        self.gamma=gamma
        self.criterion = torch.nn.SmoothL1Loss()
        self.policy = EpsGreedy(q_net)
        self.replay_buf = ReplayBuffer(replay_buf_size)
        self.target_net = copy.deepcopy(self.q_net)
        self.counter = 0
        
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_decay = eps_decay
        
        self.update_thres=update_thres
        
    @property
    def eps(self):
        ratio = np.exp(-self.counter/self.eps_decay)
        eps = self.eps_end + (self.eps_start - self.eps_end)*ratio
        return eps
           
    def update(self, batch_size):
        if batch_size > len(self.replay_buf):
            return
        batch = self.replay_buf.sample(batch_size)
        #check next_state to tell if is final
        non_final_mask = [*map(lambda trans: trans.next_state is not None, batch)]
        state, action, next_state, reward = zip(*batch)
        next_state = [*filter(lambda x: x is not None, next_state)]

        batch = Transition(*map(self.q_net.get_compat , [state, action, next_state, reward]))
        #Q^(s,a)
        state_action_vals = self.q_net.q_vals(batch.state, batch.action)
        #V(s')
        next_state_vals = torch.zeros_like(state_action_vals)
        with torch.no_grad():
            next_state_vals[non_final_mask] = self.target_net.v_vals(batch.next_state)
        #Q(s,a)=r + gamma*V(s')
        expected_state_action_vals = batch.reward + self.gamma * next_state_vals
        
        #compute loss & update
        loss = self.criterion(state_action_vals, expected_state_action_vals)
        self.opt.zero_grad()
        loss.backward()
        for para in self.q_net.parameters():
            para.grad.data.clamp_(-1, 1)
        self.opt.step()
    
    def train(self, n_episodes):
        for i in range(n_episodes):
            state, _ = self.env.reset()
            total_reward = 0
            for _ in count():
                self.counter += 1
                action = self.policy.get_action(state, "train", self.eps)
                next_state, reward, terminated, truncated, info = self.env.step(action)
                
                total_reward += reward
                if terminated:
                    next_state = None
                
                self.replay_buf.push(state, action, next_state, reward)
                
                state = next_state
                self.update(512)
                if self.counter % self.update_thres == 0:
                    self.target_net.load_state_dict(self.q_net.state_dict())
                
                if terminated or truncated:break
                    
            if i % 200 == 0:
                print(
                    "Episode: {:6}, Total_reward: {:7}, eps: {:6.4}, Total_reward_eval: {:7}".format(
                        i, total_reward, self.eps, self.eval())
                )
                self.eval()
    
    def eval(self, render=True):
        state, _ = self.env.reset()
        total_reward = 0
        for _ in count():
            action = self.policy.get_action(state)
            state, reward, terminated, truncated, info = self.env.step(action)
            if render:
                img = self.env.render()
                
            total_reward += reward
            if terminated or truncated: break
        return total_reward

In [5]:
import gymnasium as gym

device = "cuda" if torch.cuda.is_available() else "cpu"

env = gym.make("CartPole-v1", render_mode="rgb_array")
q_net = DQN(4, 2)
q_net.to(device)
trainer = DQNTrainer(q_net, env)

**Train by running 1500 episodes**

In [6]:
trainer.train(1500)

**Save the deep q-network**

In [7]:
q_net.save("cartpole_1500.pth")

**Test the performance**

In [8]:
def test(policy, env):
    state, _ = env.reset()
    for _ in count():
        action = policy.get_action(state)
        state, reward, terminated, truncated, info = env.step(action)
        
        if terminated or truncated:
            env.close()
            break

q_net.load("cartpole_1500.pth")
policy = EpsGreedy(q_net)

test_env = gym.make("CartPole-v1", render_mode="human")
test(policy, test_env)