In [None]:
# Uncomment to install required modules
# !pip install mxnet-cu101 opencv-python numpy tensorboard mxboard matplotlib pandas_bokeh gym[atari]


In [1]:
import os
import gym
import numpy as np
import cv2
from mxnet import nd, gluon, init, autograd
from mxnet.gluon.rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length

from mxnet.gluon import nn, rnn
import mxnet as mx
%matplotlib inline
import matplotlib.pyplot as plt
import pickle
import pandas as pd
import pandas_bokeh
import gc
import os
import multiprocessing
import multiprocessing.connection
from mxboard import SummaryWriter
import datetime
from tqdm.notebook import trange, tqdm
import random

# Configuration
output_name      = 'p4o_integrated'
game             = "SeaquestDeterministic-v4"
stacked_frames   = 4            # Number of stacked frames to use
context          = mx.gpu()     # GPU or CPU based training
opt_lr           = 2.5e-4       # Adam optimizer learning rate
opt_eps          = 1e-5         # Adam optimizer epsilon value
opt_clip         = .5           # Amount to clip gradients by
actor_coeff      = 1.           # Loss coefficient of the actor loss (for scaling the different loss components)
critic_coeff     = .5           # Loss coefficient of the critic loss
entropy_coeff    = .02          # Loss coefficient of the entropy term
pp_coeff         = 1.           # Loss coefficient of the predictive processing loss
hidden_size      = 512          # Number of latent hidden units before the LSTM layer
rnn_hidden_size  = 1024         # Number of hidden units in the LSTM
num_workers      = 16           # Number of parallel environments running
batch_steps      = 125
c, w, h          = 1, 84, 84
gamma            = .99
lamda            = .95
clip_range       = 0.10
schedule_steps   = 10000
cooldown_period  = 200
epochs           = 4
n_mini_batch     = 5
pred_steps       = 3

# Initialize other globals
env             = gym.make(game)
cur_eps         = np.zeros((num_workers), dtype=np.int32)
batch_size      = num_workers * batch_steps
mini_work_size  = batch_steps // n_mini_batch
mini_batch_size = batch_size // n_mini_batch
states          = np.zeros((num_workers, 1, stacked_frames*c, w, h), dtype=np.float32)
states_new      = np.zeros((num_workers, 210, 160, 3), dtype=np.float32)
lives           = np.zeros(num_workers, dtype=np.int32) + env.unwrapped.ale.lives()
total_episodes  = 0
all_grads       = []
output_dir      = './logs/'+output_name+datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")

In [None]:
# Colab tensorboard extension, uncomment if running in colab

#%load_ext tensorboard

In [None]:
# Colab tensorboard, uncomment if running in Colab

#%tensorboard --logdir './logs'


In [2]:
class LSTMCell(rnn.HybridRecurrentCell):
    # Long-Short Term Memory (LSTM) network cell modified for predictive processing

    def __init__(self, hidden_size,
                 i2h_weight_initializer=None, h2h_weight_initializer=None,
                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                 input_size=0, prefix=None, params=None, activation='tanh',
                 recurrent_activation='sigmoid'):
        super(LSTMCell, self).__init__(prefix=prefix, params=params)

        self._hidden_size = hidden_size
        self._input_size = input_size
        self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, env.action_space.n),
                                          init=i2h_weight_initializer,
                                          allow_deferred_init=True)
        self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size),
                                          init=h2h_weight_initializer,
                                          allow_deferred_init=True)
        self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
                                        init=i2h_bias_initializer,
                                        allow_deferred_init=True)
        self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
                                        init=h2h_bias_initializer,
                                        allow_deferred_init=True)
        self.test = 1
        self._activation = activation
        self._recurrent_activation = recurrent_activation


    def state_info(self, batch_size=0):
        return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'},
                {'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}]

    def _alias(self):
        return 'lstm'

    def __repr__(self):
        s = '{name}({mapping})'
        shape = self.i2h_weight.shape
        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
        return s.format(name=self.__class__.__name__,
                        mapping=mapping,
                        **self.__dict__)

    def hybrid_forward(self, F, inputs, states, i2h_weight,
                       h2h_weight, i2h_bias, h2h_bias):
        # pylint: disable=too-many-locals
        prefix = 't%d_'%self._counter
        i2h = F.FullyConnected(data=inputs[:,hidden_size:], weight=i2h_weight, bias=i2h_bias,
                               num_hidden=self._hidden_size*4, name=prefix+'i2h')
        PE = F.elemwise_sub(states[0][:,:hidden_size], inputs[:,:hidden_size], name=prefix+'min0')
        states[0] = F.concat(PE, states[0][:,hidden_size:], dim = 1)
        h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
                               num_hidden=self._hidden_size*4, name=prefix+'h2h')


        gates = F.elemwise_add(i2h, h2h, name=prefix+'plus0')
        slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')


        in_gate = self._get_activation(
            F, slice_gates[0], self._recurrent_activation, name=prefix+'i')
        forget_gate = self._get_activation(
            F, slice_gates[1], self._recurrent_activation, name=prefix+'f')
        in_transform = self._get_activation(
            F, slice_gates[2], self._activation, name=prefix+'c')
        out_gate = self._get_activation(
            F, slice_gates[3], self._recurrent_activation, name=prefix+'o')
        next_c = F.elemwise_add(F.elemwise_mul(forget_gate, states[1], name=prefix+'mul0'),
                                F.elemwise_mul(in_gate, in_transform, name=prefix+'mul1'),
                                name=prefix+'state')
        next_h = F.elemwise_mul(out_gate, F.Activation(next_c, act_type=self._activation, name=prefix+'activation0'),
                                name=prefix+'out')

        return next_h, [next_h, next_c]

    def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
               valid_length=None):
        """Unrolls an RNN cell across time steps.

        Parameters
        ----------
        length : int
            Number of steps to unroll.
        inputs : Symbol, list of Symbol, or None
            If `inputs` is a single Symbol (usually the output
            of Embedding symbol), it should have shape
            (batch_size, length, ...) if `layout` is 'NTC',
            or (length, batch_size, ...) if `layout` is 'TNC'.

            If `inputs` is a list of symbols (usually output of
            previous unroll), they should all have shape
            (batch_size, ...).
        begin_state : nested list of Symbol, optional
            Input states created by `begin_state()`
            or output state of another cell.
            Created from `begin_state()` if `None`.
        layout : str, optional
            `layout` of input symbol. Only used if inputs
            is a single Symbol.
        merge_outputs : bool, optional
            If `False`, returns outputs as a list of Symbols.
            If `True`, concatenates output across time steps
            and returns a single symbol with shape
            (batch_size, length, ...) if layout is 'NTC',
            or (length, batch_size, ...) if layout is 'TNC'.
            If `None`, output whatever is faster.
        valid_length : Symbol, NDArray or None
            `valid_length` specifies the length of the sequences in the batch without padding.
            This option is especially useful for building sequence-to-sequence models where
            the input and output sequences would potentially be padded.
            If `valid_length` is None, all sequences are assumed to have the same length.
            If `valid_length` is a Symbol or NDArray, it should have shape (batch_size,).
            The ith element will be the length of the ith sequence in the batch.
            The last valid state will be return and the padded outputs will be masked with 0.
            Note that `valid_length` must be smaller or equal to `length`.

        Returns
        -------
        outputs : list of Symbol or Symbol
            Symbol (if `merge_outputs` is True) or list of Symbols
            (if `merge_outputs` is False) corresponding to the output from
            the RNN from this unrolling.

        states : list of Symbol
            The new state of this RNN after this unrolling.
            The type of this symbol is same as the output of `begin_state()`.
        """
        # pylint: disable=too-many-locals
        self.reset()

        inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
        begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)

        states = begin_state
        outputs = []
        all_states = []
        all_states_h = []
        all_states_c = []

        all_states_h.append(states[0])
        all_states_c.append(states[1])

        for i in range(length):
            output, states = self(inputs[i], states)
            outputs.append(output)
            all_states_h.append(states[0])
            all_states_c.append(states[1])
            if valid_length is not None:
                all_states.append(states)
        if valid_length is not None:
            states = [F.SequenceLast(F.stack(*ele_list, axis=0),
                                     sequence_length=valid_length,
                                     use_sequence_length=True,
                                     axis=0)
                      for ele_list in zip(*all_states)]
            outputs = _mask_sequence_variable_length(F, outputs, length, valid_length, axis, True)
        outputs, _, _, _ = _format_sequence(length, outputs, layout, merge_outputs)

        return outputs, states, [all_states_h[:-1], all_states_c[:-1]]



In [3]:
# Define ResNet-based Encoder model

class Encoder(gluon.Block):
    def __init__(self):
        super(Encoder, self).__init__()
        with self.name_scope():
            self.layers = {}
            self.layers.items()
            self.channel_list = [24,32,64,128]
            for i, channels in enumerate(self.channel_list):
                layer = str(i)
                self.layers['conv'+layer] = nn.Conv2D(channels, 3, strides=1, padding=1)
                self.layers['max'+layer] = nn.MaxPool2D(pool_size=3, strides=2,padding=1)
                self.layers['res'+layer+'_0'] =  ResidualBlock(channels)
                self.layers['res'+layer+'_1'] =  ResidualBlock(channels)

            for key, val in self.layers.items():
                self.register_child(self.layers[key])

    def forward(self, x):
        for i, channels in enumerate(self.channel_list):
            layer = str(i)
            x = nd.relu(x)
            x = self.layers['conv'+layer](x)
            x = self.layers['max'+layer](x)
            x = self.layers['res'+layer+'_0'](x)
            x = self.layers['res'+layer+'_1'](x)
        return x


class ResidualBlock(gluon.Block):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        with self.name_scope():
            self.conv0 = nn.Conv2D(in_channels, 3, strides=1, padding=1)
            self.conv1 = nn.Conv2D(in_channels, 3, strides=1, padding=1)

        self.in_channels = in_channels

    def forward(self, x):
        h = nd.relu(x)
        h = self.conv0(h)
        h = nd.relu(h)
        h = self.conv1(h)
        x = h+x
        return x


# Main model definition

class Model(gluon.Block):
    def __init__(self, **kwargs):
        super(Model, self).__init__(**kwargs)
        with self.name_scope():
            self.enc = Encoder()
            self.dense = nn.Dense(hidden_size)
            self.lstm = LSTMCell(rnn_hidden_size)
            self.flat = nn.Flatten()
            self.action = nn.Dense(env.action_space.n)
            self.value = nn.Dense(1)

    def forward(self, x, hidden):
        batch_size, timesteps, stacked_frames, H, W = x.shape
        x = x.reshape(batch_size*timesteps, stacked_frames, H, W)
        x = nd.relu(self.enc(x))
        rnn_input = nd.tanh(self.dense(x))
        input_and_action = nd.concat(rnn_input, nd.zeros((batch_size*timesteps,env.action_space.n), dtype=np.float32).as_in_context(context), dim = 1)
        input_and_action = input_and_action.reshape(batch_size, timesteps, -1)
        x, hidden, all_hidden = self.lstm.unroll(length=timesteps, inputs=input_and_action, layout='NTC', begin_state=hidden, merge_outputs=True)

        rnn_output = x.reshape(batch_size*timesteps, -1)
        probs = self.action(rnn_output)
        values = self.value(rnn_output)

        return nd.softmax(probs.astype(np.float64)).astype(np.float32), values, hidden, rnn_input, rnn_output, all_hidden

    def rnn_only(self, rnn_input, hidden):
        input_and_action = nd.concat(rnn_input, nd.zeros((num_workers,env.action_space.n), dtype=np.float32).as_in_context(context), dim = 1)
        input_and_action = input_and_action.reshape(len(rnn_input),1, -1)
        x, hidden, _ = self.lstm.unroll(length=1, inputs=input_and_action, layout='NTC', begin_state=hidden, merge_outputs=True)
        rnn_output = x.reshape(len(rnn_input), -1)
        probs = self.action(rnn_output)
        values = self.value(rnn_output)
        return nd.softmax(probs.astype(np.float64)).astype(np.float32), values, hidden, rnn_output
    
    def rnn_input_pred(self, rnn_input, hidden, action, steps):
        
        # Reshape to workers format and only take the relevant inputs for each
        rnn_input = rnn_input.reshape(num_workers, mini_work_size, hidden_size)[:,:-pred_steps]
        # Concatenate all worker data
        rnn_input = rnn_input.reshape(rnn_input.shape[0] * rnn_input.shape[1], hidden_size)     
        
        # Concatenate hiddens
        hidden = [nd.reshape(hidden[0],(hidden[0].shape[0]*hidden[0].shape[1], rnn_hidden_size)), nd.reshape(hidden[1],(hidden[1].shape[0]*hidden[1].shape[1], rnn_hidden_size))]
        preds = []
        rnn_input_pred = rnn_input # Set to original input for first step
        for step in range(steps):
            cur_act = action[:,step:-(pred_steps-step)] # Get the relevant action for this step
            cur_act = cur_act.reshape(cur_act.shape[0] * cur_act.shape[1]) # Reshape to concatenate
            cur_act = nd.one_hot(cur_act, env.action_space.n, dtype=np.float32)
            input_and_action = nd.concat(rnn_input_pred, cur_act.as_in_context(context).reshape(len(rnn_input_pred),-1), dim = 1)
            input_and_action = input_and_action.reshape(len(rnn_input_pred), 1, -1)
            x, hidden, _ = self.lstm.unroll(length=1, inputs=input_and_action, layout='NTC', begin_state=hidden, merge_outputs=True)
            rnn_output = x.reshape(len(rnn_input_pred), -1)
            rnn_input_pred = rnn_output[:,:hidden_size]
            preds.append(rnn_input_pred)
        return preds

    def encode(self, x):
        batch_size, timesteps, stacked_frames, H, W = x.shape
        x = x.reshape(batch_size*timesteps, stacked_frames, H, W)
        encoded = nd.relu(self.enc(x))
        encoded = nd.tanh(self.dense(encoded))
        return encoded



In [4]:
# Multiprocessing setup

class Game(object):
    def __init__(self, game):
        self.env = gym.make(game)

    def reset(self):
        return self.env.reset()

    def close(self):
        self.env.close()

    def step(self, action):
        #self.env.render()
        return self.env.step(action)


def runner_process(remote, game):
    game = Game(game)
    while True:
        cmd, data = remote.recv()
        if cmd == "step":
            remote.send(game.step(data))
        elif cmd == "reset":
            remote.send(game.reset())
        elif cmd == "close":
            remote.close()
            break
        else:
            raise NotImplementedError

class Runner:
    def __init__(self, game):
        self.child, parent = multiprocessing.Pipe()
        self.process = multiprocessing.Process(target=runner_process, args=(parent, game))
        self.process.start()

In [5]:
# MxBoard / Tensorboard Monitoring Setup

class Monitoring:
    def __init__(self, output_dir):
        self.update = []
        self.total_episodes = []
        self.rewards = []
        self.mean100 = []
        self.max_reward = []
        self.update_max_reward = []
        self.update_rewards = []
        self.update_mean100 = []
        self.critic_loss = []
        self.actor_loss = []
        self.pp_loss = []
        self.entropy_loss = []
        self.min_action_prob = []
        self.max_action_prob = []
        self.avg_action_prob = []
        self.std_action_prob = []
        self.avg_value = []
        self.min_value = []
        self.max_value = []
        self.std_value = []
        self.ratio = []
        self.entropy_loss_buffer = []
        self.pp_loss_buffer = []
        self.actor_loss_buffer = []
        self.critic_loss_buffer = []
        self.ratio_buffer = []
        self.sw = SummaryWriter(logdir=output_dir, flush_secs=5)


    def process_episode(self, rewards, ep):
        self.rewards.append(rewards)
        self.total_episodes.append(ep)
        if len(self.rewards)>100:
            self.mean100.append(np.mean(self.rewards[-100:]))
        else:
            self.mean100.append(np.mean(self.rewards))
        self.max_reward.append(np.max(self.rewards))

    def process_rollout(self, data):
        for key, val in data.items():
            val = val.reshape(val.shape[0] * val.shape[1], *val.shape[2:])
        self.min_action_prob.append(np.min(data['action_dists']))
        self.max_action_prob.append(np.max(data['action_dists']))
        self.avg_action_prob.append(np.mean(np.exp(data['log_probs'])))
        self.std_action_prob.append(np.std(data['action_dists']))
        self.avg_value.append(np.mean(data['values']))
        self.min_value.append(np.min(data['values']))
        self.max_value.append(np.max(data['values']))
        self.std_value.append(np.std(data['values']))

    def process_minibatch_loss(self, entropy_loss, pp_loss, actor_loss, critic_loss, ratio):
        self.entropy_loss_buffer.append(entropy_loss)
        self.actor_loss_buffer.append(actor_loss)
        self.critic_loss_buffer.append(critic_loss)
        self.ratio_buffer.append(ratio)
        self.pp_loss_buffer.append(pp_loss)

    def process_update(self, update):
        self.update.append(update)
        self.entropy_loss.append(np.mean(self.entropy_loss_buffer))
        self.critic_loss.append(np.mean(self.critic_loss_buffer))
        self.actor_loss.append(np.mean(self.actor_loss_buffer))
        self.pp_loss.append(np.mean(self.pp_loss_buffer))
        self.ratio.append(np.mean(self.ratio_buffer))
        self.entropy_loss_buffer = []
        self.actor_loss_buffer = []
        self.critic_loss_buffer = []
        self.ratio_buffer = []
        self.pp_loss_buffer = []

    def update_mxboard(self):
        self.sw.add_scalar(tag='Losses/Critic_Loss',                           value=self.critic_loss[-1],         global_step=self.update[-1])
        self.sw.add_scalar(tag='Losses/Actor_Loss',                            value=self.actor_loss[-1],          global_step=self.update[-1])
        self.sw.add_scalar(tag='Losses/Entropy_Loss',                          value=self.entropy_loss[-1],        global_step=self.update[-1])
        self.sw.add_scalar(tag='Losses/PP_loss',                               value=self.pp_loss[-1],             global_step=self.update[-1])
        self.sw.add_scalar(tag='Probabilities/Average_Action_Probability',     value=self.avg_action_prob[-1],     global_step=self.update[-1])
        self.sw.add_scalar(tag='Probabilities/Min._Action_Probability',        value=self.min_action_prob[-1],     global_step=self.update[-1])
        self.sw.add_scalar(tag='Probabilities/Max._Action_Probability',        value=self.max_action_prob[-1],     global_step=self.update[-1])
        self.sw.add_scalar(tag='Probabilities/Std._Dev._Action_Probability',   value=self.std_action_prob[-1],     global_step=self.update[-1])
        self.sw.add_scalar(tag='Values/Average_State_Value',                   value=self.avg_value[-1],           global_step=self.update[-1])
        self.sw.add_scalar(tag='Values/Min._State_Value',                      value=self.min_value[-1],           global_step=self.update[-1])
        self.sw.add_scalar(tag='Values/Max._State_Value',                      value=self.max_value[-1],           global_step=self.update[-1])
        self.sw.add_scalar(tag='Values/Std._Dev._State_Value',                 value=self.std_value[-1],           global_step=self.update[-1])
        self.sw.add_scalar(tag='Ratio/Maximum_Ratio',                          value=self.ratio[-1],               global_step=self.update[-1])
        if self.mean100:
            self.update_max_reward.append(self.max_reward[-1])
            self.update_rewards.append(self.rewards[-1])
            self.update_mean100.append(self.mean100[-1])
            self.sw.add_scalar(tag='Rewards/Rewards',                              value=self.rewards[-1],             global_step=self.update[-1])
            self.sw.add_scalar(tag='Rewards/Avg._Reward_Last_100',                 value=self.mean100[-1],             global_step=self.update[-1])
            self.sw.add_scalar(tag='Rewards/Max._Reward',                          value=self.max_reward[-1],          global_step=self.update[-1])


In [6]:
# Utilities


def preprocess(state):
    state = cv2.resize(cv2.cvtColor(state, cv2.COLOR_RGB2GRAY), (h, w), interpolation=cv2.INTER_AREA)
    state = state / 255
    return state

def summarize(net, context):
    state = env.reset()
    state = preprocess(state)
    state = np.array([state,]*stacked_frames).reshape((1, 1,stacked_frames, h, w))
    net.summary(nd.array(state, ctx=context))

def reset_state(runner):
    runner.child.send(("reset", None))
    state = runner.child.recv()
    state = preprocess(state)
    state = np.array([state,]*stacked_frames)
    return state

def process_states(states_new, states, dones, infos, runners, game, lives):
    global cur_eps
    global total_episodes
    global monitor
    states_new = np.stack(states_new)
    states_new = np.array([preprocess(state) for state in states_new])
    states = np.append(states, states_new.reshape(num_workers,1,1,h,w), axis=2)
    states = np.delete(states, 0, axis=2)
    for idx, [cur_done, runner, cur_info] in enumerate(zip(dones, runners, infos)):
        if cur_done or (cur_info['ale.lives'] < lives[idx]):
            lives[idx] = cur_info['ale.lives']
            if cur_done:
                monitor.process_episode(cur_eps[idx], total_episodes)
                total_episodes += 1
                cur_eps[idx] = 0
                lives[idx] = env.unwrapped.ale.lives()
                states[idx] = reset_state(runner)
            dones[idx] = True
    return states, dones




def standardize(adv):
    return (adv - adv.mean()) / (nd.sqrt(nd.power(adv-adv.mean(),2).sum() / len(adv)) + 1e-10)

def calculate_advantages(done, rewards, values, states, hidden):
    advantages = np.zeros((values.shape[0], values.shape[1]), dtype=np.float32)
    next_advantage = 0

    # Get the value of the state that resulted from the last action
    _, next_value, *_ = net_new(nd.array(states, ctx=context, dtype=np.float32).reshape(num_workers, 1, stacked_frames*c, h, w),hidden)
    next_value = next_value.reshape(-1).asnumpy()

    # Work backwards through values to calculate GAE
    # Whenever a life was lost or an episode ended it will have been marked with done
    # Using this as a mask allows us to restart advantage calculation from these points
    for t in reversed(range(values.shape[1])):
        mask = 1.0 - done[:, t]
        delta = rewards[:, t] + gamma * next_value * mask - values[:, t]
        advantages[:, t] = delta + gamma * lamda * next_advantage * mask

        next_advantage = advantages[:, t]
        next_value = values[:, t]

    return advantages

In [7]:
# Loss functions

def calc_actor_loss(probs, mini_batch, clip_range, advantages):

    # Get log probabilities and PPO style clip the ratio between original model and current output
    log_probs = nd.log(nd.pick(probs,mini_batch['actions'])+1e-10)
    ratio = nd.exp(log_probs - mini_batch['log_probs'].detach())
    clipped_ratio = nd.clip(ratio,1.0 - clip_range,1.0 + clip_range)

    # Take the minimum result of the base ratio and the clipped ratio
    actor_loss = nd.concat((ratio * advantages.detach()).reshape(1,-1),(clipped_ratio * advantages.detach()).reshape(1,-1), dim = 0)
    actor_loss = nd.min(actor_loss, axis=0)
    actor_loss = -actor_loss.mean()
    return actor_loss, ratio

def calc_critic_loss(mini_batch, value, clip_range):
    # Calculate the batch return and clip the current value to stabilize the critic  
    batch_return = mini_batch['values'] + mini_batch['advantages']
    clipped_value = mini_batch['values'].detach() + nd.clip(value.reshape(-1) - mini_batch['values'].detach(), -clip_range, clip_range)
    critic_loss = nd.abs(clipped_value - batch_return.detach())
    critic_loss = critic_loss.mean()
    return critic_loss

def calc_entropy_loss(probs):
    probs = probs+1e-10
    entropy_loss = -(probs * probs.log()).sum(axis=1)
    entropy_loss = entropy_loss.mean()
    return entropy_loss


def calc_pp_loss(predictions, rnn_input):
    total_loss = nd.array([0.], ctx=context)
    for step in range(pred_steps):
        if step+1 == pred_steps:
            target_rnn_input = rnn_input[:,step+1:]
        else:
            target_rnn_input = rnn_input[:,step+1:-(pred_steps-(step+1))]
        pp_loss = nd.mean((predictions[step] - target_rnn_input.reshape(target_rnn_input.shape[0]*target_rnn_input.shape[1],-1))**2)
        total_loss = total_loss + pp_loss
    return total_loss


In [8]:
# Batch rollout

def rollout():
    global states
    global cur_eps
    global lives
    global hidden
    global hiddens

    # Initialize batch
    data = {'rewards':         np.zeros((num_workers, batch_steps)),\
            'values':          np.zeros((num_workers, batch_steps)),\
            'log_probs':       np.zeros((num_workers, batch_steps)),\
            'action_dists':    np.zeros((num_workers, batch_steps, env.action_space.n)),\
            'done':            np.zeros((num_workers, batch_steps)),\
            'states':          np.zeros((num_workers, batch_steps, stacked_frames * c, h, w)),\
            'actions':         np.zeros((num_workers, batch_steps), dtype=np.int32),\
            'rnn_inputs':      np.zeros((num_workers, batch_steps, hidden_size))}
    hiddens =  nd.zeros((num_workers, n_mini_batch, 2, rnn_hidden_size), ctx=context)

    for step in range(batch_steps):

        if step % mini_work_size == 0:
            hiddens[:, int(step / mini_work_size), 0] = hidden[0].asnumpy()
            hiddens[:, int(step / mini_work_size), 1] = hidden[1].asnumpy()

        # Forward pass
        probs, v, hidden, rnn_input, rnn_output, *_ = net(nd.array(states, ctx=context, dtype=np.float32),hidden)

        # Sample action
        act = mx.nd.sample_multinomial(probs)
        act_probs = nd.pick(probs, act)
        log_probs = nd.log(act_probs+1e-10)

        # Store the new values in the batch data
        data['rnn_inputs'][:, step]   = rnn_input.asnumpy().reshape(num_workers, -1) # step-1 because every rnn_input will be the t+1 target for input prediction
        data['states'][:, step]       = states.reshape(num_workers, stacked_frames * c, h, w)
        data['action_dists'][:, step] = probs.asnumpy()
        data['values'][:, step]       = v.reshape(-1).asnumpy()
        data['actions'][:, step]      = act.reshape(-1).asnumpy()
        data['log_probs'][:, step]    = log_probs.reshape(-1).asnumpy()

        # Execute the chosen actions in the workers and retrieve + process the next states
        for idx, runner in enumerate(runners):
            runner.child.send(("step", data['actions'][idx,step]))
        states_new, data['rewards'][:, step], data['done'][:, step], info = np.transpose(np.array([runner.child.recv() for runner in runners], dtype='object'))
        cur_eps = cur_eps + data['rewards'][:, step] # Track cumulative reward of the running episodes
        states, data['done'][:, step] = process_states(states_new, states, data['done'][:, step], info, runners, game, lives)

    # Process rollout for monitoring and convert to mx ndarray for backwards pass
    monitor.process_rollout(data)
    for key, val in data.items():
        data[key] = nd.array(val, ctx=context, dtype=np.float32)

    return data, hiddens




In [9]:
# Initialize model and prepare for main loop

net = Model()
net.initialize(ctx=context)
optimizer = gluon.Trainer(net.collect_params(), 'Adam', {'learning_rate': opt_lr, 'epsilon' : opt_eps})
net_new = Model()
net_new.initialize(ctx=context)

optimizer_new = gluon.Trainer(net_new.collect_params(), 'Adam', {'learning_rate': opt_lr, 'epsilon' : opt_eps})
monitor = Monitoring(output_dir)
hidden = mx.nd.random.uniform(shape=(num_workers, rnn_hidden_size), ctx=context, dtype=np.float32)
hidden = [hidden, hidden] # Gluon LSTM expects a list of recurrent state tensors (h0, c0)

net.summary(nd.array(states, ctx=context, dtype=np.float32), hidden)
probs, value, _, rnn_input, rnn_output, *_ = net_new(nd.array(states, ctx=context, dtype=np.float32), hidden)


# Copy parameters between nets
params1 = net_new.collect_params()
params2 = net.collect_params()
for p1, p2 in zip(params1.values(), params2.values()):
    p2.set_data(p1.data())


--------------------------------------------------------------------------------
        Layer (type)                                Output Shape         Param #
               Input  (16, 1, 4, 84, 84), [(16, 1024), (16, 1024)]               0
            Conv2D-1                            (16, 24, 84, 84)             888
         MaxPool2D-2                            (16, 24, 42, 42)               0
            Conv2D-3                            (16, 24, 42, 42)            5208
            Conv2D-4                            (16, 24, 42, 42)            5208
     ResidualBlock-5                            (16, 24, 42, 42)               0
            Conv2D-6                            (16, 24, 42, 42)            5208
            Conv2D-7                            (16, 24, 42, 42)            5208
     ResidualBlock-8                            (16, 24, 42, 42)               0
            Conv2D-9                            (16, 32, 42, 42)            6944
        MaxPool2D-10      

[10:08:37] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)


In [10]:
# Initialize worker processes

runners = [Runner(game) for i in range(num_workers)]
for i, runner in enumerate(runners):
    states[i] = reset_state(runner).reshape(1,1,stacked_frames * c,h,w)


In [None]:
# Main loop

with tqdm(range(schedule_steps+cooldown_period), desc='Training..') as updates:

    for update in updates:

        # Update learning rate
        if update<schedule_steps:
            progress = update/schedule_steps
            opt_lr = 2.5e-4 * (1 - progress)

        # Fetch new batch of data
        batch, hiddens = rollout()
        batch['new_values'] = batch['values']
        batch['advantages'] = nd.zeros((num_workers,batch_steps), ctx=context, dtype=np.float32)

        for _ in range(epochs):

            for mbatch in range(0, n_mini_batch):

                # Copy parameters between nets
                params1 = net_new.collect_params()
                params2 = net.collect_params()
                for p1, p2 in zip(params1.values(), params2.values()):
                    p2.set_data(p1.data())

                    
                first = mbatch * mini_work_size
                final = first + mini_work_size

                # Create and fill new mini_batch
                mini_batch = {}
                
                # Reshape states to avoid mxnet slice dimension limit
                for key, val in batch.items():
                    mini_batch[key] = val[:,first:final]

                # Retrieve relevant hidden states (only the first, to minimize usage of stale hidden states)
                mb_initial_hidden = [hiddens[:, mbatch, 0], hiddens[:, mbatch, 1]] # adjust to gluon LSTM expected data format


                with autograd.record():

                    # Get updated outputs with latest model
                    probs, value, _, rnn_input, rnn_output, all_hidden = net_new(mini_batch['states'], mb_initial_hidden)
                    batch['new_values'][:,mbatch*mini_work_size:mbatch*mini_work_size+mini_work_size] = value.reshape(num_workers,-1).detach()
                    
                    # Update predictions for predictive processing
                    all_hidden[0] = nd.stack(*all_hidden[0]).swapaxes(0,1)
                    all_hidden[1] = nd.stack(*all_hidden[1]).swapaxes(0,1)
                    pred_hidden = [all_hidden[0][:,:-pred_steps],all_hidden[1][:,:-pred_steps]]
                    predictions = net_new.rnn_input_pred(rnn_input, pred_hidden, mini_batch['actions'], pred_steps)
                    pp_loss = calc_pp_loss(predictions, rnn_input.reshape(num_workers,mini_work_size,-1))

                    batch['advantages'][:,first:] = nd.array(calculate_advantages(batch['done'][:,first:].asnumpy(), batch['rewards'][:,first:].asnumpy(), batch['new_values'][:,first:].asnumpy(), states, hidden), ctx=context)
                    mini_batch['advantages'] = batch['advantages'][:,first:final]
                    
                    # Concatenate all worker data for further loss calculation
                    for key, val in mini_batch.items():
                        mini_batch[key] = val.reshape(val.shape[0] * val.shape[1], *val.shape[2:])

                    standardized_adv = standardize(mini_batch['advantages'])

                    # Calculate losses
                    actor_loss, ratio  = calc_actor_loss(probs, mini_batch, clip_range, standardized_adv)
                    critic_loss  = calc_critic_loss(mini_batch, value, clip_range)
                    entropy_loss = calc_entropy_loss(probs)

                    # Total loss
                    loss = actor_coeff * actor_loss \
                         + critic_coeff * critic_loss  \
                         + pp_coeff * pp_loss \
                         - entropy_coeff * entropy_loss

                    optimizer_new.set_learning_rate(opt_lr)

                # Backward pass
                loss.backward()
                grads = [i.grad(context) for i in net_new.collect_params().values() if i._grad is not None]
                gluon.utils.clip_global_norm(grads, opt_clip)
                optimizer_new.step(1)

                # Update Tensorboard and logging
                monitor.process_minibatch_loss(entropy_coeff * entropy_loss.asnumpy()[0],\
                                               pp_coeff * pp_loss.asnumpy()[0],\
                                               actor_coeff * actor_loss.asnumpy()[0],\
                                               critic_coeff * critic_loss.asnumpy()[0],\
                                               np.max(ratio.asnumpy()))

        monitor.process_update(update)
        monitor.update_mxboard()
        if monitor.mean100: 
            updates.set_description('Training.. (Avg. last 100: %.2f)' % monitor.mean100[-1])
        if update % 1000 == 0:
            save_data = monitor.__dict__.copy()
            del save_data["sw"]
            pickle.dump(save_data, open( output_dir + "/monitor" + str(update) + ".pkl", "wb" ) )
            net.save_parameters(output_dir + "/net"+ str(update) +".params")

Training..:   0%|          | 0/10200 [00:00<?, ?it/s]

In [None]:
save_data = monitor.__dict__.copy()
del save_data["sw"]
pickle.dump(save_data, open( output_dir + "/monitor.pkl", "wb" ) )
net.save_parameters(output_dir + "/net.params")