Skip to content

Commit

Permalink
Have create_param() set the broadcast pattern of created shared varia…
Browse files Browse the repository at this point in the history
…bles
  • Loading branch information
f0k committed Jul 1, 2016
1 parent 7599698 commit 4d4e0b0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
9 changes: 9 additions & 0 deletions lasagne/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,15 @@ def test_create_param_retain_ndarray_dtype():
assert (result.dtype == param.dtype)


def test_create_param_broadcast_pattern():
from lasagne.utils import create_param
for shape in (10, 1, 20), (1, 2), (3, 1), (2, 3):
bcast = tuple(s == 1 for s in shape)
assert create_param(np.zeros, shape).broadcastable == bcast
assert create_param(np.zeros(shape, np.float32),
shape).broadcastable == bcast


def test_unroll_scan():
from lasagne.utils import unroll_scan
k = 2
Expand Down
6 changes: 5 additions & 1 deletion lasagne/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,11 @@ def create_param(spec, shape, name=None):
if spec.shape != shape:
raise ValueError("%s has shape %s, should be %s" %
(err_prefix % "numpy array", spec.shape, shape))
spec = theano.shared(spec)
# We assume parameter variables do not change shape after creation.
# We can thus fix their broadcast pattern, to allow Theano to infer
# broadcastable dimensions of expressions involving these parameters.
bcast = tuple(s == 1 for s in shape)
spec = theano.shared(spec, broadcastable=bcast)

if isinstance(spec, theano.Variable):
# We cannot check the shape here, Theano expressions (even shared
Expand Down

0 comments on commit 4d4e0b0

Please sign in to comment.