Skip to content

Commit

Permalink
Fixed topk op (#805)
Browse files Browse the repository at this point in the history
* Support bigbird model

* Support GPT2

* rm useless code

* deal with comments

* fixed topk

* deal with comments

* add notes
  • Loading branch information
wjj19950828 committed Jun 6, 2022
1 parent ea254a8 commit 2cb5ba2
Showing 1 changed file with 30 additions and 15 deletions.
45 changes: 30 additions & 15 deletions x2paddle/op_mapper/onnx2paddle/opset9/opset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2577,27 +2577,42 @@ def assign_params(op_name, weights, weight_idx=0, suffix=''):
def TopK(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_k = self.graph.get_input_node(node, idx=1, copy=True)
if val_k.dtype != "int32":
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": val_k.name},
outputs=[val_k.name],
dtype=string('int32'))
layer_attrs = dict()
layer_attrs["axis"] = node.get_attr('axis', -1)
layer_attrs["largest"] = True if node.get_attr('largest',
1) == 1 else False
layer_attrs["sorted"] = True if node.get_attr('sorted',
1) == 1 else False
self.paddle_graph.add_layer(
"paddle.topk",
inputs={"x": val_x.name,
"k": val_k.name},
outputs=[
"{}_p{}".format(node.layer_name, 0),
"{}_p{}".format(node.layer_name, 1)
],
**layer_attrs)
k = _const_weight_or_none(val_k)
if isinstance(k, (list, tuple, np.ndarray)):
k = k[0]
# If k can get the value directly, it is used as an attribute; otherwise it is used as an input tensor
if k is not None:
layer_attrs["k"] = k
self.paddle_graph.add_layer(
"paddle.topk",
inputs={"x": val_x.name},
outputs=[
"{}_p{}".format(node.layer_name, 0),
"{}_p{}".format(node.layer_name, 1)
],
**layer_attrs)
else:
if val_k.dtype != "int32":
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": val_k.name},
outputs=[val_k.name],
dtype=string('int32'))
self.paddle_graph.add_layer(
"paddle.topk",
inputs={"x": val_x.name,
"k": val_k.name},
outputs=[
"{}_p{}".format(node.layer_name, 0),
"{}_p{}".format(node.layer_name, 1)
],
**layer_attrs)

@print_mapping_info
def LRN(self, node):
Expand Down

0 comments on commit 2cb5ba2

Please sign in to comment.