In [1]:
import gym
import ptan
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from tensorboardX import SummaryWriter

from lib import dqn_model_noise, common_dqn_distrib

In [2]:
# n-step
REWARD_STEPS = 2

#priority replay
PRIO_REPLAY_ALPHA = 0.6
BETA_START = 0.4
BETA_FRAMES = 100000

#C51
Vmax = 10
Vmin = -10
N_ATOMS = 51
DELTA_Z = (Vmax - Vmin) / (N_ATOMS - 1)


In [3]:
class RainbowDQN(nn.Module):
    #we combine dueling DQN, noisy network and distribution DQN to the same structure, value network predict the 
    #input state distribution, then give each sample a N_ATOMS tensor. Action advantage is the action distribution
    #we made in game
    def __init__(self, input_shape, n_actions):
        super(RainbowDQN, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        
        conv_out_size = self._get_conv_out(input_shape)
        self.fc_val = nn.Sequential(
            dqn_model_noise.NoisyLinear(conv_out_size, 256),
            nn.ReLU(),
            dqn_model_noise.NoisyLinear(256, N_ATOMS)
        )
        
        self.fc_adv = nn.Sequential(
            dqn_model_noise.NoisyLinear(conv_out_size, 256),
            nn.ReLU(),
            dqn_model_noise.NoisyLinear(256, n_actions * N_ATOMS)
        )
        
        self.register_buffer("supports", torch.arange(Vmin, Vmax + DELTA_Z, DELTA_Z))
        self.softmax = nn.Softmax(dim=1)
        
    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))
    
    #forward will get the action value distribution, similar to the q-value in distribution DQN, by formulating the 
    #value path and advantage action path, our return expression become simplier, this is by the broadcasting function in
    #pytorch, it automatically implemented the unsequeeze and expand function for dimension expansion
    
    #we want to let all values that should be summed up having the same dimension, such as value path will be changed as 
    #(batch_size, 1, N_ATOMS), 2nd dimension will broadcast to all actions in action advantage path, we need to minus
    #the baseline advantage, it is the mean advantage for the atom to perform all actions, keepdim=True call mean() to store
    #2nd dimension, the shape is the tensor of (batch_size, 1, N_ATOMS). Baseline advantage will also be broadcasted.
    def forward(self, x):
        batch_size = x.size()[0]
        fx = x.float() / 256
        conv_out = self.conv(fx).view(batch_size, -1)
        val_out = self.fc_val(conv_out).view(batch_size, 1, N_ATOMS)
        adv_out = self.fc_adv(conv_out).view(batch_size, -1, N_ATOMS)
        adv_mean = adv_out.mean(dim=1, keepdim=True)
        return val_out + (adv_out - adv_mean)
    
    #the function change probability distribution to q-values without calling the network multiple times
    def both(self, x):
        cat_out = self(x)
        probs = self.apply_softmax(cat_out)
        weights = probs * self.supports
        res = weights.sum(dim=2)
        return cat_out, res
    
    def qvals(self, x):
        return self.both(x)[1]
    
    #we put softmax function on the output distribution probabilities
    def apply_softmax(self, t):
        return self.softmax(t.view(-1, N_ATOMS)).view(t.size())

In [4]:
def calc_loss(batch, batch_weights, net, tgt_net, gamma, device="cpu"):
    states, actions, rewards, dones, next_states = common_dqn_distrib.unpack_batch(batch)
    batch_size = len(batch)
    
    #we send batch weightings to each sample(priority replay buffer)
    states_v = torch.tensor(states).to(device)
    actions_v = torch.tensor(actions).to(device)
    next_states_v = torch.tensor(next_states).to(device)
    batch_weights_v = torch.tensor(batch_weights).to(device)
    
    #double DQN here requires online network to choose action and target network to get the action value, we need to pass
    #the current state and the next state to our online network, if we call the function twice as before, the efficiency is
    #low, therefore, we combine current state and next state as a tensor, transfer and calculate the result, then we will 
    #separate the result. We need to calculate the q-value and the original state distribution, because we use greedy policy
    #as the action method with largest q-value.
    distr_v, qvals_v = net.both(torch.cat((states_v, next_states_v)))
    next_qvals_v = qvals_v[batch_size:]
    distr_v = distr_v[:batch_size]
    
    #we decide the action should be performed in the next state, and use our target network to get the action value
    #distribution, therefore here implemented Double DQN. Then we put softmax above the distribution to get the best action
    #and copy the data to cpu to perform Bellman equation projection.
    next_actions_v = next_qvals_v.max(1)[1]
    next_distr_v = tgt_net(next_states_v)
    next_best_distr_v = next_distr_v[range(batch_size), next_actions_v.data]
    next_best_distr_v = tgt_net.apply_softmax(next_best_distr_v)
    next_best_distr = next_best_distr_v.data.cpu().numpy()
    
    #we use Bellman equation to calculate projection distribution, the result will be used as Kullback Leibler divergence
    #target
    dones = dones.astype(np.bool)
    proj_distr = common_dqn_distrib.distr_projection(next_best_distr, rewards, dones, Vmin, Vmax, N_ATOMS, gamma)
    
    #We get the action chosen distribution, and use log_softmax to calculate loss
    state_action_values = distr_v[range(batch_size), actions_v.data]
    state_log_sm_v = F.log_softmax(state_action_values, dim=1)
    
    #we calculate Kullback Leibler loss here, and multiply the weightings and return 2 number: the mean loss in optimization
    #and the loss for a batch, these will be used as the value for priority replay buffer.
    proj_distr_v = torch.tensor(proj_distr).to(device)
    loss_v = -state_log_sm_v * proj_distr_v
    loss_v = batch_weights_v * loss_v.sum(dim=1)
    return loss_v.mean(), loss_v + 1e-5

In [None]:
if __name__ == "__main__":
    #construct training objects, custom network, experience source, prioritized replay buffer and optimizer
    params = common_dqn_distrib.HYPERPARAMS['pong']
    params['epsilon_frames'] *= 2
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda", default=True, action="store_true", help="Enable cuda")
    args, unknown = parser.parse_known_args()
    device = torch.device("cuda" if args.cuda else "cpu")
    
    env = gym.make(params['env_name'])
    env = ptan.common.wrappers.wrap_dqn(env)
    
    writer = SummaryWriter(comment="-" + params['run_name'] + "-rainbow")
    net = RainbowDQN(env.observation_space.shape, env.action_space.n).to(device)
    tgt_net = ptan.agent.TargetNet(net)
    agent = ptan.agent.DQNAgent(lambda x: net.qvals(x), ptan.actions.ArgmaxActionSelector(), device=device)
    
    exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=params['gamma'], steps_count=REWARD_STEPS)
    buffer = ptan.experience.PrioritizedReplayBuffer(exp_source, params['replay_size'], PRIO_REPLAY_ALPHA)
    optimizer = optim.Adam(net.parameters(), lr=params['learning_rate'])
    
    frame_idx = 0
    beta = BETA_START
    
    with common_dqn_distrib.RewardTracker(writer, params['stop_reward']) as reward_tracker:
        while True:
            frame_idx += 1
            buffer.populate(1)
            beta = min(1.0, BETA_START + frame_idx * (1.0 - BETA_START) / BETA_FRAMES)
            
            new_rewards = exp_source.pop_total_rewards()
            if new_rewards:
                if reward_tracker.reward(new_rewards[0], frame_idx):
                    break
                    
            if len(buffer) < params['replay_initial']:
                continue
                
            optimizer.zero_grad()
            batch, batch_indices, batch_weights = buffer.sample(params['batch_size'], beta)
            loss_v, sample_prios_v = calc_loss(batch, batch_weights, net, tgt_net.target_model,\
                                               params['gamma'] ** REWARD_STEPS, device=device)
            loss_v.backward()
            optimizer.step()
            buffer.update_priorities(batch_indices, sample_prios_v.data.cpu().numpy())
            
            if frame_idx % params['target_net_sync'] == 0:
                tgt_net.sync()

761: done 1 games, mean reward -21.000, speed 155.05 f/s
1517: done 2 games, mean reward -21.000, speed 177.77 f/s
2277: done 3 games, mean reward -21.000, speed 161.44 f/s
3035: done 4 games, mean reward -21.000, speed 160.70 f/s
3792: done 5 games, mean reward -21.000, speed 162.52 f/s
4554: done 6 games, mean reward -21.000, speed 155.61 f/s
5314: done 7 games, mean reward -21.000, speed 161.34 f/s
6074: done 8 games, mean reward -21.000, speed 165.64 f/s
6830: done 9 games, mean reward -21.000, speed 164.52 f/s
7590: done 10 games, mean reward -21.000, speed 162.30 f/s
8348: done 11 games, mean reward -21.000, speed 163.33 f/s
9107: done 12 games, mean reward -21.000, speed 169.52 f/s
9867: done 13 games, mean reward -21.000, speed 166.31 f/s
10835: done 14 games, mean reward -20.857, speed 34.16 f/s
11672: done 15 games, mean reward -20.800, speed 29.84 f/s
12434: done 16 games, mean reward -20.812, speed 29.92 f/s
13195: done 17 games, mean reward -20.824, speed 30.34 f/s
13957: 

112002: done 140 games, mean reward -20.690, speed 22.09 f/s
112763: done 141 games, mean reward -20.690, speed 22.24 f/s
113604: done 142 games, mean reward -20.690, speed 22.58 f/s
114365: done 143 games, mean reward -20.690, speed 22.21 f/s
115123: done 144 games, mean reward -20.700, speed 22.53 f/s
115909: done 145 games, mean reward -20.710, speed 22.52 f/s
116684: done 146 games, mean reward -20.710, speed 22.31 f/s
117553: done 147 games, mean reward -20.700, speed 22.20 f/s
118388: done 148 games, mean reward -20.690, speed 21.93 f/s
119179: done 149 games, mean reward -20.690, speed 22.04 f/s
119942: done 150 games, mean reward -20.690, speed 21.93 f/s
120782: done 151 games, mean reward -20.690, speed 21.90 f/s
121614: done 152 games, mean reward -20.680, speed 22.24 f/s
122377: done 153 games, mean reward -20.700, speed 22.10 f/s
123311: done 154 games, mean reward -20.690, speed 28.55 f/s
124090: done 155 games, mean reward -20.710, speed 28.10 f/s
124853: done 156 games, 

220344: done 275 games, mean reward -20.550, speed 28.55 f/s
221106: done 276 games, mean reward -20.550, speed 28.38 f/s
221865: done 277 games, mean reward -20.550, speed 28.73 f/s
222697: done 278 games, mean reward -20.550, speed 28.53 f/s
223458: done 279 games, mean reward -20.550, speed 28.53 f/s
224237: done 280 games, mean reward -20.550, speed 28.25 f/s
225076: done 281 games, mean reward -20.540, speed 28.36 f/s
225864: done 282 games, mean reward -20.550, speed 28.52 f/s
226697: done 283 games, mean reward -20.550, speed 28.27 f/s
227566: done 284 games, mean reward -20.540, speed 28.59 f/s
228399: done 285 games, mean reward -20.540, speed 28.54 f/s
229261: done 286 games, mean reward -20.540, speed 28.37 f/s
230125: done 287 games, mean reward -20.530, speed 28.23 f/s
230960: done 288 games, mean reward -20.520, speed 28.27 f/s
231746: done 289 games, mean reward -20.530, speed 28.61 f/s
232552: done 290 games, mean reward -20.540, speed 28.59 f/s
233308: done 291 games, 

327314: done 410 games, mean reward -20.720, speed 28.30 f/s
328101: done 411 games, mean reward -20.720, speed 28.38 f/s
328860: done 412 games, mean reward -20.730, speed 28.36 f/s
329617: done 413 games, mean reward -20.730, speed 28.54 f/s
330378: done 414 games, mean reward -20.730, speed 28.43 f/s
331136: done 415 games, mean reward -20.730, speed 28.48 f/s
331893: done 416 games, mean reward -20.740, speed 28.50 f/s
332656: done 417 games, mean reward -20.740, speed 28.39 f/s
333417: done 418 games, mean reward -20.740, speed 28.37 f/s
334178: done 419 games, mean reward -20.750, speed 28.48 f/s
334940: done 420 games, mean reward -20.760, speed 28.70 f/s
335699: done 421 games, mean reward -20.760, speed 28.55 f/s
336461: done 422 games, mean reward -20.770, speed 28.46 f/s
337223: done 423 games, mean reward -20.770, speed 28.21 f/s
337981: done 424 games, mean reward -20.770, speed 28.17 f/s
338738: done 425 games, mean reward -20.780, speed 28.07 f/s
339573: done 426 games, 

433919: done 545 games, mean reward -20.690, speed 27.96 f/s
434757: done 546 games, mean reward -20.680, speed 27.37 f/s
435520: done 547 games, mean reward -20.680, speed 26.48 f/s
436350: done 548 games, mean reward -20.680, speed 25.69 f/s
437113: done 549 games, mean reward -20.690, speed 28.28 f/s
437949: done 550 games, mean reward -20.680, speed 27.91 f/s
438738: done 551 games, mean reward -20.680, speed 28.14 f/s
439496: done 552 games, mean reward -20.690, speed 26.11 f/s
440256: done 553 games, mean reward -20.690, speed 27.04 f/s
441013: done 554 games, mean reward -20.690, speed 26.59 f/s
441773: done 555 games, mean reward -20.700, speed 25.80 f/s
442535: done 556 games, mean reward -20.710, speed 28.02 f/s
443294: done 557 games, mean reward -20.710, speed 27.50 f/s
444129: done 558 games, mean reward -20.700, speed 27.54 f/s
444965: done 559 games, mean reward -20.720, speed 28.19 f/s
445740: done 560 games, mean reward -20.720, speed 27.23 f/s
446619: done 561 games, 

542709: done 680 games, mean reward -20.520, speed 24.98 f/s
543576: done 681 games, mean reward -20.520, speed 24.96 f/s
544354: done 682 games, mean reward -20.520, speed 24.19 f/s
545132: done 683 games, mean reward -20.520, speed 24.30 f/s
545890: done 684 games, mean reward -20.520, speed 25.11 f/s
546725: done 685 games, mean reward -20.510, speed 25.99 f/s
547561: done 686 games, mean reward -20.510, speed 26.17 f/s
548318: done 687 games, mean reward -20.510, speed 25.13 f/s
549076: done 688 games, mean reward -20.510, speed 26.31 f/s
549990: done 689 games, mean reward -20.490, speed 26.37 f/s
550934: done 690 games, mean reward -20.470, speed 26.34 f/s
551691: done 691 games, mean reward -20.480, speed 26.20 f/s
552452: done 692 games, mean reward -20.490, speed 26.26 f/s
553208: done 693 games, mean reward -20.490, speed 26.17 f/s
553964: done 694 games, mean reward -20.490, speed 25.31 f/s
554925: done 695 games, mean reward -20.470, speed 25.66 f/s
555681: done 696 games, 

650638: done 815 games, mean reward -20.660, speed 27.36 f/s
651473: done 816 games, mean reward -20.670, speed 28.06 f/s
652361: done 817 games, mean reward -20.660, speed 27.71 f/s
653202: done 818 games, mean reward -20.660, speed 28.22 f/s
653960: done 819 games, mean reward -20.660, speed 27.88 f/s
654737: done 820 games, mean reward -20.660, speed 27.87 f/s
655576: done 821 games, mean reward -20.650, speed 28.17 f/s
656414: done 822 games, mean reward -20.650, speed 28.27 f/s
657224: done 823 games, mean reward -20.650, speed 28.11 f/s
