Skip to content

Commit

Permalink
feat: customize repr (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed May 26, 2024
1 parent c78d21c commit 5ed02b7
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 4 deletions.
3 changes: 2 additions & 1 deletion inline_snapshot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ._code_repr import register_repr
from ._external import external
from ._external import outsource
from ._inline_snapshot import snapshot

__all__ = ["snapshot", "external", "outsource"]
__all__ = ["snapshot", "external", "outsource", "register_repr"]

__version__ = "0.10.0"
90 changes: 90 additions & 0 deletions inline_snapshot/_code_repr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import ast
from enum import Enum
from enum import Flag
from functools import singledispatch
from unittest import mock

real_repr = repr


class HasRepr:
"""This class is used for objects where `__repr__()` returns an non
parsable representation."""

def __init__(self, str_repr: str) -> None:
self._str_repr = str_repr

def __repr__(self):
return f"HasRepr({self._str_repr!r})"


@singledispatch
def code_repr_dispatch(v):
return real_repr(v)


register_repr = code_repr_dispatch.register


def code_repr(obj):
with mock.patch("builtins.repr", code_repr_dispatch):
result = code_repr_dispatch(obj)

try:
ast.parse(result)
except SyntaxError:
return real_repr(HasRepr(result))

return result


@register_repr
def _(v: Enum):
return str(v)


@register_repr
def _(v: Flag):
name = type(v).__name__
return " | ".join(str(flag) for flag in type(v) if flag in v)


@register_repr
def _(v: list):
return "[" + ", ".join(map(repr, v)) + "]"


@register_repr
def _(v: set):
if len(v) == 0:
return "set()"

return "{" + ", ".join(map(repr, v)) + "}"


@register_repr
def _(v: dict):
return "{" + ", ".join(f"{repr(k)}:{repr(value)}" for k, value in v.items()) + "}"


@register_repr
def _(v: type):
return v.__qualname__


try:
from pydantic import BaseModel
except ImportError:
pass
else:

@register_repr
def _(model: BaseModel):
return (
type(model).__name__
+ "("
+ ", ".join(
e + "=" + repr(getattr(model, e)) for e in model.__pydantic_fields_set__
)
+ ")"
)
4 changes: 3 additions & 1 deletion inline_snapshot/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import tokenize
from collections import namedtuple

from ._code_repr import code_repr


def normalize_strings(token_sequence):
"""Normalize string concattenanion.
Expand Down Expand Up @@ -118,7 +120,7 @@ def __eq__(self, other):


def value_to_token(value):
input = io.StringIO(repr(value))
input = io.StringIO(code_repr(value))

def map_string(tok):
"""Convert strings with newlines in triple quoted strings."""
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run(self, *flags):
error = False

try:
exec(compile(filename.read_text("utf-8"), filename, "exec"))
exec(compile(filename.read_text("utf-8"), filename, "exec"), {})
except AssertionError:
traceback.print_exc()
error = True
Expand Down
92 changes: 92 additions & 0 deletions tests/test_code_repr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from inline_snapshot import snapshot


def test_enum(check_update):

assert (
check_update(
"""
from enum import Enum
class color(Enum):
val="val"
assert [color.val] == snapshot()
""",
flags="create",
)
== snapshot(
"""\
from enum import Enum
class color(Enum):
val="val"
assert [color.val] == snapshot([color.val])
"""
)
)


def test_flag(check_update):

assert (
check_update(
"""
from enum import Flag, auto
class Color(Flag):
red = auto()
green = auto()
blue = auto()
assert Color.red | Color.blue == snapshot()
""",
flags="create",
)
== snapshot(
"""\
from enum import Flag, auto
class Color(Flag):
red = auto()
green = auto()
blue = auto()
assert Color.red | Color.blue == snapshot(Color.red | Color.blue)
"""
)
)


def test_type(check_update):

assert (
check_update(
"""\
class Color:
pass
assert [Color,int] == snapshot()
""",
flags="create",
)
== snapshot(
"""\
class Color:
pass
assert [Color,int] == snapshot([Color, int])
"""
)
)
2 changes: 1 addition & 1 deletion tests/test_inline_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ class Thing:
def __repr__(self):
return "+++"
assert Thing() == snapshot()
assert Thing() == snapshot(HasRepr("+++"))
"""
)
)
Expand Down

0 comments on commit 5ed02b7

Please sign in to comment.