diff --git a/tests/strategies/test_ddp_strategy_with_comm_hook.py b/tests/strategies/test_ddp_strategy_with_comm_hook.py index 9445565ec9c3b..dada03e83a5a4 100644 --- a/tests/strategies/test_ddp_strategy_with_comm_hook.py +++ b/tests/strategies/test_ddp_strategy_with_comm_hook.py @@ -30,11 +30,26 @@ import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD +class TestDDPStrategy(DDPStrategy): + def __init__(self, expected_ddp_comm_hook_name, *args, **kwargs): + self.expected_ddp_comm_hook_name = expected_ddp_comm_hook_name + super().__init__(*args, **kwargs) + + def teardown(self): + # check here before unwrapping DistributedDataParallel in self.teardown + attached_ddp_comm_hook_name = self.model._get_ddp_logging_data()["comm_hook"] + assert attached_ddp_comm_hook_name == self.expected_ddp_comm_hook_name + return super().teardown() + + @RunIf(min_gpus=2, min_torch="1.9.0", skip_windows=True, standalone=True) def test_ddp_fp16_compress_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() - strategy = DDPStrategy(ddp_comm_hook=default.fp16_compress_hook) + strategy = TestDDPStrategy( + expected_ddp_comm_hook_name=default.fp16_compress_hook.__qualname__, + ddp_comm_hook=default.fp16_compress_hook, + ) trainer = Trainer( max_epochs=1, accelerator="gpu", @@ -47,9 +62,6 @@ def test_ddp_fp16_compress_comm_hook(tmpdir): enable_model_summary=False, ) trainer.fit(model) - trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook - expected_comm_hook = default.fp16_compress_hook.__qualname__ - assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -57,7 +69,8 @@ def test_ddp_fp16_compress_comm_hook(tmpdir): def test_ddp_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress hook.""" model = BoringModel() - strategy = DDPStrategy( + strategy = TestDDPStrategy( + expected_ddp_comm_hook_name=powerSGD.powerSGD_hook.__qualname__, ddp_comm_state=powerSGD.PowerSGDState(process_group=None), ddp_comm_hook=powerSGD.powerSGD_hook, ) @@ -73,9 +86,6 @@ def test_ddp_sgd_comm_hook(tmpdir): enable_model_summary=False, ) trainer.fit(model) - trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook - expected_comm_hook = powerSGD.powerSGD_hook.__qualname__ - assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -83,7 +93,8 @@ def test_ddp_sgd_comm_hook(tmpdir): def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): """Test for DDP FP16 compress wrapper for SGD hook.""" model = BoringModel() - strategy = DDPStrategy( + strategy = TestDDPStrategy( + expected_ddp_comm_hook_name=default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__, ddp_comm_state=powerSGD.PowerSGDState(process_group=None), ddp_comm_hook=powerSGD.powerSGD_hook, ddp_comm_wrapper=default.fp16_compress_wrapper, @@ -100,9 +111,6 @@ def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir): enable_model_summary=False, ) trainer.fit(model) - trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook - expected_comm_hook = default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__ - assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -130,8 +138,8 @@ def test_ddp_spawn_fp16_compress_comm_hook(tmpdir): def test_ddp_post_local_sgd_comm_hook(tmpdir): """Test for DDP post-localSGD hook.""" model = BoringModel() - - strategy = DDPStrategy( + strategy = TestDDPStrategy( + expected_ddp_comm_hook_name=post_localSGD.post_localSGD_hook.__qualname__, ddp_comm_state=post_localSGD.PostLocalSGDState( process_group=None, subgroup=None, @@ -151,9 +159,6 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir): enable_model_summary=False, ) trainer.fit(model) - trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook - expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__ - assert trainer_comm_hook == expected_comm_hook assert trainer.state.finished, f"Training failed with {trainer.state}"