diff --git a/inline_snapshot/_code_repr.py b/inline_snapshot/_code_repr.py index 806c243..5f93703 100644 --- a/inline_snapshot/_code_repr.py +++ b/inline_snapshot/_code_repr.py @@ -19,6 +19,13 @@ def __repr__(self): return f"HasRepr({self._type.__qualname__}, {self._str_repr!r})" def __eq__(self, other): + if isinstance(other, HasRepr): + if other._type is not self._type: + return False + else: + if type(other) is not self._type: + return False + other_repr = code_repr(other) return other_repr == self._str_repr or other_repr == repr(self) diff --git a/tests/conftest.py b/tests/conftest.py index 2eca13e..24b762a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ -import ast import os import platform import re @@ -19,7 +18,6 @@ import inline_snapshot._external from .utils import snapshot_env from inline_snapshot import _inline_snapshot -from inline_snapshot import register_repr from inline_snapshot._format import format_code from inline_snapshot._inline_snapshot import Flags from inline_snapshot._rewrite_code import ChangeRecorder @@ -32,16 +30,6 @@ black.files.find_project_root = black.files.find_project_root.__wrapped__ # type: ignore -@register_repr -def _(v: executing.Source): - return f"" - - -@register_repr -def _(v: ast.AST): - return repr(ast.dump(v)) - - @pytest.fixture() def check_update(source): def w(source_code, *, flags="", reported_flags=None, number=1): diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py index 1cef7fe..182a4a1 100644 --- a/tests/test_code_repr.py +++ b/tests/test_code_repr.py @@ -1,4 +1,5 @@ from .example import Example +from inline_snapshot import HasRepr from inline_snapshot import snapshot @@ -34,7 +35,7 @@ class color(Enum): ) -def test_hasrepr(): +def test_snapshot_generates_hasrepr(): Example( """\ @@ -76,6 +77,13 @@ def test_thing(): ) +def test_hasrepr_type(): + assert 5 == HasRepr(int, "5") + assert not "a" == HasRepr(int, "5") + assert not HasRepr(float, "nan") == HasRepr(str, "nan") + assert not HasRepr(str, "a") == HasRepr(str, "b") + + def test_enum_in_dataclass(check_update): assert (