Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# optionally can be set by user
self._example_input_array = None
self._current_fx_name: Optional[str] = None
self._current_dataloader_idx: Optional[int] = None
self._automatic_optimization: bool = True
self._truncated_bptt_steps: int = 0
self._param_requires_grad_state = {}
Expand Down Expand Up @@ -419,7 +418,6 @@ def log(
reduce_fx=reduce_fx,
enable_graph=enable_graph,
add_dataloader_idx=add_dataloader_idx,
dataloader_idx=self._current_dataloader_idx,
batch_size=batch_size,
sync_dist=sync_dist and distributed_available(),
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
Expand Down
14 changes: 6 additions & 8 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,16 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader."""
void(*args, **kwargs)

dataloader_idx: int = self.current_dataloader_idx
dataloader_idx = self.current_dataloader_idx
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=dataloader_idx
)
dl_max_batches = self._max_batches[dataloader_idx]

dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
dl_outputs = self.epoch_loop.run(
dataloader, dataloader_idx if self.num_dataloaders > 1 else None, dl_max_batches
)

# store batch level output per dataloader
self._outputs.append(dl_outputs)
Expand Down Expand Up @@ -212,17 +214,13 @@ def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None:
# inform logger the batch loop has finished
self.trainer.logger_connector.epoch_end_reached()

# call the model epoch end
model = self.trainer.lightning_module

# unset dataloader_idx in model
model._current_dataloader_idx = None

# with a single dataloader don't pass a 2D list
output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = (
outputs[0] if len(outputs) > 0 and self.num_dataloaders == 1 else outputs
)

# call the model epoch end
model = self.trainer.lightning_module
if self.trainer.testing:
if is_overridden("test_epoch_end", model):
model._current_fx_name = "test_epoch_end"
Expand Down
59 changes: 23 additions & 36 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(self) -> None:

self._outputs: EPOCH_OUTPUT = []
self._dl_max_batches = 0
self._num_dataloaders = 0
self._dataloader_iter: Optional[Iterator] = None
self._data_fetcher: Optional[AbstractDataFetcher] = None
self._dataloader_state_dict: Dict[str, Any] = {}
Expand All @@ -61,7 +60,6 @@ def done(self) -> bool:
def reset(self) -> None:
"""Resets the loop's internal state."""
self._dl_max_batches = 0
self._num_dataloaders = 0
self._data_fetcher = None
self._outputs = []

Expand All @@ -71,39 +69,36 @@ def reset(self) -> None:
self.batch_progress.reset_on_restart()

def on_run_start( # type: ignore[override]
self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
self, data_fetcher: AbstractDataFetcher, dataloader_idx: Optional[int], dl_max_batches: int
) -> None:
"""Adds the passed arguments to the loop's state if necessary.

Args:
data_fetcher: the current data_fetcher wrapping the dataloader
dataloader_idx: index of the current dataloader
dl_max_batches: maximum number of batches the dataloader can produce
num_dataloaders: the total number of dataloaders
"""
void(dataloader_idx)
self._dl_max_batches = dl_max_batches
self._num_dataloaders = num_dataloaders
self._data_fetcher = data_fetcher

self._reload_dataloader_state_dict(data_fetcher)
self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_progress.current.ready)

def advance( # type: ignore[override]
self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
self, data_fetcher: AbstractDataFetcher, dataloader_idx: Optional[int], dl_max_batches: int
) -> None:
"""Calls the evaluation step with the corresponding hooks and updates the logger connector.

Args:
data_fetcher: iterator over the dataloader
dataloader_idx: index of the current dataloader
dl_max_batches: maximum number of batches the dataloader can produce
num_dataloaders: the total number of dataloaders

Raises:
StopIteration: If the current batch is None
"""
void(dl_max_batches, num_dataloaders)
void(dl_max_batches)

assert self._dataloader_iter is not None
batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter)
Expand All @@ -113,24 +108,27 @@ def advance( # type: ignore[override]

if not data_fetcher.store_on_device:
with self.trainer.profiler.profile("evaluation_batch_to_device"):
batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=dataloader_idx)
batch = self.trainer.training_type_plugin.batch_to_device(batch, dataloader_idx=(dataloader_idx or 0))

self.batch_progress.increment_ready()

# configure step_kwargs
kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

# hook
self._on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
self._on_evaluation_batch_start(**kwargs)

self.batch_progress.increment_started()

# lightning module methods
with self.trainer.profiler.profile("evaluation_step_and_end"):
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
output = self._evaluation_step(**kwargs)
output = self._evaluation_step_end(output)

self.batch_progress.increment_processed()

# track loss history
self._on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx)
self._on_evaluation_batch_end(output, **kwargs)

self.batch_progress.increment_completed()

Expand Down Expand Up @@ -208,7 +206,7 @@ def _num_completed_batches_reached(self) -> bool:
def _has_completed(self) -> bool:
return self.batch_progress.current.ready == self.batch_progress.current.completed

def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[STEP_OUTPUT]:
def _evaluation_step(self, **kwargs: Any) -> Optional[STEP_OUTPUT]:
"""The evaluation step (validation_step or test_step depending on the trainer's state).

Args:
Expand All @@ -219,17 +217,14 @@ def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> O
Returns:
the outputs of the step
"""
# configure step_kwargs
step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

if self.trainer.testing:
self.trainer.lightning_module._current_fx_name = "test_step"
with self.trainer.profiler.profile("test_step"):
output = self.trainer.accelerator.test_step(*step_kwargs.values())
output = self.trainer.accelerator.test_step(*kwargs.values())
else:
self.trainer.lightning_module._current_fx_name = "validation_step"
with self.trainer.profiler.profile("validation_step"):
output = self.trainer.accelerator.validation_step(*step_kwargs.values())
output = self.trainer.accelerator.validation_step(*kwargs.values())

return output

Expand All @@ -239,7 +234,7 @@ def _evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPU
output = self.trainer.call_hook(hook_name, *args, **kwargs)
return output

def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def _on_evaluation_batch_start(self, **kwargs: Any) -> None:
"""Calls the ``on_{validation/test}_batch_start`` hook.

Args:
Expand All @@ -250,19 +245,15 @@ def _on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
Raises:
AssertionError: If the number of dataloaders is None (has not yet been set).
"""
self.trainer.logger_connector.on_batch_start(batch_idx, batch)

assert self._num_dataloaders is not None
self.trainer.logger_connector.on_evaluation_batch_start(dataloader_idx, self._num_dataloaders)
self.trainer.logger_connector.on_batch_start(**kwargs)

kwargs.setdefault("dataloader_idx", 0) # TODO: the argument should be keyword for these
if self.trainer.testing:
self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx)
self.trainer.call_hook("on_test_batch_start", *kwargs.values())
else:
self.trainer.call_hook("on_validation_batch_start", batch, batch_idx, dataloader_idx)
self.trainer.call_hook("on_validation_batch_start", *kwargs.values())

def _on_evaluation_batch_end(
self, output: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
def _on_evaluation_batch_end(self, output: Optional[STEP_OUTPUT], **kwargs: Any) -> None:
"""The ``on_{validation/test}_batch_end`` hook.

Args:
Expand All @@ -271,12 +262,13 @@ def _on_evaluation_batch_end(
batch_idx: The index of the current batch
dataloader_idx: Index of the dataloader producing the current batch
"""
kwargs.setdefault("dataloader_idx", 0) # TODO: the argument should be keyword for these
hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end"
self.trainer.call_hook(hook_name, output, batch, batch_idx, dataloader_idx)
self.trainer.call_hook(hook_name, output, *kwargs.values())

self.trainer.logger_connector.on_batch_end()

def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]:
def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]) -> Dict[str, Union[Any, int]]:
"""Helper function to build the arguments for the current step.

Args:
Expand All @@ -289,13 +281,8 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict
"""
# make dataloader_idx arg in validation_step optional
step_kwargs = OrderedDict([("batch", batch), ("batch_idx", batch_idx)])

multiple_val_loaders = not self.trainer.testing and self._num_dataloaders > 1
multiple_test_loaders = self.trainer.testing and self._num_dataloaders > 1

if multiple_test_loaders or multiple_val_loaders:
if dataloader_idx is not None:
step_kwargs["dataloader_idx"] = dataloader_idx

return step_kwargs

@lru_cache(1)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov

self.batch_progress.increment_ready()

self.trainer.logger_connector.on_batch_start(batch_idx, batch)
self.trainer.logger_connector.on_batch_start(batch, batch_idx)

if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,12 @@ def _increment_eval_log_step(self) -> None:
elif self.trainer.state.stage is RunningStage.TESTING:
self._test_log_step += 1

def on_evaluation_batch_start(self, dataloader_idx: int, num_dataloaders: int) -> None:
model = self.trainer.lightning_module
# set dataloader_idx only if multiple ones
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None

def update_eval_step_metrics(self) -> None:
assert not self._epoch_end_reached
if self.trainer.sanity_checking:
return

# logs user requested information to logger
assert not self._epoch_end_reached
self.log_metrics(self.metrics["log"], step=self._eval_log_step)

# increment the step even if nothing was logged
Expand Down Expand Up @@ -259,23 +254,29 @@ def _log_gpus_metrics(self) -> None:
def on_epoch_start(self) -> None:
self._epoch_end_reached = False

def on_batch_start(self, batch_idx: int, batch: Any) -> None:
def on_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> None:
self._batch_idx = batch_idx
self._epoch_end_reached = False

assert self.trainer._results is not None
results = self.trainer._results
assert results is not None
# attach reference to the new batch and remove the cached batch_size
self.trainer._results.batch = batch
self.trainer._results.batch_size = None
results.batch = batch
results.batch_size = None
results.dataloader_idx = dataloader_idx

def epoch_end_reached(self) -> None:
self._epoch_end_reached = True
self._batch_idx = None
self._split_idx = None
assert self.trainer._results is not None

def on_epoch_end(self) -> None:
assert self._epoch_end_reached
results = self.trainer._results
assert results is not None
# we need to reset this index before the `self.metrics` call below
results.dataloader_idx = None

metrics = self.metrics
self._progress_bar_metrics.update(metrics["pbar"])
self._callback_metrics.update(metrics["callback"])
Expand Down Expand Up @@ -308,8 +309,9 @@ def reset_metrics(self) -> None:
self._callback_metrics = {}

def reset_results(self) -> None:
if self.trainer._results is not None:
self.trainer._results.reset()
results = self.trainer._results
if results is not None:
results.reset()

self._batch_idx = None
self._split_idx = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] =
self.device: Optional[Union[str, torch.device]] = device
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.dataloader_idx: Optional[int] = None

@property
def result_metrics(self) -> List[ResultMetric]:
Expand Down Expand Up @@ -436,7 +437,6 @@ def log(
sync_dist_fn: Callable = _Sync.no_op,
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
dataloader_idx: Optional[int] = None,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
rank_zero_only: bool = False,
Expand All @@ -453,9 +453,9 @@ def log(
# storage key
key = f"{fx}.{name}"
# add dataloader_suffix to both key and fx
if add_dataloader_idx and dataloader_idx is not None:
key += f".{dataloader_idx}"
fx += f".{dataloader_idx}"
if add_dataloader_idx and self.dataloader_idx is not None:
key += f".{self.dataloader_idx}"
fx += f".{self.dataloader_idx}"

meta = _Metadata(
fx=fx,
Expand All @@ -467,7 +467,7 @@ def log(
reduce_fx=reduce_fx,
enable_graph=enable_graph,
add_dataloader_idx=add_dataloader_idx,
dataloader_idx=dataloader_idx,
dataloader_idx=self.dataloader_idx,
metric_attribute=metric_attribute,
)
meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, _group=sync_dist_group, rank_zero_only=rank_zero_only)
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,7 +1430,6 @@ def _call_teardown_hook(self) -> None:
self.call_hook("teardown", stage=fn)

self.lightning_module._current_fx_name = None
self.lightning_module._current_dataloader_idx = None
# these could have become stale if metrics are defined in `setup`
self.lightning_module._metric_attributes = None

Expand Down
4 changes: 4 additions & 0 deletions tests/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,15 @@ def test_loops_state_dict_structure():
"epoch_loop.val_loop._results": {
"batch": None,
"batch_size": None,
"dataloader_idx": None,
"training": False,
"device": None,
"items": {},
},
"epoch_loop._results": {
"batch": None,
"batch_size": None,
"dataloader_idx": None,
"training": True,
"device": None,
"items": {},
Expand All @@ -109,6 +111,7 @@ def test_loops_state_dict_structure():
"_results": {
"batch": None,
"batch_size": None,
"dataloader_idx": None,
"training": False,
"device": None,
"items": {},
Expand All @@ -126,6 +129,7 @@ def test_loops_state_dict_structure():
"_results": {
"batch": None,
"batch_size": None,
"dataloader_idx": None,
"training": False,
"device": None,
"items": {},
Expand Down
Loading