Skip to content

Commit

Permalink
Merge pull request #650 from PGrothaus/spatial_pyramid
Browse files Browse the repository at this point in the history
Spatial Pyramid Pooling
  • Loading branch information
f0k committed Apr 27, 2016
2 parents eee78c5 + a31557f commit f1d290c
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/modules/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,5 @@
dnn.Pool2DDNNLayer
dnn.MaxPool3DDNNLayer
dnn.Pool3DDNNLayer
dnn.SpatialPyramidPoolingDNNLayer

94 changes: 91 additions & 3 deletions lasagne/layers/dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"MaxPool3DDNNLayer",
"Conv2DDNNLayer",
"Conv3DDNNLayer",
"SpatialPyramidPoolingDNNLayer",
]


Expand Down Expand Up @@ -139,7 +140,6 @@ def __init__(self, incoming, pool_size, stride=None,
**kwargs)



class Pool3DDNNLayer(Layer):
"""
3D pooling layer
Expand Down Expand Up @@ -208,7 +208,6 @@ def __init__(self, incoming, pool_size, stride=None, pad=(0, 0, 0),
raise NotImplementedError("Pool3DDNNLayer does not support "
"ignore_border=False.")


def get_output_shape_for(self, input_shape):
output_shape = list(input_shape) # copy / convert to mutable list

Expand All @@ -225,7 +224,7 @@ def get_output_shape_for(self, input_shape):
pad=self.pad[1],
ignore_border=True,
)

output_shape[4] = pool_output_length(input_shape[4],
pool_size=self.pool_size[2],
stride=self.stride[2],
Expand Down Expand Up @@ -503,3 +502,92 @@ def convolve(self, input, **kwargs):
conv_mode=conv_mode
)
return conved


class SpatialPyramidPoolingDNNLayer(Layer):
"""
Spatial Pyramid Pooling Layer
Performs spatial pyramid pooling (SPP) over the input.
It will turn a 2D input of arbitrary size into an output of fixed
dimension.
Hence, the convolutional part of a DNN can be connected to a dense part
with a fixed number of nodes even if the dimensions of the
input image are unknown.
The pooling is performed over :math:`l` pooling levels.
Each pooling level :math:`i` will create :math:`M_i` output features.
:math:`M_i` is given by :math:`n_i * n_i`,
with :math:`n_i` as the number of pooling operation per dimension in
level :math:`i`, and we use a list of the :math:`n_i`'s as a
parameter for SPP-Layer.
The length of this list is the level of the spatial pyramid.
Parameters
----------
incoming : a :class:`Layer` instance or tuple
The layer feeding into this layer, or the expected input shape.
pool_dims : list of integers
The list of :math:`n_i`'s that define the output dimension of each
pooling level :math:`i`. The length of pool_dims is the level of
the spatial pyramid.
mode : string
Pooling mode, one of 'max', 'average_inc_pad' or 'average_exc_pad'.
Defaults to 'max'.
**kwargs
Any additional keyword arguments are passed to the :class:`Layer`
superclass.
Notes
-----
This layer should be inserted between the convolutional part of a
DNN and its dense part. Convolutions can be used for
arbitrary input dimensions, but the size of their output will
depend on their input dimensions. Connecting the output of the
convolutional to the dense part then usually demands us to fix
the dimensions of the network's InputLayer.
The spatial pyramid pooling layer, however, allows us to leave the
network input dimensions arbitrary. The advantage over a global
pooling layer is the added robustness against object deformations
due to the pooling on different scales.
References
----------
.. [1] He, Kaiming et al (2015):
Spatial Pyramid Pooling in Deep Convolutional Networks
for Visual Recognition.
http://arxiv.org/pdf/1406.4729.pdf.
"""
def __init__(self, incoming, pool_dims=[4, 2, 1], mode='max', **kwargs):
super(SpatialPyramidPoolingDNNLayer, self).__init__(incoming,
**kwargs)
if len(self.input_shape) != 4:
raise ValueError("Tried to create a SPP layer with "
"input shape %r. Expected 4 input dimensions "
"(batchsize, channels, 2 spatial dimensions)."
% (self.input_shape,))
self.mode = mode
self.pool_dims = pool_dims

def get_output_for(self, input, **kwargs):
input_size = tuple(symb if fixed is None else fixed
for fixed, symb
in zip(self.input_shape[2:], input.shape[2:]))
pool_list = []
for pool_dim in self.pool_dims:
win_size = tuple((i + pool_dim - 1) // pool_dim
for i in input_size)
str_size = tuple(i // pool_dim for i in input_size)

pool = dnn.dnn_pool(input, win_size, str_size, self.mode, (0, 0))
pool = pool.flatten(2)
pool_list.append(pool)

return theano.tensor.concatenate(pool_list, axis=1)

def get_output_shape_for(self, input_shape):
num_features = sum(p*p for p in self.pool_dims)
return (input_shape[0], input_shape[1], num_features)
87 changes: 87 additions & 0 deletions lasagne/tests/layers/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,29 @@ def upscale_2d(data, scale_factor):
return upscaled


def spatial_pool(data, pool_dims):

def ceildiv(a, b):
return (a + b - 1) // b

def floordiv(a, b):
return a // b

input_size = data.shape[2:]
pooled_data_list = []
for pool_dim in pool_dims:
pool_size = tuple(ceildiv(i, pool_dim) for i in input_size)
stride_size = tuple(floordiv(i, pool_dim) for i in input_size)

pooled_part = max_pool_2d_ignoreborder(
data, pool_size, stride_size, (0, 0))
pooled_part = pooled_part.reshape(
np.shape(data)[0], np.shape(data)[1] * pool_dim ** 2)
pooled_data_list.append(pooled_part)

return np.concatenate(pooled_data_list, axis=1)


class TestFeaturePoolLayer:
def pool_test_sets():
for pool_size in [2, 3]:
Expand Down Expand Up @@ -815,3 +838,67 @@ def test_get_output_for(self, layer):
np_result = input.get_value().reshape((2, 3, -1)).mean(-1)

assert np.allclose(result, np_result)


class TestSpatialPyramidPoolingDNNLayer:
def pool_dims_test_sets():
for pyramid_level in [2, 3, 4]:
pool_dims = list(range(1, pyramid_level))
yield pool_dims

def input_layer(self, output_shape):
return Mock(output_shape=output_shape)

def layer(self, input_layer, pool_dims):
try:
from lasagne.layers.dnn import SpatialPyramidPoolingDNNLayer
except ImportError:
pytest.skip("cuDNN not available")

return SpatialPyramidPoolingDNNLayer(input_layer, pool_dims=pool_dims)

@pytest.mark.parametrize(
"pool_dims", list(pool_dims_test_sets()))
def test_get_output_for_ignoreborder(self, pool_dims):
try:
input = floatX(np.random.randn(8, 16, 17, 13))
input_layer = self.input_layer(input.shape)
input_theano = theano.shared(input)

result = self.layer(input_layer, pool_dims).get_output_for(
input_theano)

result_eval = result.eval()
numpy_result = spatial_pool(input, pool_dims)

assert np.all(numpy_result.shape == result_eval.shape)
assert np.allclose(result_eval, numpy_result)
except NotImplementedError:
pytest.skip()

@pytest.mark.parametrize(
"input_shape,output_shape",
[((32, 64, 24, 24), (32, 64, 21)),
((None, 64, 24, 24), (None, 64, 21)),
((32, None, 24, 24), (32, None, 21)),
((None, None, None, None), (None, None, 21))],
)
def test_get_output_shape_for(self, input_shape, output_shape):
try:
input_layer = self.input_layer(input_shape)
layer = self.layer(input_layer, pool_dims=[1, 2, 4])
assert layer.get_output_shape_for(input_shape) == output_shape
except NotImplementedError:
raise

def test_fail_on_mismatching_dimensionality(self):
try:
from lasagne.layers.dnn import SpatialPyramidPoolingDNNLayer
except ImportError:
pytest.skip("cuDNN not available")
with pytest.raises(ValueError) as exc:
SpatialPyramidPoolingDNNLayer((10, 20, 30))
assert "Expected 4 input dimensions" in exc.value.args[0]
with pytest.raises(ValueError) as exc:
SpatialPyramidPoolingDNNLayer((10, 20, 30, 40, 50))
assert "Expected 4 input dimensions" in exc.value.args[0]

0 comments on commit f1d290c

Please sign in to comment.