In [1]:
import sys
sys.path.append("../src/")

from config import *
from pong_wrapper import *
from process_image import *
from utilities import *
from plugin_write_and_run import *

In [2]:
%%write_and_run ../src/a2c_networks.py
import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as kl
from tensorflow.keras.initializers import VarianceScaling

In [3]:
%%write_and_run -a ../src/a2c_networks.py

class ProbabilityDistribution(tf.keras.Model):
    def call(self, logits, **kwargs):
        # Random distribution
        return tf.squeeze(tf.random.categorical(logits, 1), axis=-1)


class Model(tf.keras.Model):
    def __init__(self, num_actions, hidden):
        # Note: no tf.get_variable(), just simple Keras API!
        super().__init__('mlp_policy')
        self.normalize = kl.Lambda(lambda layer: layer / 255)    # normalize by 255
        self.conv1 = kl.Conv2D(32, (8, 8), strides=4, kernel_initializer=VarianceScaling(scale=2.), activation='relu', use_bias=False)
        self.conv2 = kl.Conv2D(64, (4, 4), strides=2, kernel_initializer=VarianceScaling(scale=2.), activation='relu', use_bias=False)
        self.conv3 = kl.Conv2D(64, (3, 3), strides=1, kernel_initializer=VarianceScaling(scale=2.), activation='relu', use_bias=False)
        self.conv4 = kl.Conv2D(hidden, (7, 7), strides=1, kernel_initializer=VarianceScaling(scale=2.), activation='relu', use_bias=False)
        
        self.flatten = kl.Flatten()
        
        self.value = kl.Dense(1, kernel_initializer=VarianceScaling(scale=2.), name="value")
        self.logits = kl.Dense(num_actions, kernel_initializer=VarianceScaling(scale=2.), name='policy_logits')
        
        self.dist = ProbabilityDistribution()

    def call(self, inputs, **kwargs):
        # Inputs is a numpy array, convert to a tensor.
        x = tf.convert_to_tensor(inputs)
        # Separate hidden layers from the same input tensor.
        x = self.normalize(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        
        return self.logits(x), self.value(x)

    def action_value(self, obs):
        # Executes `call()` under the hood.
        logits, value = self.predict_on_batch(obs)
        action = self.dist.predict_on_batch(logits)
        # Another way to sample actions:
        #     action = tf.random.categorical(logits, 1)
        # Will become clearer later why we don't use it.
        return np.squeeze(action, axis=-1), np.squeeze(value, axis=-1)

In [4]:
pw = PongWrapper(ENV_NAME)

In [5]:
model = Model(num_actions=pw.env.action_space.n, hidden=hidden)

In [6]:
obs = pw.reset()

In [7]:
obs.shape

(84, 84, 4)

In [11]:
obs is pw.state

True

In [9]:
# No feed_dict or tf.Session() needed at all!
action, value = model.action_value(obs[None, :])
print(action, value) # [1] [-0.00145713]

4 [0.02945016]


In [10]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lambda (Lambda)              multiple                  0         
_________________________________________________________________
conv2d (Conv2D)              multiple                  8192      
_________________________________________________________________
conv2d_1 (Conv2D)            multiple                  32768     
_________________________________________________________________
conv2d_2 (Conv2D)            multiple                  36864     
_________________________________________________________________
conv2d_3 (Conv2D)            multiple                  3211264   
_________________________________________________________________
flatten (Flatten)            multiple                  0         
_________________________________________________________________
value (Dense)                multiple                  1025  