Skip to content

Fixes to argsort when axis is None. #596

Merged
merged 3 commits into from Apr 11, 2012

2 participants

@lamblin
Theano member
lamblin commented Apr 5, 2012

No description provided.

@delallea delallea and 1 other commented on an outdated diff Apr 5, 2012
theano/tensor/basic.py
@@ -6381,6 +6384,13 @@ def perform(self, node, inputs, output_storage):
z[0] = numpy.argsort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes):
+ if (isinstance(node.inputs[1], Constant) and
+ node.inputs[1].data is None):
+ return [(mul(*inputs_shapes[0]),)]
+ # axis should not be None, so there should be the same number of
+ # dimensions in the input and output
+ assert node.inputs[0].ndim == node.outputs[0].ndim
+ assert inputs_shapes[1] is ()
@delallea
Theano member
delallea added a note Apr 5, 2012

Maybe change is into ==? I don't think it can be guaranteed that they are the same objects, can it?

@lamblin
Theano member
lamblin added a note Apr 5, 2012

I'm not sure, I copied the code in Sort.infer_shape. I can change both.

@delallea
Theano member
delallea added a note Apr 5, 2012

Ok, yes then, please change both to be safe.

@lamblin
Theano member
lamblin added a note Apr 11, 2012

Done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@delallea
Theano member
delallea commented Apr 5, 2012

I guess there was no test with axis=None? Maybe add one?

@lamblin
Theano member
lamblin commented Apr 5, 2012

There is a test. It was failing in DebugMode.

@delallea
Theano member
delallea commented Apr 5, 2012

Ok.

@delallea delallea merged commit 9aa725a into Theano:master Apr 11, 2012
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Something went wrong with that request. Please try again.