Skip to content

Commit

Permalink
Merge pull request #808 from wjj19950828/fixed_nonzero
Browse files Browse the repository at this point in the history
Fixed nonzero op
  • Loading branch information
jiangjiajun committed Jun 8, 2022
2 parents 2cb5ba2 + 06a3deb commit 7101837
Showing 1 changed file with 7 additions and 24 deletions.
31 changes: 7 additions & 24 deletions x2paddle/op_mapper/onnx2paddle/opset9/opset.py
Expand Up @@ -1880,30 +1880,13 @@ def Where(self, node):
@print_mapping_info
def NonZero(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_x_dim = len(val_x.out_shapes[0])
if val_x_dim == 1:
self.paddle_graph.add_layer(
"paddle.nonzero",
inputs={"x": val_x.name},
outputs=[val_x.name])
self.paddle_graph.add_layer(
"paddle.transpose",
inputs={"x": val_x.name},
outputs=[node.layer_name],
perm=[1, 0])
if val_x_dim > 1:
self.paddle_graph.add_layer(
"paddle.nonzero",
inputs={"x": val_x.name},
outputs=[val_x.name])
self.paddle_graph.add_layer(
"paddle.split",
inputs={"x": val_x.name},
outputs=[val_x.name],
num_or_sections=1,
axis=val_x_dim)
self.paddle_graph.add_layer(
"paddle.concat", inputs={"x": val_x.name}, outputs=[node.name])
self.paddle_graph.add_layer(
"paddle.nonzero",
inputs={"x": val_x.name},
outputs=[val_x.name],
as_tuple=True)
self.paddle_graph.add_layer(
"paddle.concat", inputs={"x": val_x.name}, outputs=[node.name])

@print_mapping_info
def Identity(self, node):
Expand Down

0 comments on commit 7101837

Please sign in to comment.