Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor StrEnum.from_str #102

Merged
merged 12 commits into from
Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed `StrEnum.from_str` with source as key ([#99](https://github.com/Lightning-AI/utilities/pull/99))
- Fixed `StrEnum.from_str` with source as key (
[#99](https://github.com/Lightning-AI/utilities/pull/99),
[#102](https://github.com/Lightning-AI/utilities/pull/102)
)


## [0.6.0] - 2023-01-23
Expand Down
54 changes: 23 additions & 31 deletions src/lightning_utilities/core/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
import warnings
from enum import Enum
from typing import Optional
from typing import List, Optional

from typing_extensions import Literal

Expand All @@ -19,46 +19,30 @@ class StrEnum(str, Enum):
True
>>> MySE.from_str("t-2", source="value") == MySE.t2
True
>>> MySE.from_str("t-2", source="value")
<MySE.t2: 'T-2'>
>>> MySE.from_str("t-3", source="any")
Traceback (most recent call last):
...
ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3.
"""

@classmethod
def from_str(
cls, value: str, source: Literal["key", "value", "any"] = "key", strict: bool = False
) -> Optional["StrEnum"]:
"""Create StrEnum from a sting matching the key or value.
Borda marked this conversation as resolved.
Show resolved Hide resolved
def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> "StrEnum":
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""Create ``StrEnum`` from a string matching the key or value.

Args:
value: matching string
source: compare with:

- ``"key"``: validates only with Enum keys, typical alphanumeric with "_"
- ``"value"``: validates only with Enum values, could be any string
- ``"key"``: validates with any key or value, but key has priority
Borda marked this conversation as resolved.
Show resolved Hide resolved

strict: allow not matching string and returns None; if false raises exceptions
- ``"key"``: validates only from the enum keys, typical alphanumeric with "_"
Borda marked this conversation as resolved.
Show resolved Hide resolved
- ``"value"``: validates only from the values, could be any string
- ``"any"``: validates with any key or value, but key has priority

Raises:
ValueError:
if requested string does not match any option based on selected source and use ``"strict=True"``
UserWarning:
if requested string does not match any option based on selected source and use ``"strict=False"``

Example:
>>> class MySE(StrEnum):
... t1 = "T-1"
... t2 = "T-2"
>>> MySE.from_str("t-1", source="key")
>>> MySE.from_str("t-2", source="value")
Borda marked this conversation as resolved.
Show resolved Hide resolved
<MySE.t2: 'T-2'>
>>> MySE.from_str("t-3", source="any", strict=True)
Traceback (most recent call last):
...
ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3.
if requested string does not match any option based on selected source.
"""
allowed = cls._allowed_matches(source)
if strict and not any(enum_.lower() == value.lower() for enum_ in allowed):
raise ValueError(f"Invalid match: expected one of {allowed}, but got {value}.")

if source in ("key", "any"):
for enum_key in cls.__members__.keys():
if enum_key.lower() == value.lower():
Expand All @@ -67,12 +51,20 @@ def from_str(
for enum_key, enum_val in cls.__members__.items():
if enum_val == value:
return cls[enum_key]
raise ValueError(f"Invalid match: expected one of {cls._allowed_matches(source)}, but got {value}.")

warnings.warn(UserWarning(f"Invalid string: expected one of {allowed}, but got {value}."))
@classmethod
def try_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Optional["StrEnum"]:
try:
return cls.from_str(value, source)
except ValueError:
warnings.warn(
UserWarning(f"Invalid string: expected one of {cls._allowed_matches(source)}, but got {value}.")
)
return None

@classmethod
def _allowed_matches(cls, source: str) -> list:
def _allowed_matches(cls, source: str) -> List[str]:
keys, vals = [], []
for enum_key, enum_val in cls.__members__.items():
keys.append(enum_key)
Expand Down
Empty file removed tests/unittests/core/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion tests/unittests/core/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, ClassVar, List, Optional

import pytest
from unittests.core.mocks import torch
from unittests.mocks import torch

from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/core/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class MyEnum(StrEnum):
T2 = "t:2"

assert MyEnum.from_str("T1", source="key")
assert MyEnum.from_str("T1", source="value") is None
assert MyEnum.try_from_str("T1", source="value") is None
assert MyEnum.from_str("T1", source="any")

assert MyEnum.from_str("T:2", source="key") is None
assert MyEnum.try_from_str("T:2", source="key") is None
assert MyEnum.from_str("T:2", source="value")
assert MyEnum.from_str("T:2", source="any")
File renamed without changes.