Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions tests/strategies/test_ddp_strategy_with_comm_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,26 @@
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD


class TestDDPStrategy(DDPStrategy):
def __init__(self, expected_ddp_comm_hook_name, *args, **kwargs):
self.expected_ddp_comm_hook_name = expected_ddp_comm_hook_name
super().__init__(*args, **kwargs)

def teardown(self):
# check here before unwrapping DistributedDataParallel in self.teardown
attached_ddp_comm_hook_name = self.model._get_ddp_logging_data()["comm_hook"]
assert attached_ddp_comm_hook_name == self.expected_ddp_comm_hook_name
return super().teardown()


@RunIf(min_gpus=2, min_torch="1.9.0", skip_windows=True, standalone=True)
def test_ddp_fp16_compress_comm_hook(tmpdir):
"""Test for DDP FP16 compress hook."""
model = BoringModel()
strategy = DDPStrategy(ddp_comm_hook=default.fp16_compress_hook)
strategy = TestDDPStrategy(
expected_ddp_comm_hook_name=default.fp16_compress_hook.__qualname__,
ddp_comm_hook=default.fp16_compress_hook,
)
trainer = Trainer(
max_epochs=1,
accelerator="gpu",
Expand All @@ -47,17 +62,15 @@ def test_ddp_fp16_compress_comm_hook(tmpdir):
enable_model_summary=False,
)
trainer.fit(model)
trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook
expected_comm_hook = default.fp16_compress_hook.__qualname__
assert trainer_comm_hook == expected_comm_hook
assert trainer.state.finished, f"Training failed with {trainer.state}"


@RunIf(min_gpus=2, min_torch="1.9.0", skip_windows=True, standalone=True)
def test_ddp_sgd_comm_hook(tmpdir):
"""Test for DDP FP16 compress hook."""
model = BoringModel()
strategy = DDPStrategy(
strategy = TestDDPStrategy(
expected_ddp_comm_hook_name=powerSGD.powerSGD_hook.__qualname__,
ddp_comm_state=powerSGD.PowerSGDState(process_group=None),
ddp_comm_hook=powerSGD.powerSGD_hook,
)
Expand All @@ -73,17 +86,15 @@ def test_ddp_sgd_comm_hook(tmpdir):
enable_model_summary=False,
)
trainer.fit(model)
trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook
expected_comm_hook = powerSGD.powerSGD_hook.__qualname__
assert trainer_comm_hook == expected_comm_hook
assert trainer.state.finished, f"Training failed with {trainer.state}"


@RunIf(min_gpus=2, min_torch="1.9.0", skip_windows=True, standalone=True)
def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir):
"""Test for DDP FP16 compress wrapper for SGD hook."""
model = BoringModel()
strategy = DDPStrategy(
strategy = TestDDPStrategy(
expected_ddp_comm_hook_name=default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__,
ddp_comm_state=powerSGD.PowerSGDState(process_group=None),
ddp_comm_hook=powerSGD.powerSGD_hook,
ddp_comm_wrapper=default.fp16_compress_wrapper,
Expand All @@ -100,9 +111,6 @@ def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir):
enable_model_summary=False,
)
trainer.fit(model)
trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook
expected_comm_hook = default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__
assert trainer_comm_hook == expected_comm_hook
assert trainer.state.finished, f"Training failed with {trainer.state}"


Expand Down Expand Up @@ -130,8 +138,8 @@ def test_ddp_spawn_fp16_compress_comm_hook(tmpdir):
def test_ddp_post_local_sgd_comm_hook(tmpdir):
"""Test for DDP post-localSGD hook."""
model = BoringModel()

strategy = DDPStrategy(
strategy = TestDDPStrategy(
expected_ddp_comm_hook_name=post_localSGD.post_localSGD_hook.__qualname__,
ddp_comm_state=post_localSGD.PostLocalSGDState(
process_group=None,
subgroup=None,
Expand All @@ -151,9 +159,6 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir):
enable_model_summary=False,
)
trainer.fit(model)
trainer_comm_hook = trainer.strategy.model.get_ddp_logging_data().comm_hook
expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
assert trainer_comm_hook == expected_comm_hook
assert trainer.state.finished, f"Training failed with {trainer.state}"


Expand Down