Skip to content

Commit

Permalink
applying the changes for the case of ignore_border plus the changes f…
Browse files Browse the repository at this point in the history
…or pep8 for issue Theano#2196
  • Loading branch information
Sina Honari committed Dec 4, 2014
1 parent f67a9f1 commit 51bc9ec
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 79 deletions.
71 changes: 41 additions & 30 deletions theano/tensor/signal/downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def max_pool_2d(input, ds, ignore_border=False):
:param input: input images. Max pooling will be done over the 2 last
dimensions.
:type ds: tuple of length 2
:param ds: factor by which to downscale (vertical ds, horizontal ds).
(2,2) will halve the image in each dimension.
:param ds: factor by which to downscale (vertical ds, horizontal ds).
(2,2) will halve the image in each dimension.
:param ignore_border: boolean value. When True, (5,5) input with ds=(2,2)
will generate a (2,2) output. (3,3) otherwise.
"""
Expand Down Expand Up @@ -81,7 +81,7 @@ def out_shape(imgshape, ds, ignore_border=False, st=None):
this parameter indicates the size of the pooling region
:type ds: list or tuple of two ints
:param st: the stride size. This is the distance between the pooling
:param st: the stride size. This is the distance between the pooling
regions. If it's set to None, in which case it equlas ds.
:type st: list or tuple of two ints
Expand All @@ -102,29 +102,34 @@ def out_shape(imgshape, ds, ignore_border=False, st=None):
st = ds
r, c = imgshape[-2:]

out_r = (r - ds[0]) // st[0] + 1
out_c = (c - ds[1]) // st[1] + 1

if isinstance(r, theano.Variable):
nr = tensor.maximum(out_r, 0)
else:
nr = numpy.maximum(out_r, 0)
if isinstance(c, theano.Variable):
nc = tensor.maximum(out_c, 0)
if ignore_border:
out_r = (r - ds[0]) // st[0] + 1
out_c = (c - ds[1]) // st[1] + 1
if isinstance(r, theano.Variable):
nr = tensor.maximum(out_r, 0)
else:
nr = numpy.maximum(out_r, 0)
if isinstance(c, theano.Variable):
nc = tensor.maximum(out_c, 0)
else:
nc = numpy.maximum(out_c, 0)
else:
nc = numpy.maximum(out_c, 0)

if not ignore_border:
if isinstance(r, theano.Variable):
nr = tensor.switch(tensor.ge(st[0], ds[0]), (r - 1) // st[0] + 1, tensor.maximum(0, (r - 1 - ds[0]) // st[0] + 1) + 1)
elif st[0] >= ds[0]:
nr = tensor.switch(tensor.ge(st[0], ds[0]),
(r - 1) // st[0] + 1,
tensor.maximum(0, (r - 1 - ds[0])
// st[0] + 1) + 1)
elif st[0] >= ds[0]:
nr = (r - 1) // st[0] + 1
else:
nr = max(0, (r - 1 - ds[0]) // st[0] + 1) + 1

if isinstance(c, theano.Variable):
nc = tensor.switch(tensor.ge(st[1], ds[1]), (c - 1) // st[1] + 1, tensor.maximum(0, (c - 1 - ds[1]) // st[1] + 1) + 1)
elif st[1] >= ds[1]:
nc = tensor.switch(tensor.ge(st[1], ds[1]),
(c - 1) // st[1] + 1,
tensor.maximum(0, (c - 1 - ds[1])
// st[1] + 1) + 1)
elif st[1] >= ds[1]:
nc = (c - 1) // st[1] + 1
else:
nc = max(0, (c - 1 - ds[1]) // st[1] + 1) + 1
Expand All @@ -134,14 +139,15 @@ def out_shape(imgshape, ds, ignore_border=False, st=None):

def __init__(self, ds, ignore_border=False, st=None):
"""
:param ds: downsample factor over rows and column. ds indicates the pool region size
:param ds: downsample factor over rows and column.
ds indicates the pool region size.
:type ds: list or tuple of two ints
: param st: stride size, which is the number of shifts
: param st: stride size, which is the number of shifts
over rows/cols to get the the next pool region.
if st is None, it is considered equal to ds
if st is None, it is considered equal to ds
(no overlap on pooling regions)
: type st: list or tuple of two ints
: type st: list or tuple of two ints
:param ignore_border: if ds doesn't divide imgshape, do we include
an extra row/col of partial downsampling (False) or
Expand All @@ -163,11 +169,12 @@ def __eq__(self, other):
self.ignore_border == other.ignore_border)

def __hash__(self):
return hash(type(self)) ^ hash(self.ds) ^ hash(self.st) ^ hash(self.ignore_border)
return hash(type(self)) ^ hash(self.ds) ^ \
hash(self.st) ^ hash(self.ignore_border)

def __str__(self):
return '%s{%s,%s,%s}' % (self.__class__.__name__,
self.ds, self.st, self.ignore_border)
self.ds, self.st, self.ignore_border)

def make_node(self, x):
if x.type.ndim != 4:
Expand All @@ -192,8 +199,10 @@ def perform(self, node, inp, out):

## zz needs to be initialized with -inf for the following to work
zz -= numpy.inf
pr = zz.shape[-2] # number of pooling output rows
pc = zz.shape[-1] # number of pooling output cols
#number of pooling output rows
pr = zz.shape[-2]
#number of pooling output cols
pc = zz.shape[-1]
ds0, ds1 = self.ds
st0, st1 = self.st
img_rows = x.shape[-2]
Expand All @@ -209,11 +218,13 @@ def perform(self, node, inp, out):
col_end = __builtin__.min(col_st + ds1, img_cols)
for row_ind in xrange(row_st, row_end):
for col_ind in xrange(col_st, col_end):
zz[n, k, r, c] = __builtin__.max(zz[n, k, r, c],
x[n, k, row_ind, col_ind])
zz[n, k, r, c] = \
__builtin__.max(zz[n, k, r, c],
x[n, k, row_ind, col_ind])

def infer_shape(self, node, in_shapes):
shp = self.out_shape(in_shapes[0], self.ds, self.ignore_border, self.st)
shp = self.out_shape(in_shapes[0], self.ds,
self.ignore_border, self.st)
return [shp]

def grad(self, inp, grads):
Expand Down
111 changes: 62 additions & 49 deletions theano/tensor/signal/tests/test_downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def numpy_max_pool_2d(input, ds, ignore_border=False):
'''Helper function, implementing max_pool_2d in pure numpy'''
if len(input.shape) < 2:
raise NotImplementedError('input should have at least 2 dim,'
' shape is %s'\
% str(input.shape))
' shape is %s'
% str(input.shape))
xi = 0
yi = 0
if not ignore_border:
Expand Down Expand Up @@ -45,10 +45,10 @@ def numpy_max_pool_2d_stride(input, ds, ignore_border=False, st=None):
for the pooling regions. if not indicated, st == sd.'''
if len(input.shape) < 2:
raise NotImplementedError('input should have at least 2 dim,'
' shape is %s'\
% str(input.shape))
' shape is %s'
% str(input.shape))

if st == None:
if st is None:
st = ds
xi = 0
yi = 0
Expand All @@ -58,25 +58,24 @@ def numpy_max_pool_2d_stride(input, ds, ignore_border=False, st=None):
out_r = 0
out_c = 0
if img_rows - ds[0] >= 0:
out_r = (img_rows - ds[0]) // st[0] + 1
out_r = (img_rows - ds[0]) // st[0] + 1
if img_cols - ds[1] >= 0:
out_c = (img_cols - ds[1]) // st[1] + 1

if not ignore_border:
if out_r > 0:
if img_rows - ((out_r - 1) * st[0] + ds[0]) > 0 :
if img_rows - ((out_r - 1) * st[0] + ds[0]) > 0:
rr = img_rows - out_r * st[0]
if rr > 0:
out_r += 1
else:
if img_rows > 0:
out_r += 1

if out_c > 0:
if img_cols - ((out_c - 1) * st[1] + ds[1]) > 0 :
if img_cols - ((out_c - 1) * st[1] + ds[1]) > 0:
cr = img_cols - out_c * st[1]
if cr > 0:
out_c +=1
out_c += 1
else:
if img_cols > 0:
out_c += 1
Expand Down Expand Up @@ -119,7 +118,8 @@ def test_DownsampleFactorMax(self):

#DownsampleFactorMax op
maxpool_op = DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border)(images)
ignore_border=
ignore_border)(images)
f = function([images], maxpool_op)
output_val = f(imval)
utt.assert_allclose(output_val, numpy_output_val)
Expand All @@ -130,11 +130,12 @@ def test_DownsampleFactorMaxStride(self):
stridesizes = ((1, 1), (3, 3), (5, 7))
# generate random images
imval = rng.rand(4, 10, 16, 16)
outputshps = ((4, 10, 16, 16), (4, 10, 6, 6), (4, 10, 4, 3), (4, 10, 16, 16), \
(4, 10, 6, 6), (4, 10, 4, 3), (4, 10, 14, 14), (4, 10, 5, 5), \
(4, 10, 3, 2), (4, 10, 14, 14), (4, 10, 6, 6), (4, 10, 4, 3), \
(4, 10, 12, 14), (4, 10, 4, 5), (4, 10, 3, 2), (4, 10, 12, 14), \
(4, 10, 5, 6), (4, 10, 4, 3))
outputshps = ((4, 10, 16, 16), (4, 10, 6, 6), (4, 10, 4, 3),
(4, 10, 16, 16), (4, 10, 6, 6), (4, 10, 4, 3),
(4, 10, 14, 14), (4, 10, 5, 5), (4, 10, 3, 2),
(4, 10, 14, 14), (4, 10, 6, 6), (4, 10, 4, 3),
(4, 10, 12, 14), (4, 10, 4, 5), (4, 10, 3, 2),
(4, 10, 12, 14), (4, 10, 5, 6), (4, 10, 4, 3))
images = tensor.dtensor4()
indx = 0
for maxpoolshp in maxpoolshps:
Expand All @@ -143,30 +144,36 @@ def test_DownsampleFactorMaxStride(self):
outputshp = outputshps[indx]
indx += 1
#DownsampleFactorMax op
numpy_output_val = self.numpy_max_pool_2d_stride(imval, maxpoolshp,
ignore_border, stride)
numpy_output_val = \
self.numpy_max_pool_2d_stride(imval, maxpoolshp,
ignore_border, stride)
assert numpy_output_val.shape == outputshp, (
"outshape is %s, calculated shape is %s"
%(outputshp, numpy_output_val.shape))
maxpool_op = DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border, st=stride)(images)
"outshape is %s, calculated shape is %s"
% (outputshp, numpy_output_val.shape))
maxpool_op = \
DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border,
st=stride)(images)
f = function([images], maxpool_op)
output_val = f(imval)
utt.assert_allclose(output_val, numpy_output_val)

def test_DownsampleFactorMaxStrideExtra(self):
rng = numpy.random.RandomState(utt.fetch_seed())
maxpoolshps = ((5, 3), (5, 3), (5, 3), (5, 5), (3, 2), (7, 7), (9, 9))
stridesizes = ((3, 2), (7, 5), (10, 6), (1, 1), (2, 3), (10, 10), (1, 1))
imvsizs = ((16, 16), (16, 16), (16, 16), (8, 5), (8, 5), (8, 5), (8, 5))
outputshps = ((4, 10, 4, 7), (4, 10, 5, 8), (4, 10, 2, 3), (4, 10, 3, 4), \
(4, 10, 2, 3), (4, 10, 2, 3), (4, 10, 4, 1), (4, 10, 4, 1), \
(4, 10, 3, 2), (4, 10, 4, 2), (4, 10, 1, 0), (4, 10, 1, 1), \
stridesizes = ((3, 2), (7, 5), (10, 6), (1, 1),
(2, 3), (10, 10), (1, 1))
imvsizs = ((16, 16), (16, 16), (16, 16), (8, 5),
(8, 5), (8, 5), (8, 5))
outputshps = ((4, 10, 4, 7), (4, 10, 5, 8), (4, 10, 2, 3),
(4, 10, 3, 4), (4, 10, 2, 3), (4, 10, 2, 3),
(4, 10, 4, 1), (4, 10, 4, 1), (4, 10, 3, 2),
(4, 10, 4, 2), (4, 10, 1, 0), (4, 10, 1, 1),
(4, 10, 0, 0), (4, 10, 1, 1))
images = tensor.dtensor4()
for indx in numpy.arange(len(maxpoolshps)):
imvsize = imvsizs[indx]
imval = rng.rand(4, 10 , imvsize[0], imvsize[1])
imvsize = imvsizs[indx]
imval = rng.rand(4, 10, imvsize[0], imvsize[1])
stride = stridesizes[indx]
maxpoolshp = maxpoolshps[indx]
for ignore_border in [True, False]:
Expand All @@ -175,13 +182,16 @@ def test_DownsampleFactorMaxStrideExtra(self):
indx_out += 1
outputshp = outputshps[indx_out]
#DownsampleFactorMax op
numpy_output_val = self.numpy_max_pool_2d_stride(imval, maxpoolshp,
ignore_border, stride)
numpy_output_val = \
self.numpy_max_pool_2d_stride(imval, maxpoolshp,
ignore_border, stride)
assert numpy_output_val.shape == outputshp, (
"outshape is %s, calculated shape is %s"
%(outputshp, numpy_output_val.shape))
maxpool_op = DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border, st=stride)(images)
"outshape is %s, calculated shape is %s"
% (outputshp, numpy_output_val.shape))
maxpool_op = \
DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border,
st=stride)(images)
f = function([images], maxpool_op)
output_val = f(imval)
utt.assert_allclose(output_val, numpy_output_val)
Expand All @@ -198,7 +208,8 @@ def test_DownsampleFactorMax_grad(self):
#print 'ignore_border =', ignore_border
def mp(input):
return DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border)(input)
ignore_border=
ignore_border)(input)
utt.verify_grad(mp, [imval], rng=rng)

def test_DownsampleFactorMaxGrad_grad(self):
Expand Down Expand Up @@ -257,7 +268,8 @@ def test_max_pool_2d_2D(self):
output_val = function([images], output)(imval)
assert numpy.all(output_val == numpy_output_val), (
"output_val is %s, numpy_output_val is %s"
%(output_val, numpy_output_val))
% (output_val, numpy_output_val))

def mp(input):
return max_pool_2d(input, maxpoolshp, ignore_border)
utt.verify_grad(mp, [imval], rng=rng)
Expand All @@ -278,15 +290,15 @@ def test_max_pool_2d_3D(self):
output_val = function([images], output)(imval)
assert numpy.all(output_val == numpy_output_val), (
"output_val is %s, numpy_output_val is %s"
%(output_val, numpy_output_val))
% (output_val, numpy_output_val))
c = tensor.sum(output)
c_val = function([images], c)(imval)
g = tensor.grad(c, images)
g_val = function([images],
[g.shape,
tensor.min(g, axis=(0, 1, 2)),
tensor.max(g, axis=(0, 1, 2))]
)(imval)
[g.shape,
tensor.min(g, axis=(0, 1, 2)),
tensor.max(g, axis=(0, 1, 2))]
)(imval)

#removed as already tested in test_max_pool_2d_2D
#This make test in debug mode too slow.
Expand Down Expand Up @@ -335,19 +347,20 @@ def test_infer_shape(self):

# checking shapes generated by DownsampleFactorMax
self._compile_and_check([image],
[DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border)(image)],
[image_val], DownsampleFactorMax)
[DownsampleFactorMax(maxpoolshp,
ignore_border=ignore_border)(image)],
[image_val], DownsampleFactorMax)

# checking shapes generated by DownsampleFactorMaxGrad
maxout_val = rng.rand(*out_shapes[i][j])
gz_val = rng.rand(*out_shapes[i][j])
self._compile_and_check([image, maxout, gz],
[DownsampleFactorMaxGrad(maxpoolshp,
ignore_border=ignore_border)(image, maxout, gz)],
[image_val, maxout_val, gz_val],
[DownsampleFactorMaxGrad(maxpoolshp,
ignore_border=ignore_border)
(image, maxout, gz)],
[image_val, maxout_val, gz_val],
DownsampleFactorMaxGrad,
warn=False)
warn=False)


if __name__ == '__main__':
Expand Down

0 comments on commit 51bc9ec

Please sign in to comment.