diff --git a/inline_snapshot/__init__.py b/inline_snapshot/__init__.py index 74d4829..03e1984 100644 --- a/inline_snapshot/__init__.py +++ b/inline_snapshot/__init__.py @@ -1,9 +1,10 @@ +from ._code_repr import customize +from ._code_repr import EqualTo from ._code_repr import HasRepr -from ._code_repr import register_repr from ._external import external from ._external import outsource from ._inline_snapshot import snapshot -__all__ = ["snapshot", "external", "outsource", "register_repr", "HasRepr"] +__all__ = ["snapshot", "external", "outsource", "customize", "HasRepr", "EqualTo"] __version__ = "0.10.2" diff --git a/inline_snapshot/_code_repr.py b/inline_snapshot/_code_repr.py index 806c243..5015661 100644 --- a/inline_snapshot/_code_repr.py +++ b/inline_snapshot/_code_repr.py @@ -34,32 +34,54 @@ def used_hasrepr(tree): ] +class DefaultBehaviour: + @staticmethod + def repr(v): + return real_repr(v) + + @singledispatch -def code_repr_dispatch(v): - return real_repr(v) +def customize_dispatch(datatype): + return DefaultBehaviour + + +def customize(datatype): + def f(customize_class): + + @customize_dispatch.register + def _(v: datatype): + return customize_class + + return customize_class + + return f + +class EqualTo: + def __init__(self, obj): + self._obj = obj -def register_repr(f): - """Register a funtion which should be used to get the code representation - of a object. + def __repr__(self): + return repr(self._obj) + + def __eq__(self, obj): + if isinstance(obj, EqualTo): + obj = obj._obj - ```python - @register_repr - def _(obj: MyCustomClass): - return f"MyCustomClass({repr(obj.attr)})" - ``` - it is important to use `repr()` inside the implementation, because it is mocked to return the code represenation + if type(self._obj) is not type(obj): + return NotImplemented - you dont have to provide a custom implementation if: - * __repr__() of your class returns a valid code representation, - * and __repr__() uses `repr()` to get the representaion of the child objects - """ - code_repr_dispatch.register(f) + c = customize_dispatch(obj) + return c.equal_to(self._obj, obj) def code_repr(obj): + c = customize_dispatch(obj) with mock.patch("builtins.repr", code_repr): - result = code_repr_dispatch(obj) + result = c.repr(obj) + + if hasattr(c, "equal_to"): + result = f"EqualTo({result})" try: ast.parse(result) @@ -69,38 +91,52 @@ def code_repr(obj): return result -@register_repr -def _(v: Enum): - return f"{type(v).__qualname__}.{v.name}" +@customize(Enum) +class CustomizeEnum: + @staticmethod + def repr(v: Enum): + return f"{type(v).__qualname__}.{v.name}" -@register_repr -def _(v: Flag): - name = type(v).__qualname__ - return " | ".join(f"{name}.{flag.name}" for flag in type(v) if flag in v) +@customize(Flag) +class CustomizeFlag: + @staticmethod + def repr(v: Flag): + name = type(v).__qualname__ + return " | ".join(f"{name}.{flag.name}" for flag in type(v) if flag in v) -@register_repr -def _(v: list): - return "[" + ", ".join(map(repr, v)) + "]" +@customize(list) +class CustomizeList: + @staticmethod + def repr(v: list): + return "[" + ", ".join(map(repr, v)) + "]" -@register_repr -def _(v: set): - if len(v) == 0: - return "set()" +@customize(set) +class CustomizeSet: + @staticmethod + def repr(v: set): + if len(v) == 0: + return "set()" - return "{" + ", ".join(map(repr, v)) + "}" + return "{" + ", ".join(map(repr, v)) + "}" -@register_repr -def _(v: dict): - return "{" + ", ".join(f"{repr(k)}:{repr(value)}" for k, value in v.items()) + "}" +@customize(dict) +class CustomizeDict: + @staticmethod + def repr(v: dict): + return ( + "{" + ", ".join(f"{repr(k)}:{repr(value)}" for k, value in v.items()) + "}" + ) -@register_repr -def _(v: type): - return v.__qualname__ +@customize(type) +class CustomizeType: + @staticmethod + def repr(v: type): + return v.__qualname__ from dataclasses import is_dataclass, fields @@ -113,15 +149,54 @@ def __subclasshook__(subclass): return is_dataclass(subclass) -@register_repr -def _(v: IsDataclass): - attrs = [] - for field in fields(v): # type: ignore - if field.repr: - value = getattr(v, field.name) - attrs.append(f"{field.name} = {repr(value)}") +@customize(IsDataclass) +class CustomizeDataclass_: + @staticmethod + def repr(v: IsDataclass): + attrs = [] + for field in fields(v): # type: ignore + if field.repr: + value = getattr(v, field.name) + attrs.append(f"{field.name} = {repr(value)}") + + return f"{repr(type(v))}({', '.join(attrs)})" + + +import ast +from itertools import zip_longest +from typing import List, Union + + +def compare_ast( + node1: Union[ast.AST, List[ast.AST]], node2: Union[ast.AST, List[ast.AST]] +) -> bool: + if type(node1) is not type(node2): + return False - return f"{repr(type(v))}({', '.join(attrs)})" + if isinstance(node1, ast.AST): + for k, v in vars(node1).items(): + if k in ("lineno", "end_lineno", "col_offset", "end_col_offset"): + continue + if not compare_ast(v, getattr(node2, k)): + return False + return True + + elif isinstance(node1, list) and isinstance(node2, list): + return all(compare_ast(n1, n2) for n1, n2 in zip_longest(node1, node2)) + + else: + return node1 == node2 + + +@customize(ast.AST) +class CustomizeAst: + @staticmethod + def repr(node: ast.AST): + return ast.dump(node) + + @staticmethod + def equal_to(node1: ast.AST, node2: ast.AST): + return compare_ast(node1, node2) try: @@ -130,14 +205,16 @@ def _(v: IsDataclass): pass else: - @register_repr - def _(model: BaseModel): - return ( - type(model).__qualname__ - + "(" - + ", ".join( - e + "=" + repr(getattr(model, e)) - for e in sorted(model.__pydantic_fields_set__) + @customize(BaseModel) + class CustomizePydanticBaseModel: + @staticmethod + def repr(model: BaseModel): + return ( + type(model).__qualname__ + + "(" + + ", ".join( + e + "=" + repr(getattr(model, e)) + for e in sorted(model.__pydantic_fields_set__) + ) + + ")" ) - + ")" - ) diff --git a/tests/conftest.py b/tests/conftest.py index 2eca13e..74cf09d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ -import ast import os import platform import re @@ -19,7 +18,7 @@ import inline_snapshot._external from .utils import snapshot_env from inline_snapshot import _inline_snapshot -from inline_snapshot import register_repr +from inline_snapshot import customize from inline_snapshot._format import format_code from inline_snapshot._inline_snapshot import Flags from inline_snapshot._rewrite_code import ChangeRecorder @@ -32,14 +31,11 @@ black.files.find_project_root = black.files.find_project_root.__wrapped__ # type: ignore -@register_repr -def _(v: executing.Source): - return f"" - - -@register_repr -def _(v: ast.AST): - return repr(ast.dump(v)) +@customize(executing.Source) +class _: + @staticmethod + def repr(v: executing.Source): + return f"" @pytest.fixture() diff --git a/tests/example.py b/tests/example.py index bc9f4b2..796e1ae 100644 --- a/tests/example.py +++ b/tests/example.py @@ -95,8 +95,8 @@ def run_inline( if reported_flags is not None: assert sorted(snapshot_flags) == reported_flags - # if changes is not None: - # assert all_changes == changes + if changes is not None: + assert all_changes == changes recorder.fix_all() diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py index 1cef7fe..c2f37e2 100644 --- a/tests/test_code_repr.py +++ b/tests/test_code_repr.py @@ -1,3 +1,9 @@ +from ast import Add +from ast import BinOp +from ast import Constant +from ast import Expr +from ast import Module + from .example import Example from inline_snapshot import snapshot @@ -248,3 +254,48 @@ class Color(Enum): } ), ).run_inline() + + +def test_equal_to(): + import ast + from inline_snapshot import EqualTo + from inline_snapshot._code_repr import code_repr + + assert ast.Constant(value=5) == EqualTo(ast.Constant(value=5)) + assert ast.Constant(value=5) == EqualTo(ast.Constant(value=5, lineno=5)) + + assert EqualTo(ast.Constant(value=5)) == EqualTo(ast.Constant(value=5)) + assert not EqualTo(ast.Constant(value=5)) == EqualTo( + ast.BinOp(left=ast.Constant(value=1), op=ast.Add(), right=ast.Constant(value=1)) + ) + + assert not EqualTo( + ast.BinOp(left=ast.Constant(value=1), op=ast.Add(), right=ast.Constant(value=1)) + ) == EqualTo( + ast.BinOp(left=ast.Constant(value=1), op=ast.Add(), right=ast.Constant(value=2)) + ) + + assert not EqualTo( + ast.BinOp(left=ast.Constant(value=1), op=ast.Add(), right=ast.Constant(value=1)) + ) == EqualTo( + ast.BinOp(left=ast.Constant(value=1), op=ast.Sub(), right=ast.Constant(value=1)) + ) + + assert code_repr(ast.Constant(value=5)) == snapshot("EqualTo(Constant(value=5))") + + assert EqualTo(ast.Constant(value=5)) == snapshot(EqualTo(Constant(value=5))) + + assert ast.parse("1+1") == snapshot( + EqualTo( + Module( + body=[ + Expr( + value=BinOp( + left=Constant(value=1), op=Add(), right=Constant(value=1) + ) + ) + ], + type_ignores=[], + ) + ) + ) diff --git a/tests/test_example.py b/tests/test_example.py index dbcd93e..68ab9ad 100644 --- a/tests/test_example.py +++ b/tests/test_example.py @@ -1,5 +1,12 @@ +from ast import Constant + +from executing import Source + from .example import Example +from inline_snapshot import EqualTo +from inline_snapshot import HasRepr from inline_snapshot import snapshot +from inline_snapshot._change import Replace def test_example(): @@ -23,6 +30,18 @@ def test_a(): e.run_inline( "fix", reported_flags=snapshot(["fix"]), + changes=snapshot( + [ + Replace( + flag="fix", + source=HasRepr(Source, ""), + node=EqualTo(Constant(value=2)), + new_code="1", + old_value=2, + new_value=1, + ) + ] + ), ).run_inline( "fix", changed_files=snapshot({}),