diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 2189063ee20ef..b3041a72cac3b 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -27,10 +27,10 @@ class LightningEnum(str, Enum): @classmethod def from_str(cls, value: str) -> LightningEnum | None: - statuses = [status for status in dir(cls) if not status.startswith("_")] + statuses = cls.__members__.keys() for st in statuses: if st.lower() == value.lower(): - return getattr(cls, st) + return cls[st] return None def __eq__(self, other: object) -> bool: @@ -43,21 +43,21 @@ def __hash__(self) -> int: return hash(self.value.lower()) -class _OnAccessEnumMeta(EnumMeta): - """Enum with a hook to run a function whenever a member is accessed. +class _DeprecatedEnumMeta(EnumMeta): + """Enum that calls `deprecate()` whenever a member is accessed. - Adapted from: - https://www.buzzphp.com/posts/how-do-i-detect-and-invoke-a-function-when-a-python-enum-member-is-accessed + Adapted from: https://stackoverflow.com/a/62309159/208880 """ def __getattribute__(cls, name: str) -> Any: obj = super().__getattribute__(name) - if isinstance(obj, Enum): + # ignore __dunder__ names -- prevents potential recursion errors + if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum): obj.deprecate() return obj def __getitem__(cls, name: str) -> Any: - member: _OnAccessEnumMeta = super().__getitem__(name) + member: _DeprecatedEnumMeta = super().__getitem__(name) member.deprecate() return member @@ -68,6 +68,12 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: return obj +class _DeprecatedEnum(LightningEnum, metaclass=_DeprecatedEnumMeta): + """_DeprecatedEnum calls an enum's `deprecate()` method on member access.""" + + pass + + class AMPType(LightningEnum): """Type of Automatic Mixed Precission used for training. @@ -104,7 +110,7 @@ def supported_types() -> list[str]: return [x.value for x in PrecisionType] -class DistributedType(LightningEnum, metaclass=_OnAccessEnumMeta): +class DistributedType(_DeprecatedEnum): """Define type of training strategy. Deprecated since v1.6.0 and will be removed in v1.8.0. @@ -145,7 +151,7 @@ def deprecate(self) -> None: ) -class DeviceType(LightningEnum, metaclass=_OnAccessEnumMeta): +class DeviceType(_DeprecatedEnum): """Define Device type by its nature - accelerators. Deprecated since v1.6.0 and will be removed in v1.8.0.