Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ class LightningEnum(str, Enum):

@classmethod
def from_str(cls, value: str) -> LightningEnum | None:
statuses = [status for status in dir(cls) if not status.startswith("_")]
statuses = cls.__members__.keys()
for st in statuses:
if st.lower() == value.lower():
return getattr(cls, st)
return cls[st]
return None

def __eq__(self, other: object) -> bool:
Expand All @@ -43,21 +43,21 @@ def __hash__(self) -> int:
return hash(self.value.lower())


class _OnAccessEnumMeta(EnumMeta):
"""Enum with a hook to run a function whenever a member is accessed.
class _DeprecatedEnumMeta(EnumMeta):
"""Enum that calls `deprecate()` whenever a member is accessed.

Adapted from:
https://www.buzzphp.com/posts/how-do-i-detect-and-invoke-a-function-when-a-python-enum-member-is-accessed
Adapted from: https://stackoverflow.com/a/62309159/208880
"""

def __getattribute__(cls, name: str) -> Any:
obj = super().__getattribute__(name)
if isinstance(obj, Enum):
# ignore __dunder__ names -- prevents potential recursion errors
if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum):
obj.deprecate()
return obj

def __getitem__(cls, name: str) -> Any:
member: _OnAccessEnumMeta = super().__getitem__(name)
member: _DeprecatedEnumMeta = super().__getitem__(name)
member.deprecate()
return member

Expand All @@ -68,6 +68,12 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
return obj


class _DeprecatedEnum(LightningEnum, metaclass=_DeprecatedEnumMeta):
"""_DeprecatedEnum calls an enum's `deprecate()` method on member access."""

pass


class AMPType(LightningEnum):
"""Type of Automatic Mixed Precission used for training.

Expand Down Expand Up @@ -104,7 +110,7 @@ def supported_types() -> list[str]:
return [x.value for x in PrecisionType]


class DistributedType(LightningEnum, metaclass=_OnAccessEnumMeta):
class DistributedType(_DeprecatedEnum):
"""Define type of training strategy.

Deprecated since v1.6.0 and will be removed in v1.8.0.
Expand Down Expand Up @@ -145,7 +151,7 @@ def deprecate(self) -> None:
)


class DeviceType(LightningEnum, metaclass=_OnAccessEnumMeta):
class DeviceType(_DeprecatedEnum):
"""Define Device type by its nature - accelerators.

Deprecated since v1.6.0 and will be removed in v1.8.0.
Expand Down