In [1]:
import sys
sys.path.append("../src/")
from plugin_write_and_run import *

In [2]:
%%write_and_run ../src/ppo_network.py
import tensorflow as tf
from tensorflow.keras.initializers import VarianceScaling
from tensorflow.keras.layers import (Add, Conv2D, Dense, Flatten, Input,
                                     Lambda, Subtract)
import tensorflow.keras.backend as K

from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import sys
sys.path.append("../src/")
from config import *

In [3]:
%%write_and_run -a ../src/ppo_network.py

def ppo_loss(advantage, old_prediction):
    def loss(y_true, y_pred):
        prob = y_true * y_pred
        old_prob = y_true * old_prediction
        ratio = prob/(old_prob + 1e-10)
        return -K.mean(K.minimum(ratio * advantage, K.clip(ratio, min_value=1 - CLIP_RATIO, max_value=1 + CLIP_RATIO) * advantage) + \
                       ENTROPY_C * -(prob * K.log(K.clip(prob, K.epsilon(), 1-K.epsilon()))))
    return loss

def build_ppo_network(n_actions, learning_rate=LR, input_shape=INPUT_SHAPE, history_length=HISTORY_LENGHT, hidden=HIDDEN):
    """
    Builds a dueling DQN as a Keras model

    Arguments:
        n_actions: Number of possible actions
        learning_rate: Learning rate
        input_shape: Shape of the preprocessed image
        history_length: Number of historical frames to stack togheter
        hidden: Integer, Number of filters in the final convolutional layer. 

    Returns:
        A compiled Keras model
    """
    obs = Input(shape=(input_shape[0], input_shape[1], history_length))
    
    advantages = Input(shape=(1,))
    
    predictions = Input(shape=(n_actions,))
        
    
    x = Lambda(lambda layer: layer / 255)(obs)  # normalize by 255

    x = Conv2D(32, (8, 8), strides=4, kernel_initializer=VarianceScaling(scale=2.), activation='relu', use_bias=False)(x)
    x = Conv2D(64, (4, 4), strides=2, kernel_initializer=VarianceScaling(scale=2.), activation='relu', use_bias=False)(x)
    x = Conv2D(64, (3, 3), strides=1, kernel_initializer=VarianceScaling(scale=2.), activation='relu', use_bias=False)(x)
    x = Conv2D(hidden, (7, 7), strides=1, kernel_initializer=VarianceScaling(scale=2.), activation='relu', use_bias=False)(x)

    values_stream = Flatten()(x)
    values = Dense(1, kernel_initializer=VarianceScaling(scale=2.), name="values")(values_stream)

    probs_stream = Flatten()(x)
    probs = Dense(n_actions, kernel_initializer=VarianceScaling(scale=2.), activation='softmax', name="probs")(probs_stream)

    # Build model
    ppo_net = Model(inputs=[obs, advantages, predictions], outputs=[probs, values])
    ppo_net.compile(Adam(learning_rate), loss={'probs' : ppo_loss(advantages, predictions), 'values' : 'mean_squared_error'}, 
                 loss_weights={'probs': 1e-1, 'values': 1.})

    return ppo_net

In [4]:
ppo_network = build_ppo_network(6)

In [5]:
ppo_network.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 84, 84, 4)]  0                                            
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 84, 84, 4)    0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 20, 20, 32)   8192        lambda[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 9, 9, 64)     32768       conv2d[0][0]                     
______________________________________________________________________________________________

In [6]:
from pong_wrapper import *

In [7]:
pw = PongWrapper(ENV_NAME, history_length=4)

In [8]:
probs, values = ppo_network([pw.reset()[None, :], np.zeros((1, 1)), np.zeros((1, 6))])

In [9]:
tf.squeeze(tf.random.categorical(probs, 1), axis=-1)

<tf.Tensor: shape=(1,), dtype=int64, numpy=array([5])>

In [10]:
probs

<tf.Tensor: shape=(1, 6), dtype=float32, numpy=
array([[0.18297316, 0.14623234, 0.18190148, 0.13831313, 0.13417417,
        0.21640573]], dtype=float32)>