-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,7 +68,7 @@ class DownsampleFactorMax(Op): | |
""" | ||
|
||
@staticmethod | ||
def out_shape(imgshape, ds, ignore_border=False): | ||
def out_shape(imgshape, ds, st, ignore_border=False): | ||
"""Return the shape of the output from this op, for input of given | ||
shape and flags. | ||
|
@@ -78,8 +78,12 @@ def out_shape(imgshape, ds, ignore_border=False): | |
scalar Theano variable. | ||
:param ds: downsample factor over rows and columns | ||
this parameter indicates the pooling region | ||
:type ds: list or tuple of two ints | ||
:param st: the stride size | ||
:type st: list or tuple of two ints | ||
:param ignore_border: if ds doesn't divide imgshape, do we include an | ||
extra row/col of partial downsampling (False) or ignore it (True). | ||
:type ignore_border: bool | ||
|
@@ -93,24 +97,30 @@ def out_shape(imgshape, ds, ignore_border=False): | |
raise TypeError('imgshape must have at least two elements ' | ||
'(rows, cols)') | ||
r, c = imgshape[-2:] | ||
rval = list(imgshape[:-2]) + [r // ds[0], c // ds[1]] | ||
rval = list(imgshape[:-2]) + [(r - ds[0]) // st[0] + 1, (c - ds[1]) // st[1] + 1] | ||
|
||
if not ignore_border: | ||
if isinstance(r, theano.Variable): | ||
rval[-2] = tensor.switch(r % ds[0], rval[-2] + 1, rval[-2]) | ||
elif r % ds[0]: | ||
rval[-2] = tensor.switch((r - ds[0]) % st[0], rval[-2] + 1, rval[-2]) | ||
elif (r - ds[0]) % st[0]: | ||
rval[-2] += 1 | ||
if isinstance(c, theano.Variable): | ||
rval[-1] = tensor.switch(c % ds[1], rval[-1] + 1, rval[-1]) | ||
elif c % ds[1]: | ||
rval[-1] = tensor.switch((c - ds[1]) % st[1], rval[-1] + 1, rval[-1]) | ||
elif (c - ds[1]) % st[1]: | ||
rval[-1] += 1 | ||
return rval | ||
|
||
def __init__(self, ds, ignore_border=False): | ||
def __init__(self, ds, ignore_border=False, st=None): | ||
""" | ||
:param ds: downsample factor over rows and columns | ||
:param ds: downsample factor over rows and column. ds indicates the pool region size | ||
:type ds: list or tuple of two ints | ||
: param st: stride size, which is the number of shifts | ||
over rows/cols to get the the next pool region. | ||
if st is None, it is considered equal to ds | ||
(no overlap on pooling regions) | ||
: type st: list or tuple of two ints | ||
:param ignore_border: if ds doesn't divide imgshape, do we include | ||
an extra row/col of partial downsampling (False) or | ||
ignore it (True). | ||
|
@@ -119,19 +129,23 @@ def __init__(self, ds, ignore_border=False): | |
TODO: why is poolsize an op parameter here? | ||
""" | ||
self.ds = tuple(ds) | ||
if st == None: | ||
st = ds | ||
self.st = tuple(st) | ||
self.ignore_border = ignore_border | ||
|
||
def __eq__(self, other): | ||
return (type(self) == type(other) and | ||
self.ds == other.ds and | ||
self.st == other.st and | ||
self.ignore_border == other.ignore_border) | ||
|
||
def __hash__(self): | ||
return hash(type(self)) ^ hash(self.ds) ^ hash(self.ignore_border) | ||
return hash(type(self)) ^ hash(self.ds) ^ hash(self.st) ^ hash(self.ignore_border) | ||
|
||
def __str__(self): | ||
return '%s{%s,%s}' % (self.__class__.__name__, | ||
self.ds, self.ignore_border) | ||
self.ds, self.st, self.ignore_border) | ||
|
||
def make_node(self, x): | ||
if x.type.ndim != 4: | ||
|
@@ -147,35 +161,49 @@ def perform(self, node, inp, out): | |
if len(x.shape) != 4: | ||
raise NotImplementedError( | ||
'DownsampleFactorMax requires 4D input for now') | ||
z_shape = self.out_shape(x.shape, self.ds, self.ignore_border) | ||
z_shape = self.out_shape(x.shape, self.ds, self.st, self.ignore_border) | ||
if (z[0] is None) or (z[0].shape != z_shape): | ||
z[0] = numpy.zeros(self.out_shape(x.shape, self.ds, | ||
z[0] = numpy.zeros(self.out_shape(x.shape, self.ds, self.st, | ||
self.ignore_border)) | ||
z[0] = theano._asarray(z[0], dtype=x.dtype) | ||
zz = z[0] | ||
|
||
## zz needs to be initialized with -inf for the following to work | ||
zz -= numpy.inf | ||
pr = zz.shape[-2] # number of pooling output rows | ||
pc = zz.shape[-1] # number of pooling output cols | ||
ds0, ds1 = self.ds | ||
st0, st1 = self.st | ||
img_rows = x.shape[-2] | ||
img_cols = x.shape[-1] | ||
|
||
if self.ignore_border: | ||
x_usable2 = (x.shape[2] // ds0 * ds0) | ||
x_usable2 = (x.shape[2] - ds0) // st0 * st0 + ds0 | ||
else: | ||
x_usable2 = x.shape[2] | ||
if self.ignore_border: | ||
x_usable3 = (x.shape[3] // ds1 * ds1) | ||
x_usable3 = (x.shape[3] - ds1) // st1 * st1 + ds1 | ||
else: | ||
x_usable3 = x.shape[3] | ||
for n in xrange(x.shape[0]): | ||
for k in xrange(x.shape[1]): | ||
for i in xrange(x_usable2): | ||
zi = i / ds0 | ||
for j in xrange(x_usable3): | ||
zj = j / ds1 | ||
zz[n, k, zi, zj] = __builtin__.max(zz[n, k, zi, zj], | ||
x[n, k, i, j]) | ||
for r in xrange(pr): | ||
row_st = r * st0 | ||
for c in xrange(pc): | ||
col_st = c * st1 | ||
for i in xrange(ds0): | ||
row_ind = row_st + i | ||
if row_ind >= img_rows: | ||
continue | ||
for j in xrange(ds1): | ||
col_ind = col_st + j | ||
if col_ind >= img_cols: | ||
continue | ||
zz[n, k, r, c] = __builtin__.max(zz[n, k, r, c], | ||
x[n, k, row_ind, col_ind]) | ||
|
||
def infer_shape(self, node, in_shapes): | ||
shp = self.out_shape(in_shapes[0], self.ds, self.ignore_border) | ||
shp = self.out_shape(in_shapes[0], self.ds, self.st, self.ignore_border) | ||
return [shp] | ||
|
||
def grad(self, inp, grads): | ||
|
@@ -186,7 +214,7 @@ def grad(self, inp, grads): | |
ignore_border=self.ignore_border)( | ||
x, maxout, gz)] | ||
|
||
def c_code(self, node, name, inp, out, sub): | ||
def c_code_tmp(self, node, name, inp, out, sub): | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
SinaHonari
Owner
|
||
x, = inp | ||
z, = out | ||
fail = sub['fail'] | ||
|
@@ -258,7 +286,7 @@ def c_code(self, node, name, inp, out, sub): | |
} | ||
""" % locals() | ||
|
||
def c_code_cache_version(self): | ||
def c_code_cache_version_tmp(self): | ||
return (0, 1) | ||
|
||
|
||
|
instead of renaming those method, you could add this in the body:
That way, we keep the c code for the current supported case.