Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- If not set by the user, Lightning will set `OMP_NUM_THREADS` to `num_cpus / num_processes` when launching subprocesses (e.g. when DDP is used) to avoid system overload for CPU-intensive tasks ([#18677](https://github.com/Lightning-AI/lightning/pull/18677))
- The `ModelCheckpoint` no longer deletes files under the save-top-k mechanism when resuming from a folder that is not the same as the current checkpoint folder ([#18750](https://github.com/Lightning-AI/lightning/pull/18750))
- The `ModelCheckpoint` no longer deletes the file that was passed to `Trainer.fit(ckpt_path=...)` ([#18750](https://github.com/Lightning-AI/lightning/pull/18750))
- Calling `trainer.fit()` twice now raises an error with strategies that spawn subprocesses through `multiprocessing` (ddp_spawn, xla) ([#18776](https://github.com/Lightning-AI/lightning/pull/18776))

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
f" {', '.join(mp.get_all_start_methods())}"
)
self.procs: List[mp.Process] = []
self._already_fit = False

@property
def is_interactive_compatible(self) -> bool:
Expand All @@ -106,6 +107,13 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
_check_bad_cuda_fork()
if self._start_method == "spawn":
_check_missing_main_guard()
if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING:
# resolving https://github.com/Lightning-AI/lightning/issues/18775 will lift this restriction
raise NotImplementedError(
"Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not"
" supported. You can work around this limitation by creating a new Trainer instance and passing the"
" `fit(ckpt_path=...)` argument."
)

# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
Expand Down Expand Up @@ -137,6 +145,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
if trainer is None:
return worker_output

self._already_fit |= trainer.state.fn == TrainerFn.FITTING
self._recover_results_in_main_process(worker_output, trainer)
return worker_output.trainer_results

Expand Down
9 changes: 9 additions & 0 deletions src/lightning/pytorch/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
**kwargs: Optional keyword arguments to be passed to the given function.

"""
if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING:
# resolving https://github.com/Lightning-AI/lightning/issues/18775 will lift this restriction
raise NotImplementedError(
"Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not"
" supported. You can work around this by creating a new Trainer instance and passing the"
" `fit(ckpt_path=...)` argument."
)

using_pjrt = _using_pjrt()
# pjrt requires that the queue is serializable
return_queue: Union[queue.Queue, mp.SimpleQueue] = (
Expand Down Expand Up @@ -102,6 +110,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
if trainer is None:
return worker_output

self._already_fit |= trainer.state.fn == TrainerFn.FITTING
self._recover_results_in_main_process(worker_output, trainer)
return worker_output.trainer_results

Expand Down
32 changes: 16 additions & 16 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,9 @@ def fit(
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(model, self.strategy)
self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True
call._call_and_handle_interrupt(
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
)
Expand All @@ -553,10 +556,6 @@ def _fit_impl(
) -> None:
log.debug(f"{self.__class__.__name__}: trainer fit stage")

self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True

# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloaders, LightningDataModule):
datamodule = train_dataloaders
Expand All @@ -572,6 +571,7 @@ def _fit_impl(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

assert self.state.fn is not None
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn,
ckpt_path,
Expand Down Expand Up @@ -640,6 +640,9 @@ def validate(
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(self.lightning_module, self.strategy)
self.state.fn = TrainerFn.VALIDATING
self.state.status = TrainerStatus.RUNNING
self.validating = True
return call._call_and_handle_interrupt(
self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule
)
Expand All @@ -657,10 +660,6 @@ def _validate_impl(
# --------------------
log.debug(f"{self.__class__.__name__}: trainer validate stage")

self.state.fn = TrainerFn.VALIDATING
self.state.status = TrainerStatus.RUNNING
self.validating = True

# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
datamodule = dataloaders
Expand All @@ -680,6 +679,7 @@ def _validate_impl(
# links data to the trainer
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)

assert self.state.fn is not None
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)
Expand Down Expand Up @@ -749,6 +749,9 @@ def test(
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(self.lightning_module, self.strategy)
self.state.fn = TrainerFn.TESTING
self.state.status = TrainerStatus.RUNNING
self.testing = True
return call._call_and_handle_interrupt(
self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
)
Expand All @@ -766,10 +769,6 @@ def _test_impl(
# --------------------
log.debug(f"{self.__class__.__name__}: trainer test stage")

self.state.fn = TrainerFn.TESTING
self.state.status = TrainerStatus.RUNNING
self.testing = True

# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
datamodule = dataloaders
Expand All @@ -789,6 +788,7 @@ def _test_impl(
# links data to the trainer
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)

assert self.state.fn is not None
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)
Expand Down Expand Up @@ -859,6 +859,9 @@ def predict(
model = _maybe_unwrap_optimized(model)
self.strategy._lightning_module = model
_verify_strategy_supports_compile(self.lightning_module, self.strategy)
self.state.fn = TrainerFn.PREDICTING
self.state.status = TrainerStatus.RUNNING
self.predicting = True
return call._call_and_handle_interrupt(
self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
)
Expand All @@ -876,10 +879,6 @@ def _predict_impl(
# --------------------
log.debug(f"{self.__class__.__name__}: trainer predict stage")

self.state.fn = TrainerFn.PREDICTING
self.state.status = TrainerStatus.RUNNING
self.predicting = True

self.predict_loop.return_predictions = return_predictions # type: ignore[assignment]

# if a datamodule comes in as the second arg, then fix it for the user
Expand All @@ -898,6 +897,7 @@ def _predict_impl(
# links data to the trainer
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)

assert self.state.fn is not None
ckpt_path = self._checkpoint_connector._select_ckpt_path(
self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,20 @@ def test_check_for_missing_main_guard():
return_value=Mock(_inheriting=True), # pretend that main is importing itself
), pytest.raises(RuntimeError, match="requires that your script guards the main"):
launcher.launch(function=Mock())


def test_fit_twice_raises():
model = BoringModel()
trainer = Trainer(
limit_train_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
strategy="ddp_spawn",
barebones=True,
)
trainer.fit(model)
trainer.test(model) # make sure testing in between doesnt impact the result
trainer.fit_loop.max_epochs += 1
with pytest.raises(NotImplementedError, match=r"twice.*is not supported"):
trainer.fit(model)