diff --git a/python/paddle/fluid/tests/unittests/test_inplace.py b/python/paddle/fluid/tests/unittests/test_inplace.py index 052109674ae0c..399640ce6b1d2 100644 --- a/python/paddle/fluid/tests/unittests/test_inplace.py +++ b/python/paddle/fluid/tests/unittests/test_inplace.py @@ -483,7 +483,16 @@ def inplace_api_processing(self, var): return var.subtract_(input_var_2) -class TestDygraphInplaceRemainder(TestDygraphInplaceAdd): +class TestDygraphInplaceRemainder(unittest.TestCase): + + def setUp(self): + self.init_data() + self.set_np_compare_func() + + def init_data(self): + self.input_var_numpy = np.random.rand(2, 3, 4) + self.dtype = "float32" + self.input_var_numpy_2 = np.random.rand(2, 3, 4).astype(self.dtype) def non_inplace_api_processing(self, var): input_var_2 = paddle.to_tensor(self.input_var_numpy_2) @@ -493,6 +502,41 @@ def inplace_api_processing(self, var): input_var_2 = paddle.to_tensor(self.input_var_numpy_2) return var.remainder_(input_var_2) + def set_np_compare_func(self): + self.np_compare = np.array_equal + + def func_test_inplace_api(self): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + inplace_var = self.inplace_api_processing(var) + self.assertTrue(id(var) == id(inplace_var)) + + inplace_var[0] = 2. + np.testing.assert_array_equal(var.numpy(), inplace_var.numpy()) + + def test_inplace_api(self): + with _test_eager_guard(): + self.func_test_inplace_api() + self.func_test_inplace_api() + + def func_test_forward_version(self): + with paddle.fluid.dygraph.guard(): + var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype) + self.assertEqual(var.inplace_version, 0) + + inplace_var = self.inplace_api_processing(var) + self.assertEqual(var.inplace_version, 1) + + inplace_var[0] = 2. + self.assertEqual(var.inplace_version, 2) + + inplace_var = self.inplace_api_processing(inplace_var) + self.assertEqual(var.inplace_version, 3) + + def test_forward_version(self): + with _test_eager_guard(): + self.func_test_forward_version() + self.func_test_forward_version() + class TestLossIsInplaceVar(unittest.TestCase):