Skip to content

Commit

Permalink
refactor: moved and renamed some things
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed May 20, 2024
1 parent 067456e commit 626ed94
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 27 deletions.
40 changes: 21 additions & 19 deletions inline_snapshot/_pandas.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
from functools import wraps
from typing import Optional

from pandas import DataFrame
from pandas import Index
from pandas import Series
from pandas.testing import assert_frame_equal as real_assert_frame_equal
from pandas.testing import assert_index_equal as real_assert_index_equal
from pandas.testing import assert_series_equal as real_assert_series_equal


def make_assert_equals(data_type, assert_equals, repr_function):
def make_assert_equal(data_type, assert_equal, repr_function):

class Wrapper:
def __init__(self, df, cmp):
Expand All @@ -24,14 +17,14 @@ def __eq__(self, other):
return NotImplemented
return self.cmp(self.df, other)

@wraps(assert_equals)
@wraps(assert_equal)
def result(df, df_snapshot, *args, **kargs):
error: Optional[AssertionError] = None

def cmp(a, b):
nonlocal error
try:
assert_equals(a, b, *args, **kargs)
assert_equal(a, b, *args, **kargs)
except AssertionError as e:
error = e
return False
Expand All @@ -44,12 +37,21 @@ def cmp(a, b):
return result


assert_frame_equal = make_assert_equals(
DataFrame, real_assert_frame_equal, lambda df: df.to_dict("records")
)
assert_series_equal = make_assert_equals(
Series, real_assert_series_equal, lambda df: df.to_dict()
)
assert_index_equal = make_assert_equals(
Index, real_assert_index_equal, lambda df: df.to_list()
)
try:
import pandas
except:
pass
else:
from pandas.testing import assert_frame_equal
from pandas.testing import assert_index_equal
from pandas.testing import assert_series_equal

assert_frame_equal = make_assert_equal(
pandas.DataFrame, assert_frame_equal, lambda df: df.to_dict("records")
)
assert_series_equal = make_assert_equal(
pandas.Series, assert_series_equal, lambda df: df.to_dict()
)
assert_index_equal = make_assert_equal(
pandas.Index, assert_index_equal, lambda df: df.to_list()
)
16 changes: 8 additions & 8 deletions tests/test_pandas.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import sys

import pytest
from pandas import DataFrame
from pandas import Index
from pandas import Series

from inline_snapshot import snapshot
from inline_snapshot._pandas import assert_frame_equal
from inline_snapshot._pandas import assert_index_equal
from inline_snapshot._pandas import assert_series_equal
if sys.version_info >= (3, 9):
from pandas import DataFrame
from pandas import Index
from pandas import Series

nan = float("nan")
from inline_snapshot import snapshot
from inline_snapshot._pandas import assert_frame_equal
from inline_snapshot._pandas import assert_index_equal
from inline_snapshot._pandas import assert_series_equal


@pytest.mark.skipif(sys.version_info < (3, 9), reason="no pandas for 3.9")
Expand Down

0 comments on commit 626ed94

Please sign in to comment.