Skip to content

Commit

Permalink
[Feature] filter_size can be an array (#326)
Browse files Browse the repository at this point in the history
* Issue #317 [feature request] filter_size can be a array instead of one value

* Issues #326 [Feature] filter_size can be a array

* Issue #326 [Feature] filter_size can be a array

* Issues #326 [Feature] filter_size can be a array: Line too long

* Update changelog.rst

* Issue #326 [Feature] filter_size can be a array, the added test code is test_a2c_conv.py

* Issues #326 [Feature] filter_size can be a array, remove the unused variables

* Issues #326 [Feature] filter_size can be a array, remove the unused library

* Issue #326, [Feature] filter_size can be a array. Clean up the test code
  • Loading branch information
yutingsz authored and araffin committed May 17, 2019
1 parent c9be8dc commit fddf169
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 4 deletions.
4 changes: 2 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Release 2.5.2a0 (WIP)
- Bugfix for ``VecEnvWrapper.__getattr__`` which enables access to class attributes inherited from parent classes.
- Removed ``get_available_gpus`` function which hadn't been used anywhere (@Pastafarianist)
- Fixed path splitting in ``TensorboardWriter._get_latest_run_id()`` on Windows machines (@PatrickWalter214)

- The parameter ``filter_size`` of the function ``conv`` in A2C utils now supports passing a list/tuple of two integers (height and width), in order to have non-squared kernel matrix. (@yutingsz)

Release 2.5.1 (2019-05-04)
--------------------------
Expand Down Expand Up @@ -300,4 +300,4 @@ In random order...

Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @JohannesAck
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz
13 changes: 11 additions & 2 deletions stable_baselines/a2c/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,23 @@ def conv(input_tensor, scope, *, n_filters, filter_size, stride,
:param input_tensor: (TensorFlow Tensor) The input tensor for the convolution
:param scope: (str) The TensorFlow variable scope
:param n_filters: (int) The number of filters
:param filter_size: (int) The filter size
:param filter_size: (Union[int, [int], tuple<int, int>]) The filter size for the squared kernel matrix,
or the height and width of kernel filter if the input is a list or tuple
:param stride: (int) The stride of the convolution
:param pad: (str) The padding type ('VALID' or 'SAME')
:param init_scale: (int) The initialization scale
:param data_format: (str) The data format for the convolution weights
:param one_dim_bias: (bool) If the bias should be one dimentional or not
:return: (TensorFlow Tensor) 2d convolutional layer
"""
if isinstance(filter_size, list) or isinstance(filter_size, tuple):
assert len(filter_size) == 2, \
"Filter size must have 2 elements (height, width), {} were given".format(len(filter_size))
filter_height = filter_size[0]
filter_width = filter_size[1]
else:
filter_height = filter_size
filter_width = filter_size
if data_format == 'NHWC':
channel_ax = 3
strides = [1, stride, stride, 1]
Expand All @@ -122,7 +131,7 @@ def conv(input_tensor, scope, *, n_filters, filter_size, stride,
raise NotImplementedError
bias_var_shape = [n_filters] if one_dim_bias else [1, n_filters, 1, 1]
n_input = input_tensor.get_shape()[channel_ax].value
wshape = [filter_size, filter_size, n_input, n_filters]
wshape = [filter_height, filter_width, n_input, n_filters]
with tf.variable_scope(scope):
weight = tf.get_variable("w", wshape, initializer=ortho_init(init_scale))
bias = tf.get_variable("b", bias_var_shape, initializer=tf.constant_initializer(0.0))
Expand Down
41 changes: 41 additions & 0 deletions tests/test_a2c_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import tensorflow as tf
import numpy as np
from stable_baselines.a2c.utils import conv
import gym
from stable_baselines.common.input import observation_input

ENV_ID = 'BreakoutNoFrameskip-v4'
SEED = 3


def test_conv_kernel():
"""
test convolution kernel with various input formats
"""
filter_size_1 = 4 # The size of squared filter for the first layer
filter_size_2 = (3, 5) # The size of non-squared filter for the second layer
target_shape_1 = [2, 52, 40, 32] # The desired shape of the first layer
target_shape_2 = [2, 13, 9, 32] # The desired shape of the second layer
kwargs = {}
n_envs = 1
n_steps = 2
n_batch = n_envs * n_steps
scale = False
env = gym.make(ENV_ID)
ob_space = env.observation_space

graph = tf.Graph()
with graph.as_default():
_, scaled_images = observation_input(ob_space, n_batch, scale=scale)
activ = tf.nn.relu
layer_1 = activ(conv(scaled_images, 'c1', n_filters=32, filter_size=filter_size_1, stride=4
, init_scale=np.sqrt(2), **kwargs))
layer_2 = activ(conv(layer_1, 'c2', n_filters=32, filter_size=filter_size_2, stride=4
, init_scale=np.sqrt(2), **kwargs))
assert layer_1.shape == target_shape_1 \
, "The shape of layer based on the squared kernel matrix is not correct. " \
"The current shape is {} and the desired shape is {}".format(layer_1.shape, target_shape_1)
assert layer_2.shape == target_shape_2 \
, "The shape of layer based on the non-squared kernel matrix is not correct. " \
"The current shape is {} and the desired shape is {}".format(layer_2.shape, target_shape_2)
env.close()

0 comments on commit fddf169

Please sign in to comment.