Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Disabled batch_size extraction for torchmetric instances because they accumulate the metrics internally ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815))


- Fixed an issue to return the results for each dataloader separately instead of duplicating them for each ([#10810](https://github.com/PyTorchLightning/pytorch-lightning/pull/10810))


- Improved exception message if `rich` version is less than `10.2.2` ([#10839](https://github.com/PyTorchLightning/pytorch-lightning/pull/10839))


Expand All @@ -231,7 +234,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug that caused incorrect batch indices to be passed to the `BasePredictionWriter` hooks when using a dataloader with `num_workers > 0` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870))



-


Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,8 @@ def log(
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
add_dataloader_idx=add_dataloader_idx,
dataloader_idx=self._current_dataloader_idx,
batch_size=batch_size,
sync_dist=sync_dist and distributed_available(),
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,20 @@ def update_eval_step_metrics(self) -> None:
# increment the step even if nothing was logged
self._increment_eval_log_step()

@staticmethod
def _filter_metrics_for_dataloader(
dl_idx: int, metrics: _OUT_DICT, metric_prefix: str = "dataloader_idx"
) -> _OUT_DICT:
return {k: v for k, v in metrics.items() if metric_prefix not in k or k.endswith(f"{metric_prefix}_{dl_idx}")}

def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None:
def _prepare_eval_loop_results(self) -> None:
if self.trainer.sanity_checking:
return

on_step = not self._epoch_end_reached
num_dataloaders = self.trainer._evaluation_loop.num_dataloaders
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
for dl_idx in range(self.trainer._evaluation_loop.num_dataloaders):
# remove callback metrics that don't belong to this dataloader
callback_metrics = self._filter_metrics_for_dataloader(dl_idx, metrics)
assert self.trainer._evaluation_loop._results is not None
for dl_idx in range(num_dataloaders):
metrics = self.trainer._evaluation_loop._results.metrics(
on_step, dataloader_idx=dl_idx if num_dataloaders > 1 else None
)
callback_metrics = metrics["callback"]

if has_been_initialized:
self.eval_loop_results[dl_idx].update(callback_metrics)
else:
Expand All @@ -183,7 +182,7 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:
# log all the metrics as a single dict
self.log_metrics(metrics["log"])

self._prepare_eval_loop_results(metrics["callback"])
self._prepare_eval_loop_results()

# log results of evaluation
if (
Expand Down
20 changes: 14 additions & 6 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class _Metadata:
on_epoch: bool = True
reduce_fx: Callable = torch.mean
enable_graph: bool = False
add_dataloader_idx: bool = True
dataloader_idx: Optional[int] = None
metric_attribute: Optional[str] = None
_sync: Optional[_Sync] = None
Expand Down Expand Up @@ -434,6 +435,7 @@ def log(
sync_dist: bool = False,
sync_dist_fn: Callable = _Sync.no_op,
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
dataloader_idx: Optional[int] = None,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
Expand All @@ -451,7 +453,7 @@ def log(
# storage key
key = f"{fx}.{name}"
# add dataloader_suffix to both key and fx
if dataloader_idx is not None:
if add_dataloader_idx and dataloader_idx is not None:
key += f".{dataloader_idx}"
fx += f".{dataloader_idx}"

Expand All @@ -464,6 +466,7 @@ def log(
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
add_dataloader_idx=add_dataloader_idx,
dataloader_idx=dataloader_idx,
metric_attribute=metric_attribute,
)
Expand Down Expand Up @@ -522,24 +525,29 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten
return cache.detach()
return cache

def valid_items(self) -> Generator:
def valid_items(self, dataloader_idx: Optional[int] = None) -> Generator:
"""This function is used to iterate over current valid metrics."""
return ((k, v) for k, v in self.items() if not (isinstance(v, ResultMetric) and v.has_reset))
return (
(k, v)
for k, v in self.items()
if not (isinstance(v, ResultMetric) and v.has_reset) and (dataloader_idx in (None, v.meta.dataloader_idx))
)

def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]:
name = result_metric.meta.name
forked_name = result_metric.meta.forked_name(on_step)
add_dataloader_idx = result_metric.meta.add_dataloader_idx
dl_idx = result_metric.meta.dataloader_idx
if dl_idx is not None:
if add_dataloader_idx and dl_idx is not None:
dataloader_suffix = self.DATALOADER_SUFFIX.format(dl_idx)
name += dataloader_suffix
forked_name += dataloader_suffix
return name, forked_name

def metrics(self, on_step: bool) -> _METRICS:
def metrics(self, on_step: bool, dataloader_idx: Optional[int] = None) -> _METRICS:
metrics = _METRICS(callback={}, log={}, pbar={})

for _, result_metric in self.valid_items():
for _, result_metric in self.valid_items(dataloader_idx):

# extract forward_cache or computed from the ResultMetric. ignore when the output is None
value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
" The best model of the previous `fit` call will be used."
f" You can pass `{fn}(ckpt_path='best')` to use and best model"
" checkpoint and avoid this warning or"
" `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model."
" `ckpt_path=trainer.checkpoint_callback.last_model_path` to use the last model."
)
ckpt_path = "best"

Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/test_ddp_spawn_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def on_predict_start(self) -> None:
assert isinstance(self.trainer.model, LightningModule)


@RunIf(skip_windows=True, skip_49370=True)
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
def test_ddp_spawn_configure_ddp(tmpdir):
"""Tests with ddp spawn plugin."""
trainer = Trainer(default_root_dir=tmpdir, num_processes=2, strategy="ddp_spawn", fast_dev_run=True)
Expand Down
69 changes: 42 additions & 27 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -676,32 +675,6 @@ def val_dataloader(self):
trainer.fit(model)


@pytest.mark.parametrize(
["kwargs", "expected"],
[
({"dl_idx": 0, "metrics": {"acc": 123}}, {"acc": 123}),
(
{"dl_idx": 0, "metrics": {"acc/dataloader_idx_0": 123, "acc/dataloader_idx_1": 321}},
{"acc/dataloader_idx_0": 123},
),
(
{"dl_idx": 10, "metrics": {"acc/dataloader_idx_1": 123, "acc/dataloader_idx_10": 321}},
{"acc/dataloader_idx_10": 321},
),
(
{"dl_idx": 3, "metrics": {"top_3_acc/dataloader_idx_0": 123, "top_3_acc/dataloader_idx_3": 321}},
{"top_3_acc/dataloader_idx_3": 321},
),
# theoretical case, as `/dataloader_idx_3` would have been added
({"dl_idx": 3, "metrics": {"top_3_acc": 123}}, {"top_3_acc": 123}),
],
)
def test_filter_metrics_for_dataloader(kwargs, expected):
"""Logged metrics should only include metrics from the concerned dataloader."""
actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs)
assert actual == expected


@RunIf(min_gpus=1)
def test_evaluation_move_metrics_to_cpu_and_outputs(tmpdir):
class TestModel(BoringModel):
Expand All @@ -723,3 +696,45 @@ def validation_epoch_end(self, outputs):
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=2, move_metrics_to_cpu=True, gpus=1)
trainer.validate(model, verbose=False)


def test_logging_results_with_no_dataloader_idx(tmpdir):
num_dataloaders = 2
log_common_same_val = {"test_log_common": 789}
log_common_diff_val = "test_log_common_diff_value"
log_key_no_dl_idx = "test_log_no_dl_idx_{}"
log_key_dl0 = {"test_log_a_class": 123}
log_key_dl1 = {"test_log_b_class": 456}

class CustomBoringModel(BoringModel):
def test_step(self, batch, batch_idx, dataloader_idx):
self.log_dict(log_common_same_val)
self.log(log_common_diff_val, dataloader_idx + 1)
self.log(
log_key_no_dl_idx.format(dataloader_idx),
321 * (dataloader_idx + 1),
add_dataloader_idx=False,
)
self.log_dict(log_key_dl0 if dataloader_idx == 0 else log_key_dl1, add_dataloader_idx=False)

def test_dataloader(self):
return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)]

model = CustomBoringModel()
model.test_epoch_end = None
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
results = trainer.test(model)

assert len(results) == num_dataloaders
assert results[0] == {
"test_log_common/dataloader_idx_0": 789.0,
"test_log_common_diff_value/dataloader_idx_0": 1.0,
"test_log_no_dl_idx_0": 321,
"test_log_a_class": 123.0,
}
assert results[1] == {
"test_log_common/dataloader_idx_1": 789.0,
"test_log_common_diff_value/dataloader_idx_1": 2.0,
"test_log_no_dl_idx_1": 321 * 2,
"test_log_b_class": 456.0,
}