From ad44300988996aa37b1cd202bf767e256835f0db Mon Sep 17 00:00:00 2001 From: Shion Date: Wed, 4 Oct 2023 21:29:52 +0900 Subject: [PATCH 1/2] use new update called --- .../pytorch/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index a097ae93d8038..f28a41f527dd1 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -265,7 +265,7 @@ def _wrap_compute(self, compute: Any) -> Any: # Override to avoid syncing - we handle it ourselves. @wraps(compute) def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: - if not self._update_called: + if not self.update_called: rank_zero_warn( f"The ``compute`` method of metric {self.__class__.__name__}" " was called before the ``update`` method which may lead to errors," From 74780e4a065afd13fcae14800e3d3f7b8465bf0d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 6 Oct 2023 13:28:30 +0200 Subject: [PATCH 2/2] add the switch --- .../pytorch/trainer/connectors/logger_connector/result.py | 4 +++- src/lightning/pytorch/utilities/imports.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index f28a41f527dd1..26bddea1b211e 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -26,6 +26,7 @@ from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_0 from lightning.pytorch.utilities.data import extract_batch_size from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0 from lightning.pytorch.utilities.memory import recursive_detach from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn from lightning.pytorch.utilities.warnings import PossibleUserWarning @@ -265,7 +266,8 @@ def _wrap_compute(self, compute: Any) -> Any: # Override to avoid syncing - we handle it ourselves. @wraps(compute) def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: - if not self.update_called: + update_called = self.update_called if _TORCHMETRICS_GREATER_EQUAL_1_0_0 else self._update_called + if not update_called: rank_zero_warn( f"The ``compute`` method of metric {self.__class__.__name__}" " was called before the ``update`` method which may lead to errors," diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index 766eb9956a619..159b0b7758644 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -21,6 +21,7 @@ _PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11) _TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1") _TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task +_TORCHMETRICS_GREATER_EQUAL_1_0_0 = RequirementCache("torchmetrics>=1.0.0") _OMEGACONF_AVAILABLE = package_available("omegaconf") _TORCHVISION_AVAILABLE = RequirementCache("torchvision")