# Dueling DQN

## 1. 相关概念

#### 最优状态价值函数 (optimal state-value function):

$V_*(s) = \max_\pi V_\pi (s), \forall s \in \mathcal{S}$。

#### 最优优势函数 (optimal advantage function): 

$A_*(s, a) = Q_*(s, a) - V_*(s)$。可得到: $Q_*(s, a) = V_*(s) + A_*(s, a) - \max_{a \in \mathcal{A}} A_*(s, a), \forall s \in \mathcal{S}, a \in \mathcal{A}$。其中，$\max_{a \in \mathcal{A}} A_*(s, a) = 0, \forall s \in \mathcal{S}$。

#### 对决网络 (dueling network):

由两部分组成，$A(s, a; w^A)$近似最优优势函数$A_*(s, a)$，$V(s; w^V)$近似最优状态价值函数$V_*(s)$，对决网络定义为$Q(s, a; w) = V(s; w^V) + A(s, a; w^A) - \max_{a \in \mathcal{A}} A(s, a; w^A)$。直观上，对决网络可以了解哪些状态有价值或没价值而无需考虑动作；实践中，对决网络比DQN可取得更好的效果。

#### 网络结构:

两个网络共享embedding层，将状态$s$映射成特征向量。(1) 优势头输出向量的大小是动作空间的维度$\mid \mathcal{A} \mid$，每个元素对应一个动作；(2) 状态价值头输出一个实数。

## 2. 不唯一性

$\max_{a \in \mathcal{A}} A_*(s, a) = 0, \forall s \in \mathcal{S}$，但是从对决网络中删去。若删掉，$Q_*(s, a) = V_*(s) + A_*(s, a)$，$V$和$A$会不唯一，例如$V$增加1，$A$减少1。从而对决网络最后一项不可省略。

## 3. 对决网络的实现

实际中，会使用$Q(s, a; w) = V(s; w^V) + A(s, a; w^A) - \text{mean}_{a \in \mathcal{A}} A(s, a; w^A)$，原因如下:

(1) 取max会造成梯度不稳定，取mean梯度会更平滑，减少max操作带来的剧烈变化和不稳定性。

(2) 取mean可以减少方差，平滑掉一些噪声，从而使模型效果更好。

(3) 取max会过于集中在一个动作上，而取mean会减少偏向，有利于探索更多策略。

## 4. 训练流程

训练流程与DQN相同。

# 代码实现

In [1]:
# 导入库
import os
import gym
import argparse
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

In [2]:
# 定义Dueling Q-Network
class Dueling_QNet(nn.Module):
    def __init__(self, state_size, num_actions):
        super().__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 32)
        self.V = nn.Linear(32, 1)
        self.A = nn.Linear(32, num_actions)
        
    def forward(self, state):
        # state: (batch_size, state_size)
        x = F.relu(self.fc1(state.float()))
        x = F.relu(self.fc2(x))
        # V: (batch_size, 1)
        V = self.V(x)
        # A: (batch_size, num_actions)
        A = self.A(x)
        return V + A - torch.mean(A, dim=-1, keepdim=True)

In [3]:
# 定义Dueling DQN类
class DuelingDQN:
    def __init__(self, state_size, num_actions, gamma, device):
        self.gamma = gamma
        self.Q = Dueling_QNet(state_size, num_actions).to(device)
        self.target_Q = Dueling_QNet(state_size, num_actions).to(device)
        self.target_Q.load_state_dict(self.Q.state_dict())
        
    def get_action(self, state):
        # state: (state_size,)
        # qvals: (num_actions,)
        qvals = self.Q(state)
        return qvals.argmax().item()
    
    def compute_loss(self, bs, ba, br, bd, bns):
        # bs, bns: (batch_size, state_size)
        # ba, br, bd: (batch_size,)
        # qvals, next_qvals: (batch_size,)
        qvals = self.Q(bs).gather(1, ba.unsqueeze(1)).squeeze(1)
        # next_actions: (batch_size, 1)
        next_actions = self.Q(bns).argmax(dim=1, keepdim=True)
        # detach() 避免对目标网络的参数求导
        next_qvals = self.target_Q(bns).gather(1, next_actions).squeeze(1).detach()
        loss = F.mse_loss(qvals, br + self.gamma * next_qvals * (1 - bd))
        return loss        
    
    def soft_update(self, tau=0.01):
        for target_param, param in zip(self.target_Q.parameters(), self.Q.parameters()):
            target_param.data.copy_(param.data * tau + target_param.data * (1 - tau))

In [4]:
# 定义经验缓存
class ReplayBuffer:
    def __init__(self, max_size, device):
        self.max_size = max_size
        self.device = device
        self.size = 0
        self.state = []
        self.action = []
        self.reward = []
        self.done = []
        self.next_state = []
        
    def push(self, state, action, reward, done, next_state):
        if self.size < self.max_size:
            self.state.append(state)
            self.action.append(action)
            self.reward.append(reward)
            self.done.append(done)
            self.next_state.append(next_state)
        else:
            idx = self.size % self.max_size
            self.state[idx] = state
            self.action[idx] = action
            self.reward[idx] = reward
            self.done[idx] = done
            self.next_state[idx] = next_state
        self.size += 1
        
    def sample(self, n):
        sample_num = min(self.size, self.max_size)
        indices = np.random.choice(range(sample_num), size=n, replace=True) if self.size < n else np.random.choice(range(sample_num), size=n, replace=False)
        state = torch.tensor([self.state[i] for i in indices], dtype=torch.float32, device=self.device)
        action = torch.tensor([self.action[i] for i in indices], dtype=torch.long, device=self.device)
        reward = torch.tensor([self.reward[i] for i in indices], dtype=torch.float32, device=self.device)
        done = torch.tensor([self.done[i] for i in indices], dtype=torch.float32, device=self.device)
        next_state = torch.tensor([self.next_state[i] for i in indices], dtype=torch.float32, device=self.device)
        return state, action, reward, done, next_state

In [5]:
# 训练
def train(args, env):
    agent = DuelingDQN(args.state_size, args.num_actions, args.discount, args.device)
    replay_buffer = ReplayBuffer(10000, args.device)
    optimizer = torch.optim.Adam(agent.Q.parameters(), lr=args.lr)
    writer = SummaryWriter()
    
    epsilon = 1
    epsilon_max = 1
    epsilon_min = 0.1
    
    episode_reward = 0
    episode_length = 0
    episode_num = 1
    max_episode_reward = -float('inf')
    
    agent.Q.train()
    state, _ = env.reset()
    
    for i in range(args.max_steps):
        if np.random.rand() < epsilon or i < args.warmup_steps:
            action = env.action_space.sample()
        else:
            action = agent.get_action(torch.from_numpy(state).to(args.device))
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        episode_reward += reward
        episode_length += 1
        replay_buffer.push(state, action, reward, done, next_state)
        state = next_state
        
        if done:
            if episode_reward > max_episode_reward:
                max_episode_reward = episode_reward
                save_path = os.path.join(args.output_dir, 'model.bin')
                torch.save(agent.Q.state_dict(), save_path)
                
            writer.add_scalar('Maximum reward', max_episode_reward, episode_num)
            writer.add_scalar('Episode reward', episode_reward, episode_num)
            writer.add_scalar('Episode length', episode_length, episode_num)
            print(f'step = {i}, reward = {episode_reward:.0f}, length = {episode_length}, max reward = {max_episode_reward}, epsilon = {epsilon:.3f}')
            
            episode_reward = 0
            episode_length = 0
            episode_num += 1
            epsilon = max(epsilon - (epsilon_max - epsilon_min) * args.epsilon_decay, epsilon_min)
            state, _ = env.reset()
            
        if i > args.warmup_steps:
            bs, ba, br, bd, bns = replay_buffer.sample(n=args.batch_size)
            loss = agent.compute_loss(bs, ba, br, bd, bns)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            agent.soft_update()
            if i % 500 == 0:
                writer.add_scalar('loss', loss.item(), i)
    
    writer.close()

In [6]:
# 测试
def eval(args, env):
    agent = DuelingDQN(args.state_size, args.num_actions, args.discount, args.device)
    model_path = os.path.join(args.output_dir, 'model.bin')
    agent.Q.load_state_dict(torch.load(model_path))
    agent.Q.to(args.device)
    agent.Q.eval()
    
    episode_reward = 0
    episode_length = 0
    state, _ = env.reset()
    for _ in range(5000):
        action = agent.get_action(torch.from_numpy(state).to(args.device))
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        episode_reward += reward
        episode_length += 1
        state = next_state
        
        if done:
            state, _ = env.reset()
            print(f'episode reward = {episode_reward:.0f}, episode_length = {episode_length}')
            episode_reward = 0
            episode_length = 0

In [7]:
# 运行
args = argparse.Namespace()
args.env = 'CartPole-v1'
args.state_size = 4
args.num_actions = 2
args.discount = 0.99
args.max_steps = int(1e5)
args.lr = 1e-3
args.batch_size = 32
args.warmup_steps = int(1e4)
args.output_dir = 'output'
args.epsilon_decay = 1 / 1000
if torch.backends.mps.is_available():
    args.device = torch.device('mps')
elif torch.cuda.is_available():
    args.device = torch.device('cuda')
else:
    args.device = torch.device('cpu')
    
os.makedirs(args.output_dir, exist_ok=True)
    
env = gym.make(args.env)
env.reset(seed=42)
env.action_space.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if args.device == torch.device('cuda'):
    torch.cuda.manual_seed(42)

print('Training started...')
train(args, env)
print('Training completed!')

print('Evaluation started...')
eval(args, env)
print('Evaluation completed!')

Training started...
step = 28, reward = 29, length = 29, max reward = 29.0, epsilon = 1.000
step = 45, reward = 17, length = 17, max reward = 29.0, epsilon = 0.999
step = 114, reward = 69, length = 69, max reward = 69.0, epsilon = 0.998
step = 129, reward = 15, length = 15, max reward = 69.0, epsilon = 0.997
step = 168, reward = 39, length = 39, max reward = 69.0, epsilon = 0.996
step = 193, reward = 25, length = 25, max reward = 69.0, epsilon = 0.995
step = 207, reward = 14, length = 14, max reward = 69.0, epsilon = 0.995
step = 230, reward = 23, length = 23, max reward = 69.0, epsilon = 0.994
step = 258, reward = 28, length = 28, max reward = 69.0, epsilon = 0.993
step = 300, reward = 42, length = 42, max reward = 69.0, epsilon = 0.992
step = 325, reward = 25, length = 25, max reward = 69.0, epsilon = 0.991
step = 353, reward = 28, length = 28, max reward = 69.0, epsilon = 0.990
step = 370, reward = 17, length = 17, max reward = 69.0, epsilon = 0.989
step = 403, reward = 33, length =

step = 9326, reward = 28, length = 28, max reward = 73.0, epsilon = 0.615
step = 9344, reward = 18, length = 18, max reward = 73.0, epsilon = 0.614
step = 9360, reward = 16, length = 16, max reward = 73.0, epsilon = 0.613
step = 9390, reward = 30, length = 30, max reward = 73.0, epsilon = 0.612
step = 9406, reward = 16, length = 16, max reward = 73.0, epsilon = 0.611
step = 9419, reward = 13, length = 13, max reward = 73.0, epsilon = 0.610
step = 9437, reward = 18, length = 18, max reward = 73.0, epsilon = 0.609
step = 9448, reward = 11, length = 11, max reward = 73.0, epsilon = 0.608
step = 9467, reward = 19, length = 19, max reward = 73.0, epsilon = 0.608
step = 9517, reward = 50, length = 50, max reward = 73.0, epsilon = 0.607
step = 9532, reward = 15, length = 15, max reward = 73.0, epsilon = 0.606
step = 9549, reward = 17, length = 17, max reward = 73.0, epsilon = 0.605
step = 9562, reward = 13, length = 13, max reward = 73.0, epsilon = 0.604
step = 9576, reward = 14, length = 14,



step = 10044, reward = 60, length = 60, max reward = 73.0, epsilon = 0.586
step = 10060, reward = 16, length = 16, max reward = 73.0, epsilon = 0.585
step = 10076, reward = 16, length = 16, max reward = 73.0, epsilon = 0.584
step = 10086, reward = 10, length = 10, max reward = 73.0, epsilon = 0.583
step = 10108, reward = 22, length = 22, max reward = 73.0, epsilon = 0.582
step = 10119, reward = 11, length = 11, max reward = 73.0, epsilon = 0.581
step = 10141, reward = 22, length = 22, max reward = 73.0, epsilon = 0.581
step = 10153, reward = 12, length = 12, max reward = 73.0, epsilon = 0.580
step = 10162, reward = 9, length = 9, max reward = 73.0, epsilon = 0.579
step = 10172, reward = 10, length = 10, max reward = 73.0, epsilon = 0.578
step = 10187, reward = 15, length = 15, max reward = 73.0, epsilon = 0.577
step = 10196, reward = 9, length = 9, max reward = 73.0, epsilon = 0.576
step = 10206, reward = 10, length = 10, max reward = 73.0, epsilon = 0.575
step = 10225, reward = 19, le

step = 18434, reward = 32, length = 32, max reward = 259.0, epsilon = 0.489
step = 18465, reward = 31, length = 31, max reward = 259.0, epsilon = 0.488
step = 18595, reward = 130, length = 130, max reward = 259.0, epsilon = 0.487
step = 18635, reward = 40, length = 40, max reward = 259.0, epsilon = 0.486
step = 18751, reward = 116, length = 116, max reward = 259.0, epsilon = 0.485
step = 19098, reward = 347, length = 347, max reward = 347.0, epsilon = 0.484
step = 19187, reward = 89, length = 89, max reward = 347.0, epsilon = 0.483
step = 19258, reward = 71, length = 71, max reward = 347.0, epsilon = 0.482
step = 19297, reward = 39, length = 39, max reward = 347.0, epsilon = 0.482
step = 19331, reward = 34, length = 34, max reward = 347.0, epsilon = 0.481
step = 19367, reward = 36, length = 36, max reward = 347.0, epsilon = 0.480
step = 19381, reward = 14, length = 14, max reward = 347.0, epsilon = 0.479
step = 19423, reward = 42, length = 42, max reward = 347.0, epsilon = 0.478
step =

step = 32229, reward = 136, length = 136, max reward = 500.0, epsilon = 0.392
step = 32273, reward = 44, length = 44, max reward = 500.0, epsilon = 0.392
step = 32290, reward = 17, length = 17, max reward = 500.0, epsilon = 0.391
step = 32302, reward = 12, length = 12, max reward = 500.0, epsilon = 0.390
step = 32340, reward = 38, length = 38, max reward = 500.0, epsilon = 0.389
step = 32461, reward = 121, length = 121, max reward = 500.0, epsilon = 0.388
step = 32584, reward = 123, length = 123, max reward = 500.0, epsilon = 0.387
step = 32627, reward = 43, length = 43, max reward = 500.0, epsilon = 0.386
step = 32654, reward = 27, length = 27, max reward = 500.0, epsilon = 0.385
step = 32685, reward = 31, length = 31, max reward = 500.0, epsilon = 0.384
step = 32808, reward = 123, length = 123, max reward = 500.0, epsilon = 0.383
step = 32976, reward = 168, length = 168, max reward = 500.0, epsilon = 0.383
step = 32989, reward = 13, length = 13, max reward = 500.0, epsilon = 0.382
st

step = 50009, reward = 500, length = 500, max reward = 500.0, epsilon = 0.297
step = 50418, reward = 409, length = 409, max reward = 500.0, epsilon = 0.296
step = 50430, reward = 12, length = 12, max reward = 500.0, epsilon = 0.295
step = 50769, reward = 339, length = 339, max reward = 500.0, epsilon = 0.294
step = 51269, reward = 500, length = 500, max reward = 500.0, epsilon = 0.293
step = 51769, reward = 500, length = 500, max reward = 500.0, epsilon = 0.293
step = 51837, reward = 68, length = 68, max reward = 500.0, epsilon = 0.292
step = 52337, reward = 500, length = 500, max reward = 500.0, epsilon = 0.291
step = 52536, reward = 199, length = 199, max reward = 500.0, epsilon = 0.290
step = 52683, reward = 147, length = 147, max reward = 500.0, epsilon = 0.289
step = 53183, reward = 500, length = 500, max reward = 500.0, epsilon = 0.288
step = 53473, reward = 290, length = 290, max reward = 500.0, epsilon = 0.287
step = 53559, reward = 86, length = 86, max reward = 500.0, epsilon 

step = 73038, reward = 314, length = 314, max reward = 500.0, epsilon = 0.200
step = 73159, reward = 121, length = 121, max reward = 500.0, epsilon = 0.199
step = 73325, reward = 166, length = 166, max reward = 500.0, epsilon = 0.198
step = 73425, reward = 100, length = 100, max reward = 500.0, epsilon = 0.197
step = 73439, reward = 14, length = 14, max reward = 500.0, epsilon = 0.196
step = 73779, reward = 340, length = 340, max reward = 500.0, epsilon = 0.195
step = 73840, reward = 61, length = 61, max reward = 500.0, epsilon = 0.194
step = 73857, reward = 17, length = 17, max reward = 500.0, epsilon = 0.194
step = 73948, reward = 91, length = 91, max reward = 500.0, epsilon = 0.193
step = 73975, reward = 27, length = 27, max reward = 500.0, epsilon = 0.192
step = 74163, reward = 188, length = 188, max reward = 500.0, epsilon = 0.191
step = 74189, reward = 26, length = 26, max reward = 500.0, epsilon = 0.190
step = 74397, reward = 208, length = 208, max reward = 500.0, epsilon = 0.18

step = 89877, reward = 500, length = 500, max reward = 500.0, epsilon = 0.104
step = 90118, reward = 241, length = 241, max reward = 500.0, epsilon = 0.104
step = 90584, reward = 466, length = 466, max reward = 500.0, epsilon = 0.103
step = 91084, reward = 500, length = 500, max reward = 500.0, epsilon = 0.102
step = 91280, reward = 196, length = 196, max reward = 500.0, epsilon = 0.101
step = 91602, reward = 322, length = 322, max reward = 500.0, epsilon = 0.100
step = 91738, reward = 136, length = 136, max reward = 500.0, epsilon = 0.100
step = 91870, reward = 132, length = 132, max reward = 500.0, epsilon = 0.100
step = 91986, reward = 116, length = 116, max reward = 500.0, epsilon = 0.100
step = 92311, reward = 325, length = 325, max reward = 500.0, epsilon = 0.100
step = 92327, reward = 16, length = 16, max reward = 500.0, epsilon = 0.100
step = 92827, reward = 500, length = 500, max reward = 500.0, epsilon = 0.100
step = 93185, reward = 358, length = 358, max reward = 500.0, epsi

#### Reward:

![fig1](https://raw.githubusercontent.com/Xavier-MaYiMing/Reinforcement-learning-and-combinatorial-optimzation/main/figs/Dueling%20DQN%20-%20reward.png)

#### Maximum reward:

![fig2](https://raw.githubusercontent.com/Xavier-MaYiMing/Reinforcement-learning-and-combinatorial-optimzation/main/figs/Dueling%20DQN%20-%20max%20reward.png)