Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor StrEnum.from_str #102

Merged
merged 12 commits into from
Feb 14, 2023
55 changes: 21 additions & 34 deletions src/lightning_utilities/core/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
#
import warnings
from enum import Enum
from typing import Optional
from typing import List, Optional

from typing_extensions import Literal

Expand All @@ -19,46 +18,29 @@ class StrEnum(str, Enum):
True
>>> MySE.from_str("t-2", source="value") == MySE.t2
True
>>> MySE.from_str("t-2", source="value")
<MySE.t2: 'T-2'>
>>> MySE.from_str("t-3", source="any")
Traceback (most recent call last):
...
ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3.
"""

@classmethod
def from_str(
cls, value: str, source: Literal["key", "value", "any"] = "key", strict: bool = False
) -> Optional["StrEnum"]:
"""Create StrEnum from a sting matching the key or value.
Borda marked this conversation as resolved.
Show resolved Hide resolved
def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> "StrEnum":
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""Create ``StrEnum`` from a string matching the key or value.

Args:
value: matching string
source: compare with:

- ``"key"``: validates only with Enum keys, typical alphanumeric with "_"
- ``"value"``: validates only with Enum values, could be any string
- ``"key"``: validates with any key or value, but key has priority
Borda marked this conversation as resolved.
Show resolved Hide resolved

strict: allow not matching string and returns None; if false raises exceptions
- ``"key"``: validates only from the enum keys, typical alphanumeric with "_"
Borda marked this conversation as resolved.
Show resolved Hide resolved
- ``"value"``: validates only from the values, could be any string
- ``"any"``: validates with any key or value, but key has priority

Raises:
ValueError:
if requested string does not match any option based on selected source and use ``"strict=True"``
UserWarning:
if requested string does not match any option based on selected source and use ``"strict=False"``

Example:
>>> class MySE(StrEnum):
... t1 = "T-1"
... t2 = "T-2"
>>> MySE.from_str("t-1", source="key")
>>> MySE.from_str("t-2", source="value")
Borda marked this conversation as resolved.
Show resolved Hide resolved
<MySE.t2: 'T-2'>
>>> MySE.from_str("t-3", source="any", strict=True)
Traceback (most recent call last):
...
ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3.
if requested string does not match any option based on selected source.
"""
allowed = cls._allowed_matches(source)
if strict and not any(enum_.lower() == value.lower() for enum_ in allowed):
raise ValueError(f"Invalid match: expected one of {allowed}, but got {value}.")

if source in ("key", "any"):
for enum_key in cls.__members__.keys():
if enum_key.lower() == value.lower():
Expand All @@ -67,12 +49,17 @@ def from_str(
for enum_key, enum_val in cls.__members__.items():
if enum_val == value:
return cls[enum_key]
raise ValueError(f"Invalid match: expected one of {cls._allowed_matches(source)}, but got {value}.")

warnings.warn(UserWarning(f"Invalid string: expected one of {allowed}, but got {value}."))
return None
@classmethod
def maybe_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Optional["StrEnum"]:
Borda marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved
try:
return cls.from_str(value, source)
except ValueError:
return
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def _allowed_matches(cls, source: str) -> list:
def _allowed_matches(cls, source: str) -> List[str]:
keys, vals = [], []
for enum_key, enum_val in cls.__members__.items():
keys.append(enum_key)
Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/core/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class MyEnum(StrEnum):
T2 = "t:2"

assert MyEnum.from_str("T1", source="key")
assert MyEnum.from_str("T1", source="value") is None
assert MyEnum.maybe_from_str("T1", source="value") is None
assert MyEnum.from_str("T1", source="any")

assert MyEnum.from_str("T:2", source="key") is None
assert MyEnum.maybe_from_str("T:2", source="key") is None
assert MyEnum.from_str("T:2", source="value")
assert MyEnum.from_str("T:2", source="any")