Skip to content

Commit

Permalink
Merge pull request #7 from jaberg/small_stuff
Browse files Browse the repository at this point in the history
Small fixes to AdvancedSubtensor1 and CVM.
  • Loading branch information
dwf committed Aug 25, 2011
2 parents 9c8cf38 + dd58f74 commit 92a4c45
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
12 changes: 12 additions & 0 deletions theano/compile/function_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,16 @@ def __contains__(self, item):
self._value = ValueAttribute()
self._container = ContainerAttribute()

# Compute self.n_returned_outputs.
# This is used only when fn.need_update_inputs is False
# because we're using one of the VM objects and it is
# putting updates back into the input containers all by itself.
assert len(self.maker.expanded_inputs) == len(self.input_storage)
self.n_returned_outputs = len(self.output_storage)
for input in self.maker.expanded_inputs:
if input.update is not None:
self.n_returned_outputs -= 1

def __contains__(self, item):
return self.value.__contains__(item)

Expand Down Expand Up @@ -636,6 +646,8 @@ def __call__(self, *args, **kwargs):
for input, storage in reversed(zip(self.maker.expanded_inputs, self.input_storage)):
if input.update is not None:
storage.data = outputs.pop()
else:
outputs = outputs[:self.n_returned_outputs]

# Put default values back in the storage
for i, (required, refeed, value) in enumerate(self.defaults):
Expand Down
13 changes: 6 additions & 7 deletions theano/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,9 +1270,12 @@ def __getitem__(self, args):
break

if advanced:
if len(args) == 1 and isinstance(args[0],
(list, TensorVariable,
theano.tensor.sharedvar.TensorSharedVariable)):
if (len(args) == 1
and isinstance(args[0], (
list,
TensorVariable,
TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))):
return advanced_subtensor1(self, *args)
else:
return AdvancedSubtensor(args)(self, *args)
Expand Down Expand Up @@ -4863,10 +4866,6 @@ def make_node(self, x, ilist):
raise TypeError('index must be vector')
if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar')
if x_.type.broadcastable[0]:
# the caller should have made a copy of x len(ilist) times
raise TypeError('cannot index into a broadcastable dimension')

return Apply(self, [x_, ilist_], [x_.type()])

def perform(self, node, inp, out_):
Expand Down

0 comments on commit 92a4c45

Please sign in to comment.