Skip to content

Commit

Permalink
💥 No.3 TypeVar on BasePattern
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Apr 21, 2024
1 parent a95264a commit c413ef4
Show file tree
Hide file tree
Showing 8 changed files with 749 additions and 265 deletions.
2 changes: 1 addition & 1 deletion nepattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .base import URL as URL
from .base import UnionPattern as UnionPattern
from .base import WIDE_BOOLEAN as WIDE_BOOLEAN
from .base import pipe as pipe
from .context import Patterns as Patterns
from .context import all_patterns as all_patterns
from .context import create_local_patterns as create_local_patterns
Expand All @@ -45,7 +46,6 @@
from .core import MatchMode as MatchMode
from .core import ValidateResult as ValidateResult
from .exception import MatchFailed as MatchFailed
from .main import Bind as Bind
from .main import parser as parser
from .util import RawStr as RawStr
from .util import TPattern as TPattern
Expand Down
112 changes: 66 additions & 46 deletions nepattern/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from enum import Enum
from decimal import Decimal
from datetime import datetime
from pathlib import Path
import re
Expand All @@ -24,7 +23,7 @@

from tarina import DateParser, Empty, lang

from .core import BasePattern, MatchMode, ResultFlag, ValidateResult
from .core import BasePattern, MatchMode, ResultFlag, ValidateResult, _MATCHES, TInput, TOrigin, TMM
from .exception import MatchFailed
from .util import TPattern

Expand All @@ -34,7 +33,7 @@
_T1 = TypeVar("_T1")


class DirectPattern(BasePattern[TOrigin, TOrigin]):
class DirectPattern(BasePattern[TOrigin, TOrigin, Literal[MatchMode.KEEP]]):
"""直接判断"""

__slots__ = ("target",)
Expand Down Expand Up @@ -82,7 +81,7 @@ def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, DirectPattern) and self.target == other.target


class DirectTypePattern(BasePattern[TOrigin, TOrigin]):
class DirectTypePattern(BasePattern[TOrigin, TOrigin, Literal[MatchMode.KEEP]]):
"""直接类型判断"""

__slots__ = ("target",)
Expand Down Expand Up @@ -134,7 +133,7 @@ def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, DirectTypePattern) and self.target == other.target


class RegexPattern(BasePattern[Match[str], str]):
class RegexPattern(BasePattern[Match[str], str, Literal[MatchMode.REGEX_MATCH]]):
"""针对正则的特化匹配,支持正则组"""

def __init__(self, pattern: str | TPattern, alias: str | None = None):
Expand All @@ -159,7 +158,7 @@ def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, RegexPattern) and self.pattern == other.pattern


class UnionPattern(BasePattern[Any, _T]):
class UnionPattern(BasePattern[Any, _T, Literal[MatchMode.KEEP]]):
"""多类型参数的匹配"""

optional: bool
Expand All @@ -168,7 +167,7 @@ class UnionPattern(BasePattern[Any, _T]):

__slots__ = ("base", "optional", "for_validate", "for_equal")

def __init__(self, base: Iterable[BasePattern[Any, _T] | _T]):
def __init__(self, base: Iterable[BasePattern[Any, _T, Any] | _T]):
self.base = list(base)
self.optional = False
self.for_validate = []
Expand All @@ -183,19 +182,19 @@ def __init__(self, base: Iterable[BasePattern[Any, _T] | _T]):
else:
self.for_equal.append(arg)
alias_content = "|".join([repr(a) for a in self.for_validate] + [repr(a) for a in self.for_equal])
super().__init__(mode=MatchMode.KEEP, origin=str, alias=alias_content)
super().__init__(mode=MatchMode.KEEP, origin=Any, alias=alias_content)

def match(self, text: Any):
if not text:
def match(self, input_: Any):
if not input_:
text = None
if text not in self.for_equal:
if input_ not in self.for_equal:
for pat in self.for_validate:
if (res := pat.validate(text)).success:
if (res := pat.validate(input_)).success:
return res.value()
raise MatchFailed(
lang.require("nepattern", "content_error").format(target=text, expected=self.alias)
lang.require("nepattern", "content_error").format(target=input_, expected=self.alias)
)
return text
return input_

@classmethod
def _(cls, *types: type[_T1]) -> UnionPattern[_T1]:
Expand All @@ -206,7 +205,7 @@ def _(cls, *types: type[_T1]) -> UnionPattern[_T1]:
def __calc_repr__(self):
return "|".join(repr(a) for a in (*self.for_validate, *self.for_equal))

def __or__(self, other: BasePattern[Any, _T1]) -> UnionPattern[Union[_T, _T1]]:
def __or__(self, other: BasePattern[Any, _T1, Any]) -> UnionPattern[Union[_T, _T1]]:
return UnionPattern([*self.base, other]) # type: ignore

def __calc_eq__(self, other): # pragma: no cover
Expand All @@ -225,7 +224,7 @@ class IterMode(str, Enum):
TIterMode = TypeVar("TIterMode", bound=IterMode)


class SequencePattern(BasePattern[TSeq, Union[str, TSeq]], Generic[TSeq, TIterMode]):
class SequencePattern(BasePattern[TSeq, Union[str, TSeq], Literal[MatchMode.REGEX_CONVERT]], Generic[TSeq, TIterMode]):
"""匹配列表或者元组或者集合"""

base: BasePattern
Expand All @@ -243,9 +242,10 @@ def __init__(self, form: type[TSeq], base: BasePattern, mode: TIterMode = IterMo
else:
raise ValueError(lang.require("nepattern", "sequence_form_error").format(target=str(form)))
self.converter = lambda _, x: x[1] # type: ignore
self._match = _MATCHES[MatchMode.REGEX_CONVERT](self)

def match(self, text: Any):
_res = self._MATCHES[MatchMode.REGEX_CONVERT](self, text) # type: ignore
def match(self, input_: Any):
_res = self._match(self, input_) # type: ignore
_max = 0
success: list[tuple[int, Any]] = []
fail: list[tuple[int, MatchFailed]] = []
Expand Down Expand Up @@ -276,19 +276,19 @@ def __calc_repr__(self):


class MappingPattern(
BasePattern[Dict[TKey, TVal], Union[str, Dict[TKey, TVal]]],
BasePattern[Dict[TKey, TVal], Union[str, Dict[TKey, TVal]], Literal[MatchMode.REGEX_CONVERT]],
Generic[TKey, TVal, TIterMode],
):
"""匹配字典或者映射表"""

key: BasePattern[TKey, Any]
value: BasePattern[TVal, Any]
key: BasePattern[TKey, Any, Any]
value: BasePattern[TVal, Any, Any]
itermode: TIterMode

def __init__(
self,
arg_key: BasePattern[TKey, Any],
arg_value: BasePattern[TVal, Any],
arg_key: BasePattern[TKey, Any, Any],
arg_value: BasePattern[TVal, Any, Any],
mode: TIterMode = IterMode.ALL
):
self.key = arg_key
Expand All @@ -301,9 +301,10 @@ def __init__(
alias=f"dict[{self.key}, {self.value}]",
)
self.converter = lambda _, x: x[1]
self._match = _MATCHES[MatchMode.REGEX_CONVERT](self)

def match(self, input_: str | dict):
_res = self._MATCHES[MatchMode.REGEX_CONVERT](self, input_) # type: ignore
_res = self._match(self, input_) # type: ignore
success: list[tuple[int, Any, Any]] = []
fail: list[tuple[int, MatchFailed]] = []
_max = 0
Expand Down Expand Up @@ -346,7 +347,7 @@ def __calc_repr__(self):
_TSwtich = TypeVar("_TSwtich")


class SwitchPattern(BasePattern[_TCase, _TSwtich]):
class SwitchPattern(BasePattern[_TCase, _TSwtich, Literal[MatchMode.TYPE_CONVERT]]):
switch: dict[_TSwtich | ellipsis, _TCase]

__slots__ = ("switch",)
Expand All @@ -372,7 +373,7 @@ def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, SwitchPattern) and self.switch == other.switch


class ForwardRefPattern(BasePattern[Any, Any]):
class ForwardRefPattern(BasePattern[Any, Any, Literal[MatchMode.TYPE_CONVERT]]):
def __init__(self, ref: ForwardRef):
self.ref = ref
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=Any, alias=ref.__forward_arg__)
Expand All @@ -397,9 +398,9 @@ def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, ForwardRefPattern) and self.ref == other.ref


class AntiPattern(BasePattern[TOrigin, Any]):
def __init__(self, pattern: BasePattern[TOrigin, Any]):
self.base: BasePattern[TOrigin, Any] = pattern
class AntiPattern(BasePattern[TOrigin, Any, Literal[MatchMode.TYPE_CONVERT]]):
def __init__(self, pattern: BasePattern[TOrigin, Any, Any]):
self.base: BasePattern[TOrigin, Any, Any] = pattern
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=pattern.origin, alias=f"!{pattern}")

@overload
Expand Down Expand Up @@ -452,7 +453,7 @@ def __calc_eq__(self, other): # pragma: no cover
TInput = TypeVar("TInput")


class CustomMatchPattern(BasePattern[TOrigin, TInput]):
class CustomMatchPattern(BasePattern[TOrigin, TInput, Literal[MatchMode.TYPE_CONVERT]]):
def __init__(
self,
origin: type[TOrigin],
Expand All @@ -471,7 +472,7 @@ def __calc_eq__(self, other): # pragma: no cover


@final
class AnyPattern(BasePattern[Any, Any]):
class AnyPattern(BasePattern[Any, Any, Literal[MatchMode.KEEP]]):
def __init__(self):
super().__init__(mode=MatchMode.KEEP, origin=Any, alias="any")

Expand All @@ -487,7 +488,7 @@ def __calc_eq__(self, other): # pragma: no cover


@final
class AnyStrPattern(BasePattern[str, Any]):
class AnyStrPattern(BasePattern[str, Any, Literal[MatchMode.KEEP]]):
def __init__(self):
super().__init__(mode=MatchMode.KEEP, origin=str, alias="any_str")

Expand All @@ -503,9 +504,9 @@ def __calc_eq__(self, other): # pragma: no cover


@final
class StrPattern(BasePattern[str, Any]):
class StrPattern(BasePattern[str, str | bytes | bytearray, Literal[MatchMode.TYPE_CONVERT]]):
def __init__(self):
super().__init__(mode=MatchMode.KEEP, origin=str, alias="str")
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=str, accepts=Union[str, bytes, bytearray], alias="str")

def match(self, input_: Any) -> str:
if isinstance(input_, str):
Expand All @@ -525,13 +526,15 @@ def __calc_eq__(self, other): # pragma: no cover


@final
class BytesPattern(BasePattern[bytes, Any]):
class BytesPattern(BasePattern[bytes, str | bytes | bytearray, Literal[MatchMode.TYPE_CONVERT]]):
def __init__(self):
super().__init__(mode=MatchMode.KEEP, origin=bytes, alias="bytes")
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=bytes, accepts=Union[str, bytes, bytearray], alias="bytes")

def match(self, input_: Any) -> bytes:
if isinstance(input_, bytes):
return input_
elif isinstance(input_, bytearray):
return bytes(input_)
elif isinstance(input_, str):
return input_.encode()
raise MatchFailed(
Expand All @@ -548,7 +551,7 @@ def __calc_eq__(self, other): # pragma: no cover


@final
class IntPattern(BasePattern[int, Any]):
class IntPattern(BasePattern[int, Any, Literal[MatchMode.TYPE_CONVERT]]):
def __init__(self):
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=int, alias="int")

Expand All @@ -573,7 +576,7 @@ def __calc_eq__(self, other): # pragma: no cover


@final
class FloatPattern(BasePattern[float, Any]):
class FloatPattern(BasePattern[float, Any, Literal[MatchMode.TYPE_CONVERT]]):
def __init__(self):
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=float, alias="float")

Expand All @@ -597,11 +600,11 @@ def __calc_eq__(self, other): # pragma: no cover


@final
class NumberPattern(BasePattern[Union[int, float], Any]):
class NumberPattern(BasePattern[Union[int, float], Any, Literal[MatchMode.TYPE_CONVERT]]):
def __init__(self):
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=Union[int, float], alias="number")
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=Union[int, float], alias="number") # type: ignore

def match(self, input_: Any) -> float:
def match(self, input_: Any) -> int | float:
if isinstance(input_, (float, int)):
return input_
try:
Expand All @@ -621,7 +624,7 @@ def __calc_eq__(self, other): # pragma: no cover


@final
class BoolPattern(BasePattern[bool, Any]):
class BoolPattern(BasePattern[bool, bool | str | bytes, Literal[MatchMode.TYPE_CONVERT]]):
def __init__(self):
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=bool, alias="bool")

Expand Down Expand Up @@ -649,7 +652,7 @@ def __calc_eq__(self, other): # pragma: no cover


@final
class WideBoolPattern(BasePattern[bool, Any]):
class WideBoolPattern(BasePattern[bool, Any, Literal[MatchMode.TYPE_CONVERT]]):
def __init__(self):
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=bool, alias="bool")

Expand Down Expand Up @@ -708,7 +711,7 @@ def __calc_eq__(self, other): # pragma: no cover


@final
class HexPattern(BasePattern[int, str]):
class HexPattern(BasePattern[int, str, Literal[MatchMode.TYPE_CONVERT]]):
def __init__(self):
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=int, alias="hex", accepts=str)

Expand Down Expand Up @@ -738,7 +741,7 @@ def __calc_eq__(self, other): # pragma: no cover


@final
class DateTimePattern(BasePattern[datetime, Union[str, int, float]]):
class DateTimePattern(BasePattern[datetime, Union[str, int, float], Literal[MatchMode.TYPE_CONVERT]]):
def __init__(self):
super().__init__(
mode=MatchMode.TYPE_CONVERT, origin=datetime, alias="datetime", accepts=Union[str, int, float]
Expand All @@ -764,7 +767,7 @@ def __calc_eq__(self, other): # pragma: no cover


@final
class PathPattern(BasePattern[Path, Any]):
class PathPattern(BasePattern[Path, Any, Literal[MatchMode.TYPE_CONVERT]]):
def __init__(self):
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=Path, alias="path")

Expand Down Expand Up @@ -793,3 +796,20 @@ def __calc_eq__(self, other): # pragma: no cover
alias="filedata",
converter=lambda _, x: x.read_bytes() if x.exists() and x.is_file() else None,
)


def pipe(previous: BasePattern[Any, Any, Literal[MatchMode.VALUE_OPERATE]], current: BasePattern[TOrigin, TInput, TMM]) -> BasePattern[TOrigin, TInput, TMM]:
_new = current.copy()
_match = _new.match

def match(self, input_):
return _match(previous.match(input_))

_new.match = match.__get__(_new)
return _new


DelimiterInt = pipe(
BasePattern(mode=MatchMode.VALUE_OPERATE, origin=str, converter=lambda _, x: x.replace(",", "_")),
INTEGER
)
8 changes: 4 additions & 4 deletions nepattern/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ from .core import BasePattern
class Patterns(UserDict[Any, BasePattern]):
name: str
def __init__(self, name: str): ...
def set(self, target: BasePattern, alias: str | None = None, cover: bool = True, no_alias=False):
def set(self, target: BasePattern[Any, Any, Any], alias: str | None = None, cover: bool = True, no_alias=False):
"""
增加可使用的类型转换器
Expand All @@ -18,13 +18,13 @@ class Patterns(UserDict[Any, BasePattern]):
no_alias: 是否不使用目标类型自带的别名
"""
...
def sets(self, patterns: Iterable[BasePattern], cover: bool = True, no_alias=False): ...
def merge(self, patterns: dict[str, BasePattern], no_alias=False): ...
def sets(self, patterns: Iterable[BasePattern[Any, Any, Any]], cover: bool = True, no_alias=False): ...
def merge(self, patterns: dict[str, BasePattern[Any, Any, Any]], no_alias=False): ...
def remove(self, origin_type: type, alias: str | None = None): ...

def create_local_patterns(
name: str,
data: dict[Any, BasePattern] | None = None,
data: dict[Any, BasePattern[Any, Any, Any]] | None = None,
set_current: bool = True,
) -> Patterns:
"""
Expand Down

0 comments on commit c413ef4

Please sign in to comment.