Skip to content

Commit 10675b4

Browse files
Synchronize gradients in manual optimization with DDPStrategy(static_graph=True) (#21251)
* fix: synchronize gradients in manual optimization with DDPStrategy(static_graph=True). Ensure gradients are reduced correctly when using manual optimization and DDP with static_graph enabled. * Adds regression test to cover all combinations of optimization/static_graph. * Initialize _pl_static_graph_delay_done attribute properly * changelog --------- Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
1 parent 2f448e1 commit 10675b4

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6161
- Fixed how `ThroughputMonitor` calculated training time ([#21291](https://github.com/Lightning-AI/pytorch-lightning/pull/21291))
6262

6363

64+
- Fixed synchronization of gradients in manual optimization with `DDPStrategy(static_graph=True)` ([#21251](https://github.com/Lightning-AI/pytorch-lightning/pull/21251))
65+
66+
67+
6468
---
6569

6670
## [2.5.5] - 2025-09-05

src/lightning/pytorch/strategies/ddp.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
self._process_group_backend: Optional[str] = process_group_backend
104104
self._timeout: Optional[timedelta] = timeout
105105
self._start_method = start_method
106+
self._pl_static_graph_delay_done = False
106107

107108
@property
108109
def is_distributed(self) -> bool: # pragma: no-cover
@@ -319,6 +320,27 @@ def pre_backward(self, closure_loss: Tensor) -> None:
319320
if not self.lightning_module.automatic_optimization:
320321
prepare_for_backward(self.model, closure_loss)
321322

323+
@override
324+
def post_backward(self, closure_loss: Tensor) -> None:
325+
# Only for first static-graph iteration with manual optimization
326+
model = self.model
327+
lm = self.lightning_module
328+
if not isinstance(model, DistributedDataParallel):
329+
return
330+
if lm is None or lm.automatic_optimization:
331+
return
332+
if not getattr(model, "static_graph", False):
333+
return
334+
if self._pl_static_graph_delay_done:
335+
return
336+
337+
# Call DDP's own first-iter static-graph flush.
338+
# This is what actually launches the bucket all-reduces.
339+
reducer = model.reducer
340+
reducer._delay_all_reduce()
341+
342+
self._pl_static_graph_delay_done = True
343+
322344
@override
323345
def model_to_device(self) -> None:
324346
log.debug(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")

tests/tests_pytorch/strategies/test_ddp_integration.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,50 @@ def creates_processes_externally(self):
448448
RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`."
449449
):
450450
trainer.fit(model)
451+
452+
453+
@RunIf(min_cuda_gpus=2, standalone=True)
454+
@pytest.mark.parametrize("automatic_optimization", [True, False])
455+
@pytest.mark.parametrize("static_graph", [True, False])
456+
def test_ddp_gradients_synced(tmp_path, automatic_optimization, static_graph):
457+
"""Ensure gradients are synchronized across ranks for both optimization modes and static_graph settings."""
458+
459+
class TestModel(BoringModel):
460+
def __init__(self):
461+
super().__init__()
462+
self.automatic_optimization = automatic_optimization
463+
464+
def training_step(self, batch, batch_idx):
465+
if self.automatic_optimization:
466+
return super().training_step(batch, batch_idx)
467+
468+
# manual optimization path
469+
opt = self.optimizers()
470+
opt.zero_grad()
471+
out = super().training_step(batch, batch_idx)
472+
loss = out["loss"]
473+
self.manual_backward(loss)
474+
opt.step()
475+
return out
476+
477+
def on_train_batch_end(self, *args, **kwargs):
478+
# record grad sum for sync check
479+
grad_sum = self.layer.bias.grad.detach().sum()
480+
self.log("grad_sum_min", grad_sum, sync_dist=True, reduce_fx="min")
481+
self.log("grad_sum_max", grad_sum, sync_dist=True, reduce_fx="max")
482+
483+
trainer = Trainer(
484+
default_root_dir=tmp_path,
485+
accelerator="gpu",
486+
devices=2,
487+
strategy=DDPStrategy(static_graph=static_graph),
488+
max_steps=1,
489+
enable_progress_bar=False,
490+
enable_model_summary=False,
491+
)
492+
trainer.fit(TestModel(), datamodule=BoringDataModule())
493+
494+
# assert all ranks saw identical grads
495+
gmin = trainer.callback_metrics["grad_sum_min"]
496+
gmax = trainer.callback_metrics["grad_sum_max"]
497+
assert torch.allclose(gmin, gmax)

0 commit comments

Comments
 (0)