Skip to content

Commit

Permalink
Merge pull request #715 from f0k/bcast-params
Browse files Browse the repository at this point in the history
Have create_param() set the broadcast pattern
  • Loading branch information
benanne committed Jul 4, 2016
2 parents 639972e + 4d4e0b0 commit 822b042
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
9 changes: 9 additions & 0 deletions lasagne/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,15 @@ def test_create_param_retain_ndarray_dtype():
assert (result.dtype == param.dtype)


def test_create_param_broadcast_pattern():
from lasagne.utils import create_param
for shape in (10, 1, 20), (1, 2), (3, 1), (2, 3):
bcast = tuple(s == 1 for s in shape)
assert create_param(np.zeros, shape).broadcastable == bcast
assert create_param(np.zeros(shape, np.float32),
shape).broadcastable == bcast


def test_unroll_scan():
from lasagne.utils import unroll_scan
k = 2
Expand Down
10 changes: 6 additions & 4 deletions lasagne/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,19 +348,21 @@ def create_param(spec, shape, name=None):
if spec.shape != shape:
raise ValueError("%s has shape %s, should be %s" %
(err_prefix % "numpy array", spec.shape, shape))
spec = theano.shared(spec)
# We assume parameter variables do not change shape after creation.
# We can thus fix their broadcast pattern, to allow Theano to infer
# broadcastable dimensions of expressions involving these parameters.
bcast = tuple(s == 1 for s in shape)
spec = theano.shared(spec, broadcastable=bcast)

if isinstance(spec, theano.Variable):
# We cannot check the shape here, Theano expressions (even shared
# variables) do not have a fixed compile-time shape. We can check the
# dimensionality though.
# Note that we cannot assign a name here. We could assign to the
# `name` attribute of the variable, but the user may have already
# named the variable and we don't want to override this.
if spec.ndim != len(shape):
raise ValueError("%s has %d dimensions, should be %d" %
(err_prefix % "Theano variable", spec.ndim,
len(shape)))
# We only assign a name if the user hasn't done so already.
if not spec.name:
spec.name = name
return spec
Expand Down

0 comments on commit 822b042

Please sign in to comment.