From 79c00458bdcb9d24e94893562b5f7ca03f2b44ce Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 20 May 2024 10:43:52 +0000 Subject: [PATCH] fix pir api 9 --- test/legacy_test/op_test.py | 9 +++++++++ test/{deprecated => }/legacy_test/test_cummin_op.py | 0 test/{deprecated => }/legacy_test/test_frame_op.py | 0 test/{deprecated => }/legacy_test/test_lu_op.py | 0 4 files changed, 9 insertions(+) rename test/{deprecated => }/legacy_test/test_cummin_op.py (100%) rename test/{deprecated => }/legacy_test/test_frame_op.py (100%) rename test/{deprecated => }/legacy_test/test_lu_op.py (100%) diff --git a/test/legacy_test/op_test.py b/test/legacy_test/op_test.py index 8f79ab6975a1c..ed4e0f478ed38 100644 --- a/test/legacy_test/op_test.py +++ b/test/legacy_test/op_test.py @@ -1529,6 +1529,10 @@ def _calc_output( check_cinn=False, ): with paddle.pir_utils.OldIrGuard(): + if hasattr(self, "attrs"): + for k, v in self.attrs.items(): + if isinstance(v, paddle.base.core.DataType): + self.attrs[k] = paddle.pir.core.datatype_to_vartype[v] program = Program() block = program.global_block() op = self._append_ops(block) @@ -3251,6 +3255,11 @@ def check_grad_with_place( if "use_mkldnn" in op_attrs and op_attrs["use_mkldnn"]: op_attrs["use_mkldnn"] = False use_onednn = True + if hasattr(self, "attrs"): + for k, v in self.attrs.items(): + if isinstance(v, paddle.base.core.DataType): + self.attrs[k] = paddle.pir.core.datatype_to_vartype[v] + self.op = create_op( self.scope, self.op_type, diff --git a/test/deprecated/legacy_test/test_cummin_op.py b/test/legacy_test/test_cummin_op.py similarity index 100% rename from test/deprecated/legacy_test/test_cummin_op.py rename to test/legacy_test/test_cummin_op.py diff --git a/test/deprecated/legacy_test/test_frame_op.py b/test/legacy_test/test_frame_op.py similarity index 100% rename from test/deprecated/legacy_test/test_frame_op.py rename to test/legacy_test/test_frame_op.py diff --git a/test/deprecated/legacy_test/test_lu_op.py b/test/legacy_test/test_lu_op.py similarity index 100% rename from test/deprecated/legacy_test/test_lu_op.py rename to test/legacy_test/test_lu_op.py