Skip to content

Commit

Permalink
💥 AntiPattern
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Sep 20, 2023
1 parent fb39150 commit 682ff3e
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 92 deletions.
1 change: 1 addition & 0 deletions nepattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .base import ANY as ANY
from .base import AnyString as AnyString
from .base import AntiPattern as AntiPattern
from .base import BOOLEAN as BOOLEAN
from .base import DATETIME as DATETIME
from .base import DICT as DICT
Expand Down
43 changes: 37 additions & 6 deletions nepattern/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from pathlib import Path
import re
import sys
from typing import Any, Dict, ForwardRef, Iterable, Literal, Match, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, ForwardRef, Iterable, Literal, Match, TypeVar, Union, cast

from tarina import DateParser, Empty, lang

from .core import BasePattern, MatchMode, ResultFlag, ValidateResult
from .exception import MatchFailed
from .util import TPattern

TDefault = TypeVar("TDefault")


class DirectPattern(BasePattern):
"""直接判断"""
Expand Down Expand Up @@ -88,7 +90,7 @@ class UnionPattern(BasePattern):
for_validate: list[BasePattern]
for_equal: list[str | object]

def __init__(self, base: Iterable[BasePattern | object | str], anti: bool = False):
def __init__(self, base: Iterable[BasePattern | object | str]):
self.base = list(base)
self.optional = False
self.for_validate = []
Expand All @@ -103,7 +105,7 @@ def __init__(self, base: Iterable[BasePattern | object | str], anti: bool = Fals
else:
self.for_equal.append(arg)
alias_content = "|".join([repr(a) for a in self.for_validate] + [repr(a) for a in self.for_equal])
super().__init__(mode=MatchMode.KEEP, origin=str, alias=alias_content, anti=anti)
super().__init__(mode=MatchMode.KEEP, origin=str, alias=alias_content)

def match(self, text: Any):
if not text:
Expand All @@ -118,15 +120,14 @@ def match(self, text: Any):
return text

def __calc_repr__(self):
return ("!" if self.anti else "") + ("|".join(repr(a) for a in (*self.for_validate, *self.for_equal)))
return "|".join(repr(a) for a in (*self.for_validate, *self.for_equal))

def prefixed(self) -> UnionPattern:
from .main import parser

return UnionPattern(
[pat.prefixed() for pat in self.for_validate]
+ [parser(eq).prefixed() if isinstance(eq, str) else eq for eq in self.for_equal], # type: ignore
self.anti,
)

def suffixed(self) -> UnionPattern:
Expand All @@ -135,7 +136,6 @@ def suffixed(self) -> UnionPattern:
return UnionPattern(
[pat.suffixed() for pat in self.for_validate]
+ [parser(eq).suffixed() if isinstance(eq, str) else eq for eq in self.for_equal], # type: ignore
self.anti,
)


Expand Down Expand Up @@ -312,6 +312,37 @@ def match(self, input_: Any):
return input_


class AntiPattern(BasePattern[Any]):
def __init__(self, pattern: BasePattern[Any]):
self.base = pattern
super().__init__(mode=MatchMode.TYPE_CONVERT, origin=Any, alias=f"!{pattern}")

def validate(self, input_: Any, default: TDefault | Empty = Empty) -> ValidateResult[Any | TDefault]:
"""
对传入的值进行反向验证,返回可能的匹配与转化结果。
若传入默认值,验证失败会返回默认值
"""
try:
res = self.base.match(input_)
except MatchFailed:
return ValidateResult(value=input_, flag=ResultFlag.VALID)
else: # pragma: no cover
for i in self.base.validators + self.validators:
if not i(res):
return ValidateResult(value=input_, flag=ResultFlag.VALID)
if default is Empty:
return ValidateResult(
error=MatchFailed(
lang.require("nepattern", "content_error").format(target=input_, expected=self._repr)
),
flag=ResultFlag.ERROR,
)
if TYPE_CHECKING:
default = cast(TDefault, default)
return ValidateResult(default, flag=ResultFlag.DEFAULT)


NONE = BasePattern(mode=MatchMode.KEEP, origin=None, alias="none") # type: ignore

ANY = BasePattern(mode=MatchMode.KEEP, origin=Any, alias="any")
Expand Down
75 changes: 5 additions & 70 deletions nepattern/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def step(
return self.success # type: ignore
if callable(other) and self.success:
return other(self.value)
return other.exec(self.value) if isinstance(other, BasePattern) else self
return other.validate(self.value) if isinstance(other, BasePattern) else self

@overload
def __rshift__(self, other: BasePattern[T]) -> ValidateResult[T]:
Expand Down Expand Up @@ -146,7 +146,6 @@ class BasePattern(Generic[TOrigin]):
converter: Callable[[BasePattern[TOrigin], Any], TOrigin | None]
validators: list[Callable[[TOrigin], bool]]

anti: bool
origin: type[TOrigin]
pattern_accepts: tuple[BasePattern, ...]
type_accepts: tuple[type, ...]
Expand All @@ -158,7 +157,6 @@ class BasePattern(Generic[TOrigin]):
"pattern",
"mode",
"converter",
"anti",
"origin",
"pattern_accepts",
"type_accepts",
Expand All @@ -180,7 +178,6 @@ def __init__(
previous: BasePattern | None = None,
accepts: list[type | BasePattern] | None = None,
validators: list[Callable[[TOrigin], bool]] | None = None,
anti: bool = False,
):
...

Expand All @@ -195,7 +192,6 @@ def __init__(
previous: BasePattern | None = None,
accepts: list[type | BasePattern] | None = None,
validators: list[Callable[[TOrigin], bool]] | None = None,
anti: bool = False,
):
...

Expand All @@ -210,7 +206,6 @@ def __init__(
previous: BasePattern | None = None,
accepts: list[type | BasePattern] | None = None,
validators: list[Callable[[TOrigin], bool]] | None = None,
anti: bool = False,
):
...

Expand All @@ -225,7 +220,6 @@ def __init__(
previous: BasePattern | None = None,
accepts: list[type | BasePattern] | None = None,
validators: list[Callable[[TOrigin], bool]] | None = None,
anti: bool = False,
):
...

Expand All @@ -239,7 +233,6 @@ def __init__(
previous: BasePattern | None = None,
accepts: list[type | BasePattern] | None = None,
validators: list[Callable[[TOrigin], bool]] | None = None,
anti: bool = False,
):
"""
初始化参数匹配表达式
Expand All @@ -259,19 +252,18 @@ def __init__(
lambda _, x: (get_origin(origin) or origin)(x) if mode == MatchMode.TYPE_CONVERT else eval(x[0])
)
self.validators = validators or []
self.anti = anti
self._repr = self.__calc_repr__()
self._hash = hash(self._repr)
if not self.pattern_accepts and not self.type_accepts:
self._accept = lambda _: True
elif not self.pattern_accepts:
self._accept = lambda x: generic_isinstance(x, self.type_accepts)
elif not self.type_accepts:
self._accept = lambda x: any(map(lambda y: y.exec(x).flag == "valid", self.pattern_accepts))
self._accept = lambda x: any(map(lambda y: y.validate(x).flag == "valid", self.pattern_accepts))
else:
self._accept = lambda x: (
generic_isinstance(x, self.type_accepts)
or any(map(lambda y: y.exec(x).flag == "valid", self.pattern_accepts))
or any(map(lambda y: y.validate(x).flag == "valid", self.pattern_accepts))
)

def __calc_repr__(self):
Expand Down Expand Up @@ -303,7 +295,7 @@ def __calc_repr__(self):
text = self.alias
return (
f"{f'{self.previous.__repr__()} -> ' if self.previous and id(self.previous) != id(self) else ''}"
f"{'!' if self.anti else ''}{text}"
f"{text}"
)

def __repr__(self):
Expand Down Expand Up @@ -338,13 +330,6 @@ def to(content: Any) -> BasePattern:
res = parser(content, "allow")
return res if isinstance(res, BasePattern) else parser(Any)

def reverse(self) -> Self:
"""改变 pattern 的 anti 值"""
self.anti = not self.anti
self._repr = self.__calc_repr__()
self._hash = hash(self._repr)
return self

def prefixed(self):
"""让表达式能在某些场景下实现前缀匹配; 返回自身的拷贝"""
cp_self = deepcopy(self)
Expand Down Expand Up @@ -440,58 +425,8 @@ def validate(self, input_: Any, default: TDefault | Empty = Empty) -> ValidateRe
default = cast(TDefault, default)
return ValidateResult(default, flag=ResultFlag.DEFAULT)

@overload
def invalidate(self, input_: Any) -> ValidateResult[Any]:
...

@overload
def invalidate(self, input_: Any, default: TDefault) -> ValidateResult[Any | TDefault]:
...

def invalidate(self, input_: Any, default: TDefault | Empty = Empty) -> ValidateResult[Any | TDefault]:
"""
对传入的值进行反向验证,返回可能的匹配与转化结果。
若传入默认值,验证失败会返回默认值
"""
try:
res = self.match(input_)
except MatchFailed:
return ValidateResult(value=input_, flag=ResultFlag.VALID)
else:
for i in self.validators:
if not i(res):
return ValidateResult(value=input_, flag=ResultFlag.VALID)
if default is Empty:
return ValidateResult(
error=MatchFailed(
lang.require("nepattern", "content_error").format(target=input_, expected=self._repr)
),
flag=ResultFlag.ERROR,
)
if TYPE_CHECKING:
default = cast(TDefault, default)
return ValidateResult(default, flag=ResultFlag.DEFAULT)

@overload
def exec(self, input_: Any) -> ValidateResult[TOrigin]:
...

@overload
def exec(self, input_: Any, default: TDefault) -> ValidateResult[TOrigin | TDefault]:
...

def exec(self, input_: Any, default: TDefault | Empty = Empty) -> ValidateResult[TOrigin | TDefault]:
"""
依据 anti 值 自动选择验证方式
"""
if self.anti:
return self.invalidate(input_, default) # type: ignore
else:
return self.validate(input_, default) # type: ignore

def __rrshift__(self, other):
return self.exec(other)
return self.validate(other)

def __rmatmul__(self, other) -> Self: # pragma: no cover
if isinstance(other, str):
Expand Down
26 changes: 10 additions & 16 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,12 @@ def __repr__(self):
def test_pattern_anti():
"""测试 BasePattern 的反向验证功能"""
pat8 = BasePattern.of(int)
assert pat8.validate(123).success
assert pat8.invalidate(123).failed
pat8_1 = AntiPattern(pat8)
assert pat8.validate(123).value == 123
assert pat8.validate("123").failed
assert pat8.invalidate("123").success
pat8.reverse()
assert pat8.exec(123).failed
assert pat8.exec("123").success
pat8.reverse()
assert pat8.exec(123).success
assert pat8.exec("123").failed
assert pat8_1.validate(123).failed
assert pat8_1.validate("123").value == "123"



def test_pattern_validator():
Expand All @@ -183,17 +179,15 @@ def test_pattern_validator():
)
assert pat9.validate(23).value == 23
assert pat9.validate(-23).failed
assert pat9.invalidate(-23).success
pat9_1 = BasePattern.to(set_unit(int, lambda x: x != 0))
assert pat9_1.invalidate("123").failed
print(pat9)


def test_pattern_default():
pat10 = BasePattern.of(int)
assert pat10.validate("123", 123).or_default
assert pat10.invalidate("123", 123).success
assert pat10.invalidate(123, "123").value == "123"
assert pat10.validate("123", 123).value == 123
assert AntiPattern(pat10).validate(123, "123").value == "123"


def test_parser():
Expand Down Expand Up @@ -413,10 +407,10 @@ def test_suffix():
def test_dunder():
pat17 = BasePattern.of(float)
assert ("test_float" @ pat17).alias == "test_float"
assert pat17.exec(1.33).step(str) == pat17.exec(1.33) >> str == "1.33"
assert (pat17.exec(1.33) >> 1).value == 1.33
assert pat17.validate(1.33).step(str) == pat17.validate(1.33) >> str == "1.33"
assert (pat17.validate(1.33) >> 1).value == 1.33
assert not '1.33' >> pat17
assert pat17.exec(1.33) >> bool
assert pat17.validate(1.33) >> bool
assert BasePattern.of(int).validate(1).step(lambda x: x + 2) == 3
pat17_1 = BasePattern(r"@(\d+)", MatchMode.REGEX_CONVERT, str, lambda _, x: x[0][1:])
pat17_2: BasePattern[int] = parser(int)
Expand Down

0 comments on commit 682ff3e

Please sign in to comment.