diff --git a/test/auto_parallel/test_tuning_recompute.py b/test/auto_parallel/test_tuning_recompute.py index e39091b8f57b7..239569e47f144 100644 --- a/test/auto_parallel/test_tuning_recompute.py +++ b/test/auto_parallel/test_tuning_recompute.py @@ -36,7 +36,7 @@ def generate_model(): gpt = GPTModel( vocab_size=50304, hidden_size=1024, - num_hidden_layers=14, + num_hidden_layers=13, num_attention_heads=16, intermediate_size=1024 * 4, hidden_act="gelu", @@ -97,14 +97,25 @@ def test_recompute_pass(self): engine = auto.Engine(model, loss, opt, strategy=strategy) engine._tune(self.dataset, 3, batch_size=self.batch_size) - assert ( - len( - engine._dist_contexts[ - 'train' - ].strategy.recompute.no_recompute_segments - ) - > 0 + gpu_memory_size = round( + paddle.device.cuda.get_device_properties(0).total_memory + / 1024 + / 1024 + / 1024 ) + dist_strategy = engine._dist_contexts['train'].strategy + if gpu_memory_size in [16, 32]: + self.assertGreater( + len(dist_strategy.recompute.no_recompute_segments), + 0, + "When GPU memory size is 16G or 32G, the length of no_recompute_segments should be greater than 0.", + ) + elif gpu_memory_size >= 40: + self.assertEqual( + dist_strategy.recompute.enable, + False, + "When GPU memory size is greater than 40GB, the recompute strategy should be disable.", + ) if __name__ == "__main__":