From 70476a8e7687bb9de5c7678c4b43232d03044abd Mon Sep 17 00:00:00 2001 From: LittlebullGit Date: Mon, 24 Nov 2025 16:38:04 -0500 Subject: [PATCH 1/3] Fix ModelCheckpoint.file_exists OOM in DDP --- .../pytorch/callbacks/model_checkpoint.py | 6 ++-- .../test_checkpoint_callback_frequency.py | 25 ++++++++++++++ .../checkpointing/test_model_checkpoint.py | 33 +++++++++++++++++++ 3 files changed, 62 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 5357ee39a942a..dfa1300ac1cf2 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -999,8 +999,10 @@ def to_yaml(self, filepath: Optional[_PATH] = None) -> None: def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool: """Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.""" - exists = self._fs.exists(filepath) - return trainer.strategy.broadcast(exists) + # In distributed setups, only global rank 0 touches the filesystem + local_decision = self._fs.exists(filepath) if trainer.is_global_zero else False + # Reduce the decision across ranks using an "any"-style reduction to decide if the file exists anywhere + return trainer.strategy.reduce_boolean_decision(local_decision, all=False) def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool: """Checks if the previous checkpoint should be deleted. diff --git a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py index 2e998c42ed2b7..a89cc0a12efc3 100644 --- a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py @@ -121,3 +121,28 @@ def on_train_epoch_end(self): trainer.fit(model) if os.getenv("LOCAL_RANK") == "0": assert save_mock.call_count == expected + + +@RunIf(min_cuda_gpus=2, standalone=True) +def test_model_checkpoint_ddp_monitor_none(tmp_path): + """Ensure that ModelCheckpoint with monitor=None works correctly under DDP and exercises the file_exists path.""" + + model = BoringModel() + checkpoint = callbacks.ModelCheckpoint(dirpath=tmp_path, monitor=None, save_top_k=1) + + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint], + enable_progress_bar=False, + enable_model_summary=False, + max_epochs=1, + strategy="ddp", + accelerator="gpu", + devices=2, + limit_train_batches=2, + limit_val_batches=0, + ) + + trainer.fit(model) + if os.getenv("LOCAL_RANK") == "0": + assert checkpoint.best_model_path diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 449484da970a8..0ba3ea8689cdd 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -2180,3 +2180,36 @@ def on_validation_epoch_end(self): assert len(checkpoint_files) == expected_files, ( f"Expected {expected_files} files, got {len(checkpoint_files)}: {checkpoint_names}" ) + + +def test_model_checkpoint_file_exists_distributed_branch(tmp_path): + """Ensure the distributed branch of ModelCheckpoint.file_exists uses reduce_boolean_decision.""" + + checkpoint = ModelCheckpoint(dirpath=tmp_path) + calls = [] + + class DummyStrategy: + def reduce_boolean_decision(self, decision, all=True): + calls.append((decision, all)) + return decision + + class DummyTrainer: + def __init__(self, is_global_zero: bool): + self.world_size = 2 + self.is_global_zero = is_global_zero + self.strategy = DummyStrategy() + + # global rank 0: filesystem is touched and decision=True is reduced with all=False + checkpoint._fs.exists = Mock(return_value=True) + trainer = DummyTrainer(is_global_zero=True) + assert checkpoint.file_exists("ignored", trainer) + checkpoint._fs.exists.assert_called_once_with("ignored") + assert calls == [(True, False)] + + # non-global ranks: filesystem is not touched and local decision is False + calls.clear() + checkpoint._fs.exists = Mock(return_value=True) + trainer = DummyTrainer(is_global_zero=False) + assert not checkpoint.file_exists("ignored", trainer) + checkpoint._fs.exists.assert_not_called() + assert calls == [(False, False)] From f6e48c01ecfaac23416c38e13df09ce0a9bca514 Mon Sep 17 00:00:00 2001 From: LittlebullGit Date: Mon, 24 Nov 2025 16:38:07 -0500 Subject: [PATCH 2/3] Document ModelCheckpoint.file_exists DDP memory fix --- src/lightning/pytorch/CHANGELOG.md | 3 ++ .../checkpointing/test_model_checkpoint.py | 33 ------------------- 2 files changed, 3 insertions(+), 33 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 709d307580094..b99e1c5969ccb 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -82,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed FSDP mixed precision semantics and added user warning ([#21361](https://github.com/Lightning-AI/pytorch-lightning/pull/21361)) +- Fixed `ModelCheckpoint.file_exists` using broadcast in DDP, reducing memory usage when checking for existing checkpoints ([#19674](https://github.com/Lightning-AI/pytorch-lightning/issues/19674)) + + --- ## [2.5.6] - 2025-11-05 diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 0ba3ea8689cdd..449484da970a8 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -2180,36 +2180,3 @@ def on_validation_epoch_end(self): assert len(checkpoint_files) == expected_files, ( f"Expected {expected_files} files, got {len(checkpoint_files)}: {checkpoint_names}" ) - - -def test_model_checkpoint_file_exists_distributed_branch(tmp_path): - """Ensure the distributed branch of ModelCheckpoint.file_exists uses reduce_boolean_decision.""" - - checkpoint = ModelCheckpoint(dirpath=tmp_path) - calls = [] - - class DummyStrategy: - def reduce_boolean_decision(self, decision, all=True): - calls.append((decision, all)) - return decision - - class DummyTrainer: - def __init__(self, is_global_zero: bool): - self.world_size = 2 - self.is_global_zero = is_global_zero - self.strategy = DummyStrategy() - - # global rank 0: filesystem is touched and decision=True is reduced with all=False - checkpoint._fs.exists = Mock(return_value=True) - trainer = DummyTrainer(is_global_zero=True) - assert checkpoint.file_exists("ignored", trainer) - checkpoint._fs.exists.assert_called_once_with("ignored") - assert calls == [(True, False)] - - # non-global ranks: filesystem is not touched and local decision is False - calls.clear() - checkpoint._fs.exists = Mock(return_value=True) - trainer = DummyTrainer(is_global_zero=False) - assert not checkpoint.file_exists("ignored", trainer) - checkpoint._fs.exists.assert_not_called() - assert calls == [(False, False)] From 58d8c505b5027b2f081d964a32f80300e9daf17b Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Tue, 25 Nov 2025 12:04:43 +0100 Subject: [PATCH 3/3] Update src/lightning/pytorch/callbacks/model_checkpoint.py --- src/lightning/pytorch/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index dfa1300ac1cf2..34dca4232a475 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -997,7 +997,7 @@ def to_yaml(self, filepath: Optional[_PATH] = None) -> None: yaml.dump(best_k, fp) def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool: - """Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal + """Checks if a file exists on rank 0 and synchronizes the result to all other ranks, preventing the internal state to diverge between ranks.""" # In distributed setups, only global rank 0 touches the filesystem local_decision = self._fs.exists(filepath) if trainer.is_global_zero else False