In [1]:
import gym
import ptan
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from lib import dqn_model, common_dqn_distrib

In [2]:
#These 2 tags used to activate screenshot saving for debug, but it will lower the training speed
#For every 10000 frames we activate the first tag to save the first 200 states in the buffer to get the distribution
#2nd tag to save non-zero reward or last episode distribution projection, this is used for debug and visualize.
SAVE_STATES_IMG = False
SAVE_TRANSITIONS_IMG = False

#we switch matplotlib as headless mode. Headless mode: when ploting is not neded, code has debug tag for saving probability
#distribution, so we can use visual method for debug and training.
if SAVE_STATES_IMG or SAVE_TRANSITIONS_IMG:
    import matplotlib as mpl
    mpl.use("Agg")
    import matplotlib.pylab as plt


Vmax = 10
Vmin = -10
#number of atoms
N_ATOMS = 51
#each atom width
DELTA_Z = (Vmax - Vmin) / (N_ATOMS - 1)

#The number of states to keep in buffer. This is to calculate mean value and the frequency to update the mean value.
STATES_TO_EVALUATE = 1000
EVAL_EVERY_FRAME = 100

In [3]:
#nn construction function has a n_actions * n_atoms array as output, it contains each action probabilities, and batch 
#dimension, therefore the output will be 3d.
class DistributionalDQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DistributionalDQN, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        
        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions * N_ATOMS)
        )
        #we use atom value to register torch tensor to use it later.
        self.register_buffer("supports", torch.arange(Vmin, Vmax + DELTA_Z, DELTA_Z))
        self.softmax = nn.Softmax(dim=1)
    
    #forward() function basically the same, just the output need to fit the output shape.
    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))
    
    def forward(self, x):
        batch_size = x.size()[0]
        fx = x.float() / 256
        conv_out = self.conv(fx).view(batch_size, -1)
        fc_out = self.fc(conv_out)
        return fc_out.view(batch_size, -1, N_ATOMS)
    
    #both() will return the original distribution and q value. Q-value will be used to choose the action
    #although using distribution can let us have different action to choose, but if we use greedy method to the q-value
    #it will be same as the basic dqn version
    #to get q-value, we calculate the sum of normal distribution and atom value, 
    #result is the expectation of the distribution
    def both(self, x):
        cat_out = self(x)
        probs = self.apply_softmax(cat_out)
        weights = probs * self.supports
        res = weights.sum(dim=2)
        return cat_out, res
    
    #tool functions, qvals() to calculate q value, 2nd apply softmax to the output tensor, to let it have proper shape.
    def qvals(self, x):
        return self.both(x)[1]
    
    def apply_softmax(self, t):
        return self.softmax(t.view(-1, N_ATOMS)).view(t.size())

In [4]:
def calc_values_of_states(states, net, device="cpu"):
    mean_vals = []
    for batch in np.array_split(states, 64):
        states_v = torch.tensor(batch).to(device)
        action_values_v = net.qvals(states_v)
        best_action_values_v = action_values_v.max(1)[0]
        mean_vals.append(best_action_values_v.mean().item())
    return np.mean(mean_vals)

In [5]:
def save_state_images(frame_idx, states, net, device="cpu", max_states=200):
    ofs = 0
    p = np.arange(Vmin, Vmax + DELTA_Z, DELTA_Z)
    for batch in np.array_split(states, 64):
        states_v = torch.tensor(batch).to(device)
        action_prob = net.apply_softmax(net(states_v)).data.cpu().numpy()
        batch_size, num_actions, _ = action_prob.shape
        for batch_idx in range(batch_size):
            plt.clf()
            for action_idx in range(num_actions):
                plt.subplot(num_actions, 1, action_idx+1)
                plt.bar(p, action_prob[batch_idx, action_idx], width=0.5)
            plt.savefig("states/%05d_%08d.png" % (ofs + batch_idx, frame_idx))
        ofs += batch_size
        if ofs >= max_states:
            break

In [6]:
def save_transition_images(batch_size, predicted, projected, next_distr, dones, rewards, save_prefix):
    for batch_idx in range(batch_size):
        is_done = dones[batch_idx]
        reward = rewards[batch_idx]
        plt.clf()
        p = np.arange(Vmin, Vmax + DELTA_Z, DELTA_Z)
        plt.subplot(3, 1, 1)
        plt.bar(p, predicted[batch_idx], width=0.5)
        plt.title("Predicted")
        plt.subplot(3, 1, 2)
        plt.bar(p, projected[batch_idx], width=0.5)
        plt.title("Projected")
        plt.subplot(3, 1, 3)
        plt.bar(p, next_distr[batch_idx], width=0.5)
        plt.title("Next state")
        suffix = ""
        if reward != 0.0:
            suffix = suffix + "_%.0f" % reward
        if is_done:
            suffix = suffix + "_done"
        plt.savefig("%s_%02d%s.png" % (save_prefix, batch_idx, suffix))

In [7]:
def calc_loss(batch, net, tgt_net, gamma, device="cpu", save_prefix=None):
    #we unzip batch and change it to tensor
    states, actions, rewards, dones, next_states = common_dqn_distrib.unpack_batch(batch)
    batch_size = len(batch)
    
    states_v = torch.tensor(states).to(device)
    actions_v = torch.tensor(actions).to(device)
    next_states_v = torch.tensor(next_states).to(device)
    
    #we get next state distribution and q values, so we call both() and get the best action, then we apply softmax to
    #the distribution to transform it as array
    next_distr_v, next_qvals_v = tgt_net.both(next_states_v)
    next_actions = next_qvals_v.max(1)[1].data.cpu().numpy()
    next_distr = tgt_net.apply_softmax(next_distr_v).data.cpu().numpy()
    
    #then we use Bellman equation to extract best action distribution and activate the projection
    #the result is how we hope the output of the target distribution will look like
    next_best_distr = next_distr[range(batch_size), next_actions]
    dones = dones.astype(np.bool)
    proj_distr = common_dqn_distrib.distr_projection(next_best_distr, rewards, dones, Vmin, Vmax, N_ATOMS, gamma)
    
    #we calculate network output, also the loss from projection output and the network output from the chosen action,
    #this is the Kullback Leibler divergence. We use log_softmax function to calculate log function
    distr_v = net(states_v)
    state_action_values = distr_v[range(batch_size), actions_v.data]
    state_log_sm_v = F.log_softmax(state_action_values, dim=1)
    proj_distr_v = torch.tensor(proj_distr).to(device)
    
    if save_prefix is not None:
        pred = F.softmax(state_action_values, dim=1).data.cpu().numpy()
        save_transition_images(batch_size, pred, proj_distr, next_best_distr, dones, rewards, save_prefix)
        
        
    loss_v = -state_log_sm_v * proj_distr_v
    return loss_v.sum(dim=1).mean()

In [8]:
if __name__ == "__main__":
    #input hyperparameters, check CUDA available, create environment,then we use PTAN DQN wrapper to wrap up the environment
    params = common_dqn_distrib.HYPERPARAMS['pong']
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda", default=True, action="store_true", help="Enable cuda")
    args, unknown = parser.parse_known_args()
    device = torch.device("cuda" if args.cuda else "cpu")
    
    env = gym.make(params['env_name'])
    env = ptan.common.wrappers.wrap_dqn(env)
    
    #we make a writer for the environment and action dimension
    writer = SummaryWriter(comment="-" + params['run_name'] + "-distrib")
    net = DistributionalDQN(env.observation_space.shape, env.action_space.n).to(device)
    #the wrapper below can create a copy of DQN network, which is target network, and constantly synchronize with online
    #network
    tgt_net = ptan.agent.TargetNet(net)
    
    #we create agent to change observation to action value, we also need action selector to choose the action we use
    #We use epsilon greedy method as action selector here
    selector = ptan.actions.EpsilonGreedyActionSelector(epsilon=params['epsilon_start'])
    epsilon_tracker = common_dqn_distrib.EpsilonTracker(selector, params)
    agent = ptan.agent.DQNAgent(lambda x: net.qvals(x), selector, device=device)
    
    #experience source is from one step ExperienceSourceFirstLast and replay buffer, it will store fixed step transitions
    exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=params['gamma'], steps_count=1)
    buffer = ptan.experience.ExperienceReplayBuffer(exp_source, buffer_size=params['replay_size'])
    
    #create optimizer and frame counter
    optimizer = optim.Adam(net.parameters(), lr=params['learning_rate'])
    frame_idx = 0
    
    eval_states = None
    prev_save = 0
    save_prefix = None
    
    #reward tracker will report mean reward when episode end, and increase frame counter by 1, also getting a transition
    #from frame buffer.
    #buffer.populate(1) will activate following actions:
    #ExperienceReplayBuffer will request for next transition from experience source.
    #Experience source will send the observation to agent to get the action
    #Action selector which use epsilon greedy method will choose an action based on greedy or random
    #Action will be return to experience source and input to the environment to get reward and next observation, 
    # current observation, action, reward, next observation will be stored into replay buffer
    #transfer information will be stored in replay buffer, and oldest observation will be dropped
    with common_dqn_distrib.RewardTracker(writer, params['stop_reward']) as reward_tracker:
        while True:
            frame_idx += 1
            buffer.populate(1)
            epsilon_tracker.frame(frame_idx)
            
            #check undiscounted reward list after finishing an episode, and send to reward tracker to record the data
            #Maybe it just play one step or didn't have finished episode, if it returns true, it means the mean reward
            #reached the reward boundary and we can break and stop training
            new_rewards = exp_source.pop_total_rewards()
            if new_rewards:
                if reward_tracker.reward(new_rewards[0], frame_idx, selector.epsilon):
                    break
            
            #we check buffer has cached enough data to start training or not. If not, we wait for more data.
            if len(buffer) < params['replay_initial']:
                continue
                
            if eval_states is None:
                eval_states = buffer.sample(STATES_TO_EVALUATE)
                eval_states = [np.array(transition.state, copy=False) for transition in eval_states]
                eval_states = np.array(eval_states, copy=False)
            
            #here we use Stochastic Gradient Descent(SGD) to calculate loss, zero the gradient,batch from the replay buffer
            optimizer.zero_grad()
            batch = buffer.sample(params['batch_size'])
            
            
            save_prefix = None
            if SAVE_TRANSITIONS_IMG:
                interesting = any(map(lambda s: s.last_state is None or s.reward != 0.0, batch))
                if interesting and frame_idx // 30000 > prev_save:
                    save_prefix = "images/img_%08d" % frame_idx
                    prev_save = frame_idx // 30000
            
            
            loss_v = calc_loss(batch, net, tgt_net.target_model, gamma=params['gamma'], device=device)
            loss_v.backward()
            optimizer.step()
            
            #synchronize the target network with the online network constantly
            if frame_idx % params['target_net_sync'] == 0:
                tgt_net.sync()
                
            if frame_idx % EVAL_EVERY_FRAME == 0:
                mean_val = calc_values_of_states(eval_states, net, device=device)
                writer.add_scalar("values_mean", mean_val, frame_idx)

            if SAVE_STATES_IMG and frame_idx % 10000 == 0:
                save_state_images(frame_idx, eval_states, net, device=device)

918: done 1 games, mean reward -20.000, speed 202.35 f/s, eps 0.99
1678: done 2 games, mean reward -20.500, speed 232.78 f/s, eps 0.98
2557: done 3 games, mean reward -20.667, speed 224.73 f/s, eps 0.97
3437: done 4 games, mean reward -20.750, speed 186.29 f/s, eps 0.97
4453: done 5 games, mean reward -20.600, speed 166.15 f/s, eps 0.96
5239: done 6 games, mean reward -20.667, speed 159.58 f/s, eps 0.95
6298: done 7 games, mean reward -20.714, speed 158.13 f/s, eps 0.94
7331: done 8 games, mean reward -20.625, speed 160.45 f/s, eps 0.93
8508: done 9 games, mean reward -20.111, speed 156.74 f/s, eps 0.91
9353: done 10 games, mean reward -20.200, speed 159.89 f/s, eps 0.91
10238: done 11 games, mean reward -20.273, speed 90.77 f/s, eps 0.90
10995: done 12 games, mean reward -20.333, speed 47.33 f/s, eps 0.89
12107: done 13 games, mean reward -20.308, speed 42.34 f/s, eps 0.88
13016: done 14 games, mean reward -20.357, speed 43.96 f/s, eps 0.87
13900: done 15 games, mean reward -20.400, s

144563: done 120 games, mean reward -19.400, speed 44.46 f/s, eps 0.02
146359: done 121 games, mean reward -19.390, speed 44.20 f/s, eps 0.02
148739: done 122 games, mean reward -19.310, speed 44.62 f/s, eps 0.02
150210: done 123 games, mean reward -19.290, speed 43.61 f/s, eps 0.02
151521: done 124 games, mean reward -19.280, speed 45.30 f/s, eps 0.02
153791: done 125 games, mean reward -19.220, speed 45.30 f/s, eps 0.02
155055: done 126 games, mean reward -19.210, speed 45.35 f/s, eps 0.02
156522: done 127 games, mean reward -19.220, speed 45.60 f/s, eps 0.02
157892: done 128 games, mean reward -19.190, speed 45.69 f/s, eps 0.02
159138: done 129 games, mean reward -19.210, speed 45.50 f/s, eps 0.02
160629: done 130 games, mean reward -19.210, speed 45.70 f/s, eps 0.02
161959: done 131 games, mean reward -19.200, speed 45.67 f/s, eps 0.02
164278: done 132 games, mean reward -19.050, speed 45.65 f/s, eps 0.02
165836: done 133 games, mean reward -19.030, speed 45.71 f/s, eps 0.02
167340

447341: done 236 games, mean reward 0.190, speed 45.87 f/s, eps 0.02
449472: done 237 games, mean reward 0.410, speed 46.07 f/s, eps 0.02
451853: done 238 games, mean reward 0.680, speed 45.86 f/s, eps 0.02
453837: done 239 games, mean reward 1.030, speed 46.02 f/s, eps 0.02
456013: done 240 games, mean reward 1.260, speed 45.70 f/s, eps 0.02
457914: done 241 games, mean reward 1.540, speed 45.64 f/s, eps 0.02
459833: done 242 games, mean reward 1.840, speed 46.01 f/s, eps 0.02
461626: done 243 games, mean reward 2.190, speed 45.74 f/s, eps 0.02
463699: done 244 games, mean reward 2.480, speed 45.90 f/s, eps 0.02
465706: done 245 games, mean reward 2.750, speed 45.36 f/s, eps 0.02
467430: done 246 games, mean reward 3.070, speed 45.76 f/s, eps 0.02
469316: done 247 games, mean reward 3.380, speed 45.73 f/s, eps 0.02
471260: done 248 games, mean reward 3.660, speed 46.06 f/s, eps 0.02
473189: done 249 games, mean reward 3.900, speed 45.63 f/s, eps 0.02
475072: done 250 games, mean rewar