diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index a097ae93d8038..26bddea1b211e 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -26,6 +26,7 @@ from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_0 from lightning.pytorch.utilities.data import extract_batch_size from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0 from lightning.pytorch.utilities.memory import recursive_detach from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn from lightning.pytorch.utilities.warnings import PossibleUserWarning @@ -265,7 +266,8 @@ def _wrap_compute(self, compute: Any) -> Any: # Override to avoid syncing - we handle it ourselves. @wraps(compute) def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]: - if not self._update_called: + update_called = self.update_called if _TORCHMETRICS_GREATER_EQUAL_1_0_0 else self._update_called + if not update_called: rank_zero_warn( f"The ``compute`` method of metric {self.__class__.__name__}" " was called before the ``update`` method which may lead to errors," diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index 766eb9956a619..159b0b7758644 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -21,6 +21,7 @@ _PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11) _TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1") _TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task +_TORCHMETRICS_GREATER_EQUAL_1_0_0 = RequirementCache("torchmetrics>=1.0.0") _OMEGACONF_AVAILABLE = package_available("omegaconf") _TORCHVISION_AVAILABLE = RequirementCache("torchvision")