Skip to content

Commit

Permalink
or operate
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Sep 23, 2023
1 parent 681d01e commit 4e531fc
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 11 deletions.
14 changes: 12 additions & 2 deletions nepattern/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)

from tarina import DateParser, Empty, lang
from typing_extensions import Self
from typing_extensions import Self, Unpack, TypeVarTuple

from .core import BasePattern, MatchMode, ResultFlag, ValidateResult
from .exception import MatchFailed
Expand All @@ -30,6 +30,8 @@
TOrigin = TypeVar("TOrigin")
TDefault = TypeVar("TDefault")
_T = TypeVar("_T")
_T1 = TypeVar("_T1")
Ts = TypeVarTuple("Ts")


class DirectPattern(BasePattern[TOrigin, TOrigin]):
Expand Down Expand Up @@ -173,7 +175,7 @@ class UnionPattern(BasePattern[Any, _T]):

__slots__ = ("base", "optional", "for_validate", "for_equal")

def __init__(self, base: Iterable[_T | BasePattern[Any, _T]]):
def __init__(self, base: Iterable[BasePattern[Any, _T] | _T]):
self.base = list(base)
self.optional = False
self.for_validate = []
Expand Down Expand Up @@ -201,6 +203,12 @@ def match(self, text: Any):
lang.require("nepattern", "content_error").format(target=text, expected=self.alias)
)
return text

@classmethod
def _(cls, *types: type[_T1]) -> UnionPattern[_T1]:
from .main import parser

return cls([parser(i) for i in types]) # type: ignore

def __calc_repr__(self):
return "|".join(repr(a) for a in (*self.for_validate, *self.for_equal))
Expand All @@ -221,6 +229,8 @@ def suffixed(self) -> Self:
+ [parser(eq).suffixed() if isinstance(eq, str) else eq for eq in self.for_equal], # type: ignore
)

def __or__(self, other: BasePattern[Any, _T1]) -> UnionPattern[Union[_T, _T1]]:
return UnionPattern([*self.base, other]) # type: ignore

TSeq = TypeVar("TSeq", list, tuple, set)

Expand Down
15 changes: 10 additions & 5 deletions nepattern/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ class ResultFlag(str, Enum):

T = TypeVar("T")
TInput = TypeVar("TInput")
TInput1 = TypeVar("TInput1")
TInput2 = TypeVar("TInput2")
TInput3 = TypeVar("TInput3")
TOrigin = TypeVar("TOrigin")
TVOrigin = TypeVar("TVOrigin")
TDefault = TypeVar("TDefault")
Expand Down Expand Up @@ -348,14 +345,14 @@ def prefixed(self):
cp_self = deepcopy(self)
if self.mode in (MatchMode.REGEX_MATCH, MatchMode.REGEX_CONVERT):
cp_self.regex_pattern = re.compile(f"^{self.pattern}")
return cp_self
return cp_self # pragma: no cover

def suffixed(self):
"""让表达式能在某些场景下实现后缀匹配; 返回自身的拷贝"""
cp_self = deepcopy(self)
if self.mode in (MatchMode.REGEX_MATCH, MatchMode.REGEX_CONVERT):
cp_self.regex_pattern = re.compile(f"{self.pattern}$")
return cp_self
return cp_self # pragma: no cover

def validate(self, input_: Any, default: TDefault | Empty = Empty) -> ValidateResult[TOrigin | TDefault, ResultFlag]: # type: ignore
"""
Expand Down Expand Up @@ -386,5 +383,13 @@ def __matmul__(self, other) -> Self: # pragma: no cover
self.alias = other
return self

def __or__(self, other):
from .base import UnionPattern

if isinstance(other, BasePattern):
return UnionPattern([self, other]) # type: ignore
raise TypeError( # pragma: no cover
f"unsupported operand type(s) for |: 'BasePattern' and '{other.__class__.__name__}'"
)

__all__ = ["MatchMode", "BasePattern", "ValidateResult", "TOrigin", "ResultFlag"]
2 changes: 2 additions & 0 deletions nepattern/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ TInput1 = TypeVar("TInput1")
TInput2 = TypeVar("TInput2")
TInput3 = TypeVar("TInput3")
TOrigin = TypeVar("TOrigin")
TOrigin1 = TypeVar("TOrigin1")
TVOrigin = TypeVar("TVOrigin")
TDefault = TypeVar("TDefault")
TVRF = TypeVar("TVRF", bound=ResultFlag)
Expand Down Expand Up @@ -394,3 +395,4 @@ class BasePattern(Generic[TOrigin, TInput]):
) -> ValidateResult[T, Literal[ResultFlag.VALID]] | ValidateResult[T, Literal[ResultFlag.ERROR]]: ...
def __rmatmul__(self, other) -> Self: ...
def __matmul__(self, other) -> Self: ...
def __or__(self, other: BasePattern[TOrigin1, TInput1]) -> BasePattern[TOrigin1 | TOrigin, TInput1 | TInput]: ...
11 changes: 7 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ def test_pattern_accepts():
assert pat6_2.validate(123.123).value() == 123.123
assert pat6_2.validate(b'123').value() == b'123'
print(pat6_2)
pat6_3 = BasePattern(mode=MatchMode.KEEP, addition_accepts=NUMBER)
pat6_3 = BasePattern(mode=MatchMode.KEEP, addition_accepts=INTEGER | BOOLEAN)
assert pat6_3.validate(123).value() == 123
assert pat6_3.validate(123.123).value() == 123.123
assert pat6_3.validate(True).value() is True
assert pat6_3.validate(b'123').failed


Expand Down Expand Up @@ -290,7 +290,7 @@ def __setitem__(self):


def test_union_pattern():
from typing import Union, Optional
from typing import Union, Optional, List

pat12 = parser(Union[int, bool])
assert pat12.validate(123).success
Expand All @@ -304,6 +304,9 @@ def test_union_pattern():
assert pat12_2.validate("abc").success
assert pat12_2.validate("bca").failed
print(pat12, pat12_1, pat12_2)
pat12_3 = UnionPattern._(List[bool], int)
pat12_4 = pat12_2 | pat12_3
print(pat12_3, pat12_4)


def test_seq_pattern():
Expand Down Expand Up @@ -473,7 +476,7 @@ def test_regex_pattern():
pat18_1 = parser(r"re:(\d+)") # str starts with "re:" will convert to BasePattern instead of RegexPattern
assert pat18_1.validate("1234").value() == '1234'
pat18_2 = parser(r"rep:(\d+)") # str starts with "rep:" will convert to RegexPattern
assert pat18_2.validate("1234").value().groups() == ('1234',)
assert pat18_2.validate("1234").value().groups() == ('1234',) # type: ignore
pat18_3 = parser(compile(r"(\d+)")) # re.Pattern will convert to RegexPattern
assert pat18_3.validate("1234").value().groups() == ('1234',)

Expand Down

0 comments on commit 4e531fc

Please sign in to comment.