Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR API adaptor No.89、226】Migrate paddle.nn.functional.gather_tree,paddle.nn.functional.temporal_shift into pir #58792

Merged
merged 3 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions python/paddle/nn/functional/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from ...base.data_feeder import check_type, check_variable_and_dtype
from ...base.layer_helper import LayerHelper
from ...common_ops_import import Variable
from ...framework import convert_np_dtype_to_dtype_, core
from ...framework import (
convert_np_dtype_to_dtype_,
core,
in_dynamic_or_pir_mode,
)

__all__ = []

Expand Down Expand Up @@ -206,7 +210,7 @@ def gather_tree(ids, parents):
if ids.ndim != parents.ndim:
raise ValueError("The ids's shape must be the same as parents' shape. ")

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.gather_tree(ids, parents)
else:
helper = LayerHelper('gather_tree', **locals())
Expand Down Expand Up @@ -292,7 +296,7 @@ def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"):
"Attr(data_format) should be 'NCHW' or 'NHWC'. "
f"Received Attr(data_format): {data_format}."
)
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.temporal_shift(x, seg_num, shift_ratio, data_format)
else:
helper = LayerHelper("temporal_shift", **locals())
Expand Down
17 changes: 16 additions & 1 deletion test/legacy_test/test_gather_tree_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle.base.framework import Program, program_guard
from paddle.pir_utils import test_with_pir_api


class TestGatherTreeOp(OpTest):
Expand All @@ -36,7 +37,7 @@ def setUp(self):
self.outputs = {'Out': self.backtrace(ids, parents)}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

@staticmethod
def backtrace(ids, parents):
Expand All @@ -55,6 +56,7 @@ def backtrace(ids, parents):


class TestGatherTreeOpAPI(unittest.TestCase):
@test_with_pir_api
def test_case(self):
paddle.enable_static()
ids = paddle.static.data(name='ids', shape=[5, 2, 2], dtype='int64')
Expand All @@ -77,6 +79,7 @@ def test_case2(self):


class TestGatherTreeOpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
Expand All @@ -99,6 +102,18 @@ def test_Variable_parents():

self.assertRaises(TypeError, test_Variable_parents)

paddle.disable_static()


class TestGatherTreeOpErrorForOthers(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
ids = paddle.static.data(name='ids', shape=[5, 2, 2], dtype='int64')
parents = paddle.static.data(
name='parents', shape=[5, 2, 2], dtype='int64'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议在 PR 描述里标注 gather_tree 的单测覆盖率。注明这些 test_errors 单测尚未覆盖

)

def test_type_ids():
# dtype must be int32 or int64
bad_ids = paddle.static.data(
Expand Down
14 changes: 8 additions & 6 deletions test/legacy_test/test_temporal_shift_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def temporal_shift(x, seg_num, shift_ratio, data_format):
Expand Down Expand Up @@ -73,10 +74,10 @@ def init_dtype(self):
self.dtype = 'float64'

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad_ignore_uv(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)

def initTestCase(self):
self.x_shape = (6, 4, 4, 4)
Expand Down Expand Up @@ -123,12 +124,12 @@ def initTestCase(self):
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad_ignore_uv(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(place, ['X'], 'Out')
self.check_grad_with_place(place, ['X'], 'Out', check_pir=True)


class TestTemporalShiftAPI(unittest.TestCase):
Expand All @@ -146,6 +147,7 @@ def test_api(self):
x=input, seg_num=2, shift_ratio=0.2
)

@test_with_pir_api
def test_static_fp16_gpu(self):
if paddle.base.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
Expand Down Expand Up @@ -224,11 +226,11 @@ def setUp(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad_ignore_uv(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
self.check_grad_with_place(place, ['X'], 'Out', check_pir=True)


if __name__ == "__main__":
Expand Down