Skip to content

Commit

Permalink
转换规则 No. 323/333 (#198)
Browse files Browse the repository at this point in the history
Add tests
  • Loading branch information
co63oc committed Jul 31, 2023
1 parent 27ebdae commit 7a7c826
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 0 deletions.
12 changes: 12 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -9844,6 +9844,18 @@
"out"
]
},
"torch.special.polygamma": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.polygamma",
"args_list": [
"n",
"input",
"out"
],
"kwargs_change": {
"input": "x"
}
},
"torch.special.psi": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.digamma",
Expand Down
73 changes: 73 additions & 0 deletions tests/test_nn_GLU.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.GLU")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
m = torch.nn.GLU()
result = m(x)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
m = torch.nn.GLU(dim=-1)
result = m(x)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
m = torch.nn.GLU(dim=2)
result = m(x)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)
53 changes: 53 additions & 0 deletions tests/test_special_polygamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.special.polygamma")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.special.polygamma(1, torch.tensor([1, 0.5]))
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.tensor([1, 0.5])
result = torch.special.polygamma(1, a)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
a = [1, 0.5]
out = torch.tensor(a)
result = torch.special.polygamma(1, torch.tensor(a), out=out)
"""
)
obj.run(pytorch_code, ["result", "out"])

0 comments on commit 7a7c826

Please sign in to comment.