Skip to content

Commit

Permalink
feat: compare ast
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed May 28, 2024
1 parent f04acde commit 5adff65
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 70 deletions.
5 changes: 3 additions & 2 deletions inline_snapshot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from ._code_repr import customize
from ._code_repr import EqualTo
from ._code_repr import HasRepr
from ._code_repr import register_repr
from ._external import external
from ._external import outsource
from ._inline_snapshot import snapshot

__all__ = ["snapshot", "external", "outsource", "register_repr", "HasRepr"]
__all__ = ["snapshot", "external", "outsource", "customize", "HasRepr", "EqualTo"]

__version__ = "0.10.2"
189 changes: 133 additions & 56 deletions inline_snapshot/_code_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,54 @@ def used_hasrepr(tree):
]


class DefaultBehaviour:
@staticmethod
def repr(v):
return real_repr(v)


@singledispatch
def code_repr_dispatch(v):
return real_repr(v)
def customize_dispatch(datatype):
return DefaultBehaviour


def customize(datatype):
def f(customize_class):

@customize_dispatch.register
def _(v: datatype):
return customize_class

return customize_class

return f


class EqualTo:
def __init__(self, obj):
self._obj = obj

def register_repr(f):
"""Register a funtion which should be used to get the code representation
of a object.
def __repr__(self):
return repr(self._obj)

def __eq__(self, obj):
if isinstance(obj, EqualTo):
obj = obj._obj

```python
@register_repr
def _(obj: MyCustomClass):
return f"MyCustomClass({repr(obj.attr)})"
```
it is important to use `repr()` inside the implementation, because it is mocked to return the code represenation
if type(self._obj) is not type(obj):
return NotImplemented

you dont have to provide a custom implementation if:
* __repr__() of your class returns a valid code representation,
* and __repr__() uses `repr()` to get the representaion of the child objects
"""
code_repr_dispatch.register(f)
c = customize_dispatch(obj)
return c.equal_to(self._obj, obj)


def code_repr(obj):
c = customize_dispatch(obj)
with mock.patch("builtins.repr", code_repr):
result = code_repr_dispatch(obj)
result = c.repr(obj)

if hasattr(c, "equal_to"):
result = f"EqualTo({result})"

try:
ast.parse(result)
Expand All @@ -69,38 +91,52 @@ def code_repr(obj):
return result


@register_repr
def _(v: Enum):
return f"{type(v).__qualname__}.{v.name}"
@customize(Enum)
class CustomizeEnum:
@staticmethod
def repr(v: Enum):
return f"{type(v).__qualname__}.{v.name}"


@register_repr
def _(v: Flag):
name = type(v).__qualname__
return " | ".join(f"{name}.{flag.name}" for flag in type(v) if flag in v)
@customize(Flag)
class CustomizeFlag:
@staticmethod
def repr(v: Flag):
name = type(v).__qualname__
return " | ".join(f"{name}.{flag.name}" for flag in type(v) if flag in v)


@register_repr
def _(v: list):
return "[" + ", ".join(map(repr, v)) + "]"
@customize(list)
class CustomizeList:
@staticmethod
def repr(v: list):
return "[" + ", ".join(map(repr, v)) + "]"


@register_repr
def _(v: set):
if len(v) == 0:
return "set()"
@customize(set)
class CustomizeSet:
@staticmethod
def repr(v: set):
if len(v) == 0:
return "set()"

return "{" + ", ".join(map(repr, v)) + "}"
return "{" + ", ".join(map(repr, v)) + "}"


@register_repr
def _(v: dict):
return "{" + ", ".join(f"{repr(k)}:{repr(value)}" for k, value in v.items()) + "}"
@customize(dict)
class CustomizeDict:
@staticmethod
def repr(v: dict):
return (
"{" + ", ".join(f"{repr(k)}:{repr(value)}" for k, value in v.items()) + "}"
)


@register_repr
def _(v: type):
return v.__qualname__
@customize(type)
class CustomizeType:
@staticmethod
def repr(v: type):
return v.__qualname__


from dataclasses import is_dataclass, fields
Expand All @@ -113,15 +149,54 @@ def __subclasshook__(subclass):
return is_dataclass(subclass)


@register_repr
def _(v: IsDataclass):
attrs = []
for field in fields(v): # type: ignore
if field.repr:
value = getattr(v, field.name)
attrs.append(f"{field.name} = {repr(value)}")
@customize(IsDataclass)
class CustomizeDataclass_:
@staticmethod
def repr(v: IsDataclass):
attrs = []
for field in fields(v): # type: ignore
if field.repr:
value = getattr(v, field.name)
attrs.append(f"{field.name} = {repr(value)}")

return f"{repr(type(v))}({', '.join(attrs)})"


import ast
from itertools import zip_longest
from typing import List, Union


def compare_ast(
node1: Union[ast.AST, List[ast.AST]], node2: Union[ast.AST, List[ast.AST]]
) -> bool:
if type(node1) is not type(node2):
return False

return f"{repr(type(v))}({', '.join(attrs)})"
if isinstance(node1, ast.AST):
for k, v in vars(node1).items():
if k in ("lineno", "end_lineno", "col_offset", "end_col_offset"):
continue
if not compare_ast(v, getattr(node2, k)):
return False
return True

elif isinstance(node1, list) and isinstance(node2, list):
return all(compare_ast(n1, n2) for n1, n2 in zip_longest(node1, node2))

else:
return node1 == node2


@customize(ast.AST)
class CustomizeAst:
@staticmethod
def repr(node: ast.AST):
return ast.dump(node)

@staticmethod
def equal_to(node1: ast.AST, node2: ast.AST):
return compare_ast(node1, node2)


try:
Expand All @@ -130,14 +205,16 @@ def _(v: IsDataclass):
pass
else:

@register_repr
def _(model: BaseModel):
return (
type(model).__qualname__
+ "("
+ ", ".join(
e + "=" + repr(getattr(model, e))
for e in sorted(model.__pydantic_fields_set__)
@customize(BaseModel)
class CustomizePydanticBaseModel:
@staticmethod
def repr(model: BaseModel):
return (
type(model).__qualname__
+ "("
+ ", ".join(
e + "=" + repr(getattr(model, e))
for e in sorted(model.__pydantic_fields_set__)
)
+ ")"
)
+ ")"
)
16 changes: 6 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ast
import os
import platform
import re
Expand All @@ -19,7 +18,7 @@
import inline_snapshot._external
from .utils import snapshot_env
from inline_snapshot import _inline_snapshot
from inline_snapshot import register_repr
from inline_snapshot import customize
from inline_snapshot._format import format_code
from inline_snapshot._inline_snapshot import Flags
from inline_snapshot._rewrite_code import ChangeRecorder
Expand All @@ -32,14 +31,11 @@
black.files.find_project_root = black.files.find_project_root.__wrapped__ # type: ignore


@register_repr
def _(v: executing.Source):
return f"<source {Path(v.filename).name}>"


@register_repr
def _(v: ast.AST):
return repr(ast.dump(v))
@customize(executing.Source)
class _:
@staticmethod
def repr(v: executing.Source):
return f"<source {Path(v.filename).name}>"


@pytest.fixture()
Expand Down
4 changes: 2 additions & 2 deletions tests/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def run_inline(
if reported_flags is not None:
assert sorted(snapshot_flags) == reported_flags

# if changes is not None:
# assert all_changes == changes
if changes is not None:
assert all_changes == changes

recorder.fix_all()

Expand Down
51 changes: 51 additions & 0 deletions tests/test_code_repr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from ast import Add
from ast import BinOp
from ast import Constant
from ast import Expr
from ast import Module

from .example import Example
from inline_snapshot import snapshot

Expand Down Expand Up @@ -248,3 +254,48 @@ class Color(Enum):
}
),
).run_inline()


def test_equal_to():
import ast
from inline_snapshot import EqualTo
from inline_snapshot._code_repr import code_repr

assert ast.Constant(value=5) == EqualTo(ast.Constant(value=5))
assert ast.Constant(value=5) == EqualTo(ast.Constant(value=5, lineno=5))

assert EqualTo(ast.Constant(value=5)) == EqualTo(ast.Constant(value=5))
assert not EqualTo(ast.Constant(value=5)) == EqualTo(
ast.BinOp(left=ast.Constant(value=1), op=ast.Add(), right=ast.Constant(value=1))
)

assert not EqualTo(
ast.BinOp(left=ast.Constant(value=1), op=ast.Add(), right=ast.Constant(value=1))
) == EqualTo(
ast.BinOp(left=ast.Constant(value=1), op=ast.Add(), right=ast.Constant(value=2))
)

assert not EqualTo(
ast.BinOp(left=ast.Constant(value=1), op=ast.Add(), right=ast.Constant(value=1))
) == EqualTo(
ast.BinOp(left=ast.Constant(value=1), op=ast.Sub(), right=ast.Constant(value=1))
)

assert code_repr(ast.Constant(value=5)) == snapshot("EqualTo(Constant(value=5))")

assert EqualTo(ast.Constant(value=5)) == snapshot(EqualTo(Constant(value=5)))

assert ast.parse("1+1") == snapshot(
EqualTo(
Module(
body=[
Expr(
value=BinOp(
left=Constant(value=1), op=Add(), right=Constant(value=1)
)
)
],
type_ignores=[],
)
)
)
19 changes: 19 additions & 0 deletions tests/test_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from ast import Constant

from executing import Source

from .example import Example
from inline_snapshot import EqualTo
from inline_snapshot import HasRepr
from inline_snapshot import snapshot
from inline_snapshot._change import Replace


def test_example():
Expand All @@ -23,6 +30,18 @@ def test_a():
e.run_inline(
"fix",
reported_flags=snapshot(["fix"]),
changes=snapshot(
[
Replace(
flag="fix",
source=HasRepr(Source, "<source test_a.py>"),
node=EqualTo(Constant(value=2)),
new_code="1",
old_value=2,
new_value=1,
)
]
),
).run_inline(
"fix",
changed_files=snapshot({}),
Expand Down

0 comments on commit 5adff65

Please sign in to comment.