Skip to content

Commit

Permalink
feat: customize repr (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed May 3, 2024
1 parent 9060935 commit 7d47ec9
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
3 changes: 2 additions & 1 deletion inline_snapshot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ._code_repr import register_repr
from ._external import external
from ._external import outsource
from ._inline_snapshot import snapshot

__all__ = ["snapshot", "external", "outsource"]
__all__ = ["snapshot", "external", "outsource", "register_repr"]

__version__ = "0.8.2"
59 changes: 59 additions & 0 deletions inline_snapshot/_code_repr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from enum import Enum
from functools import singledispatch
from unittest import mock

real_repr = repr


@singledispatch
def code_repr_dispatch(v):
return real_repr(v)


register_repr = code_repr_dispatch.register


def code_repr(obj):
with mock.patch("builtins.repr", code_repr_dispatch):
return code_repr_dispatch(obj)


@register_repr
def _(v: Enum):
return type(v).__name__ + "." + v._name_


@register_repr
def _(v: list):
return "[" + ", ".join(map(repr, v)) + "]"


@register_repr
def _(v: set):
if len(v) == 0:
return "set()"

return "{" + ", ".join(map(repr, v)) + "}"


@register_repr
def _(v: dict):
return "{" + ", ".join(f"{repr(k)}:{repr(value)}" for k, value in v.items()) + "}"


try:
from pydantic import BaseModel
except ImportError:
pass
else:

@register_repr
def _(model: BaseModel):
return (
type(model).__name__
+ "("
+ ", ".join(
e + "=" + repr(getattr(model, e)) for e in model.__pydantic_fields_set__
)
+ ")"
)
4 changes: 3 additions & 1 deletion inline_snapshot/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import tokenize
from collections import namedtuple

from ._code_repr import code_repr


def normalize_strings(token_sequence):
"""Normalize string concattenanion.
Expand Down Expand Up @@ -102,7 +104,7 @@ def __eq__(self, other):


def value_to_token(value):
input = io.StringIO(repr(value))
input = io.StringIO(code_repr(value))

def map_string(tok):
"""Convert strings with newlines in triple quoted strings."""
Expand Down

0 comments on commit 7d47ec9

Please sign in to comment.