Skip to content

Commit

Permalink
✨ version 0.6.1
Browse files Browse the repository at this point in the history
update `__eq__` and `__hash__`
  • Loading branch information
RF-Tar-Railt committed Sep 23, 2023
1 parent 5ca0a66 commit 32d4f81
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 28 deletions.
69 changes: 53 additions & 16 deletions nepattern/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def validate(self, input_: Any, default: Union[TDefault, Empty] = Empty) -> Vali
return ValidateResult(error=e, flag=ResultFlag.ERROR)
return ValidateResult(default, flag=ResultFlag.DEFAULT) # type: ignore

def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, DirectPattern) and self.target == other.target

class DirectTypePattern(BasePattern[TOrigin, TOrigin]):
"""直接类型判断"""
Expand Down Expand Up @@ -126,6 +128,8 @@ def validate(self, input_: Any, default: Union[TDefault, Empty] = Empty) -> Vali
return ValidateResult(error=e, flag=ResultFlag.ERROR)
return ValidateResult(default, flag=ResultFlag.DEFAULT) # type: ignore

def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, DirectTypePattern) and self.target == other.target

class RegexPattern(BasePattern[Match[str], str]):
"""针对正则的特化匹配,支持正则组"""
Expand All @@ -148,6 +152,9 @@ def match(self, input_: Any) -> Match[str]:
lang.require("nepattern", "content_error").format(target=input_, expected=self.pattern)
)

def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, RegexPattern) and self.pattern == other.pattern


class UnionPattern(BasePattern[Any, _T]):
"""多类型参数的匹配"""
Expand Down Expand Up @@ -199,8 +206,10 @@ def __calc_repr__(self):
def __or__(self, other: BasePattern[Any, _T1]) -> UnionPattern[Union[_T, _T1]]:
return UnionPattern([*self.base, other]) # type: ignore

def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, UnionPattern) and self.base == other.base

TSeq = TypeVar("TSeq", list, tuple, set)
#TIterMode = TypeVar("TIterMode", bound=Literal["pre", "suf", "all"])

class IterMode(str, Enum):
PRE = "pre"
Expand Down Expand Up @@ -253,6 +262,7 @@ def match(self, text: Any):
def __calc_repr__(self):
return f"{self.origin.__name__}[{self.base}]"


TKey = TypeVar("TKey")
TVal = TypeVar("TVal")

Expand Down Expand Up @@ -351,6 +361,9 @@ def match(self, input_: Any) -> _TCase:
lang.require("nepattern", "content_error").format(target=input_, expected=self._repr)
) from e

def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, SwitchPattern) and self.switch == other.switch


class ForwardRefPattern(BasePattern[Any, Any]):
def __init__(self, ref: ForwardRef):
Expand All @@ -373,6 +386,9 @@ def match(self, input_: Any):
)
return input_

def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, ForwardRefPattern) and self.ref == other.ref


class AntiPattern(BasePattern[TOrigin, Any]):
def __init__(self, pattern: BasePattern[TOrigin, Any]):
Expand Down Expand Up @@ -422,6 +438,8 @@ def validate(self, input_: _T, default: Union[TDefault, Empty] = Empty) -> Valid
default = cast(TDefault, default)
return ValidateResult(default, flag=ResultFlag.DEFAULT)

def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, AntiPattern) and self.base == other.base

TInput = TypeVar("TInput")

Expand All @@ -434,8 +452,11 @@ def __init__(
alias: str | None = None,
):
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=origin, alias=alias or func.__name__)
self.__func__ = func
self.match = func.__get__(self) # type: ignore

def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, CustomMatchPattern) and self.__func__ == other.__func__

NONE = BasePattern(mode=MatchMode.KEEP, origin=None, alias="none") # type: ignore

Expand All @@ -448,17 +469,23 @@ def _any_str(_, x: Any) -> str:
AnyString = CustomMatchPattern(str, _any_str, "any_str")
"""匹配任意内容并转为字符串的表达式"""


def _string(_, x: str) -> str:
if not isinstance(x, str): # pragma: no cover
raise MatchFailed(
lang.require("nepattern", "type_error").format(type=x.__class__, target=x, expected="str")
)
return x


STRING = CustomMatchPattern(str, _string, "str")

@final
class StrPattern(BasePattern[str, str]):
def __init__(self):
super().__init__(mode=MatchMode.KEEP, origin=str, alias="str", accepts=str)

def match(self, input_: str) -> str:
if not isinstance(input_, str): # pragma: no cover
raise MatchFailed(
lang.require("nepattern", "type_error")
.format(type=input_.__class__, target=input_, expected="str")
)
return input_

def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, StrPattern)

STRING = StrPattern()

@final
class IntPattern(BasePattern[int, Union[str, int]]):
Expand All @@ -481,6 +508,8 @@ def match(self, input_: Union[str, int]) -> int:
lang.require("nepattern", "content_error").format(target=input_, expected="int")
) from e

def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, IntPattern)


INTEGER = IntPattern()
Expand Down Expand Up @@ -508,7 +537,8 @@ def match(self, input_: Union[str, float, int]) -> float:
lang.require("nepattern", "content_error").format(target=input_, expected="float")
) from e


def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, FloatPattern)

FLOAT = FloatPattern()
"""浮点数表达式"""
Expand Down Expand Up @@ -536,7 +566,8 @@ def match(self, input_: Union[str, float]) -> float:
lang.require("nepattern", "content_error").format(target=input_, expected="int | float")
) from e


def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, NumberPattern)

NUMBER = NumberPattern()
"""一般数表达式,既可以浮点数也可以整数 """
Expand All @@ -561,7 +592,9 @@ def match(self, input_: Union[str, bool]) -> bool:
if input_ in self._BOOL:
return self._BOOL[input_]
raise MatchFailed(lang.require("nepattern", "content_error").format(target=input_, expected="bool"))


def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, BoolPattern)


BOOLEAN = BoolPattern()
Expand Down Expand Up @@ -608,7 +641,9 @@ def match(self, input_: str) -> int:
raise MatchFailed(
lang.require("nepattern", "content_error").format(target=input_, expected="hex")
) from e


def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, HexPattern)

HEX = HexPattern()
"""匹配16进制数的表达式"""
Expand All @@ -635,6 +670,8 @@ def match(self, input_: Union[str, int, float]) -> datetime:
)
return DateParser.parse(input_)

def __calc_eq__(self, other): # pragma: no cover
return isinstance(other, DateTimePattern)

DATETIME = DateTimePattern()
"""匹配时间的表达式"""
Expand Down
16 changes: 11 additions & 5 deletions nepattern/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from copy import deepcopy
from enum import Enum, IntEnum
import re
from typing import Any, Callable, Generic, TypeVar
Expand Down Expand Up @@ -226,6 +225,10 @@ class BasePattern(Generic[TOrigin, TInput]):
"_repr",
)

def __new__(cls, *args, **kwargs):
cls.__eq__ = cls.__calc_eq__
return super().__new__(cls, *args, **kwargs)

def __init__(
self,
pattern: str = ".+",
Expand Down Expand Up @@ -261,7 +264,7 @@ def __init__(
self._accepts = get_args(accepts) or (accepts,)
self._pattern_accepts = addition_accepts
self._repr = self.__calc_repr__()
self._hash = hash(self._repr)
self._hash = self.__calc_hash__()

if not addition_accepts:
self.accept = (lambda x: True) if _accepts is Any else (lambda _: generic_isinstance(_, _accepts))
Expand All @@ -276,7 +279,10 @@ def __init__(

def refresh(self): # pragma: no cover
self._repr = self.__calc_repr__()
self._hash = hash(self._repr)
self._hash = self.__calc_hash__()

def __calc_hash__(self):
return hash((self._repr, self.origin, self.mode, self.alias, self.previous, self._accepts, self.pattern))

def __calc_repr__(self):
if self.mode == MatchMode.KEEP:
Expand Down Expand Up @@ -315,8 +321,8 @@ def __str__(self):
def __hash__(self):
return self._hash

def __eq__(self, other):
return isinstance(other, BasePattern) and self._repr == other._repr
def __calc_eq__(self, other):
return isinstance(other, self.__class__) and self._hash == other._hash

@staticmethod
def of(unit: type[TOrigin]):
Expand Down
2 changes: 2 additions & 0 deletions nepattern/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,9 @@ class BasePattern(Generic[TOrigin, TInput]):
validators: list[Callable[[TOrigin], bool]] | None = None,
): ...
def refresh(self): ...
def __calc_hash__(self): ...
def __calc_repr__(self): ...
def __calc_eq__(self, other): ...
def __repr__(self): ...
def __str__(self): ...
def __hash__(self): ...
Expand Down
6 changes: 3 additions & 3 deletions nepattern/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def parser(item: Any, extra: str = "allow") -> BasePattern:
with suppress(TypeError):
if item and (pat := all_patterns().get(item, None)):
return pat
with suppress(TypeError):
if not inspect.isclass(item) and isinstance(item, (GenericAlias, CGenericAlias, CUnionType)):
return _generic_parser(item, extra)
#with suppress(TypeError):
if not inspect.isclass(item) and isinstance(item, (GenericAlias, CGenericAlias, CUnionType)):
return _generic_parser(item, extra)
if isinstance(item, TypeVar):
return _typevar_parser(item)
if inspect.isclass(item) and getattr(item, "_is_protocol", False):
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nepattern"
version = "0.6.0"
version = "0.6.1"
description = "a complex pattern, support typing"
authors = [
{name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com"},
Expand Down Expand Up @@ -76,6 +76,7 @@ exclude_lines = [
"def __repr__",
"def __str__",
"def __eq__",
"def __calc_eq__",
"except ImportError:",
]

Expand Down
9 changes: 6 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,9 @@ def test_parser():
pat11 = parser(int)
assert pat11.validate(-321).success
pat11_1 = parser(123)
assert pat11_1 == BasePattern.on(123)
print(pat11, pat11_1)
pat11_2 = BasePattern.to(int)
assert pat11_2 == pat11
assert BasePattern.to(None) == NONE
assert parser(BasePattern.of(int)) == BasePattern.of(int)
assert isinstance(parser(Literal["a", "b"]), UnionPattern)
assert parser(Type[int]).origin is type
assert parser(complex) == BasePattern.of(complex)
Expand Down Expand Up @@ -557,6 +554,12 @@ def test_value_operate():
assert pat22_1.validate("123.0").failed
assert pat22_1.validate([]).failed

def test_eq():
assert parser(123) == BasePattern.on(123)
assert BasePattern.to(None) == NONE
assert parser(BasePattern.of(int)) == BasePattern.of(int)
assert parser(str) == STRING

if __name__ == "__main__":
import pytest

Expand Down

0 comments on commit 32d4f81

Please sign in to comment.