Skip to content

Commit

Permalink
Reverse Sequence small fixes
Browse files Browse the repository at this point in the history
Signed-off-by: maheshambule <mahesh_ambule@persistent.com>
  • Loading branch information
maheshambule committed May 1, 2020
1 parent a5bc9e8 commit 1ae7f0f
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ struct ReverseSequenceAttrs : public tvm::AttrsNode<ReverseSequenceAttrs> {
Integer batch_axis;

TVM_DECLARE_ATTRS(ReverseSequenceAttrs, "relay.attrs.ReverseSequenceAttrs") {
TVM_ATTR_FIELD(seq_axis).set_default(NullValue<Integer>())
TVM_ATTR_FIELD(seq_axis).set_default(1)
.describe("The seq axis along which to reverse elements.");
TVM_ATTR_FIELD(batch_axis).set_default(NullValue<Integer>())
TVM_ATTR_FIELD(batch_axis).set_default(0)
.describe("The batch axis along which to slice the tensor.");
}
}; // struct ReverseSequenceAttrs
Expand Down
6 changes: 2 additions & 4 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,9 +1738,7 @@ def convert_reverse_sequence(self, op):
batch_axis = options.BatchDim()
seq_axis = options.SeqDim()

out = _op.reverse_sequence(in_expr, length_expr, seq_axis, batch_axis)

return out
return _op.reverse_sequence(in_expr, length_expr, seq_axis, batch_axis)

def convert_cast(self, op):
"""Convert TFLite CAST"""
Expand Down Expand Up @@ -2384,7 +2382,7 @@ def has_expr(self, input_tensor_idx):
return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx))

def get_tensor_expr(self, tensor):
""" Return the expr for tensor. """
""" Return the Relay expr for tensor. """
if self.has_expr(tensor.tensor_idx):
expr = self.get_expr(tensor.tensor_idx)
else:
Expand Down
1 change: 1 addition & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,7 @@ Array<te::Tensor> ReverseCompute(const Attrs& attrs,
const Type& out_type) {
const ReverseAttrs *param = attrs.as<ReverseAttrs>();
CHECK(param != nullptr);
//pass empty seq_length tensor to reverse_sequence
return { topi::reverse_sequence(inputs[0], te::Tensor(), param->axis)};
}

Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def reverse_sequence(a, seq_lengths, seq_axis=1, batch_axis=0):
Parameters
----------
data : tvm.te.Tensor
a : tvm.te.Tensor
The tensor to be reversed.
seq_lengths : tvm.te.Tensor
Expand Down
1 change: 1 addition & 0 deletions topi/src/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ TVM_REGISTER_GLOBAL("topi.transpose")

TVM_REGISTER_GLOBAL("topi.flip")
.set_body([](TVMArgs args, TVMRetValue *rv) {
//pass empty seq_lengths tensor to reverse_sequence
*rv = reverse_sequence(args[0], Tensor(), args[1]);
});

Expand Down

0 comments on commit 1ae7f0f

Please sign in to comment.