diff --git a/diffly/_conditions.py b/diffly/_conditions.py index 25f76d9..d6970a6 100644 --- a/diffly/_conditions.py +++ b/diffly/_conditions.py @@ -3,6 +3,7 @@ import datetime as dt from collections.abc import Mapping +from typing import cast import polars as pl from polars.datatypes import DataType, DataTypeClass @@ -206,12 +207,7 @@ def _compare_sequence_columns( n_elements = dtype_right.shape[0] has_same_length = col_left.list.len().eq(pl.lit(n_elements)) else: # pl.List vs pl.List - if not isinstance(max_list_length, int): - # Fallback for nested list comparisons where no max_list_length is - # available: perform a direct equality comparison without element-wise - # unrolling. - return _eq_missing(col_left.eq_missing(col_right), col_left, col_right) - n_elements = max_list_length + n_elements = cast(int, max_list_length) has_same_length = col_left.list.len().eq_missing(col_right.list.len()) if n_elements == 0: @@ -232,7 +228,7 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex abs_tol=abs_tol, rel_tol=rel_tol, abs_tol_temporal=abs_tol_temporal, - max_list_length=None, + max_list_length=max_list_length, ) for i in range(n_elements) ] diff --git a/diffly/comparison.py b/diffly/comparison.py index ee46e55..d4697fb 100644 --- a/diffly/comparison.py +++ b/diffly/comparison.py @@ -711,22 +711,30 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str] @cached_property def _max_list_lengths_by_column(self) -> dict[str, int]: - list_columns = [ - col - for col in self._other_common_columns - if isinstance(self.left_schema[col], pl.List) - and isinstance(self.right_schema[col], pl.List) - ] - if not list_columns: + """Max list length across all nesting levels, for columns where both sides + contain a List anywhere in their type tree.""" + left_exprs: list[pl.Expr] = [] + right_exprs: list[pl.Expr] = [] + columns: list[str] = [] + + for col in self._other_common_columns: + col_left = _list_length_exprs(pl.col(col), self.left_schema[col]) + col_right = _list_length_exprs(pl.col(col), self.right_schema[col]) + if not (col_left and col_right): + continue + columns.append(col) + left_exprs.append(pl.max_horizontal(col_left).alias(col)) + right_exprs.append(pl.max_horizontal(col_right).alias(col)) + + if not columns: return {} - exprs = [pl.col(col).list.len().max().alias(col) for col in list_columns] [left_max, right_max] = pl.collect_all( - [self.left.select(exprs), self.right.select(exprs)] + [self.left.select(left_exprs), self.right.select(right_exprs)] ) return { col: max(int(left_max[col].item() or 0), int(right_max[col].item() or 0)) - for col in list_columns + for col in columns } def _condition_equal_rows(self, columns: list[str]) -> pl.Expr: @@ -833,3 +841,21 @@ def right_only(self) -> Schema: """Columns that are only present in the right data frame, mapped to their data types.""" return self.right() - self.left() + + +def _list_length_exprs( + expr: pl.Expr, dtype: pl.DataType | pl.datatypes.DataTypeClass +) -> list[pl.Expr]: + """Collect max-list-length scalar expressions for every List level in the type + tree.""" + if isinstance(dtype, pl.List): + return [expr.list.len().max(), *_list_length_exprs(expr.explode(), dtype.inner)] + if isinstance(dtype, pl.Array): + return _list_length_exprs(expr.explode(), dtype.inner) + if isinstance(dtype, pl.Struct): + return [ + e + for field in dtype.fields + for e in _list_length_exprs(expr.struct[field.name], field.dtype) + ] + return [] diff --git a/tests/test_conditions.py b/tests/test_conditions.py index 8aaeedb..308d2a3 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -7,6 +7,7 @@ import pytest from diffly._conditions import _can_compare_dtypes, condition_equal_columns +from diffly.comparison import compare_frames def test_condition_equal_columns_struct() -> None: @@ -14,17 +15,20 @@ def test_condition_equal_columns_struct() -> None: lhs = pl.DataFrame( { "pk": [1, 2], - "a_left": [{"x": 1.0, "y": 2.0}, {"x": 2.0, "y": 2.1}], + "a": [{"x": 1.0, "y": 2.0}, {"x": 2.0, "y": 2.1}], } ) rhs = pl.DataFrame( { "pk": [1, 2], - "a_right": [{"y": 2.0, "x": 1.1}, {"y": 2.7, "x": 2.1}], + "a": [{"y": 2.0, "x": 1.1}, {"y": 2.7, "x": 2.1}], } ) + c = compare_frames(lhs, rhs, primary_key="pk", abs_tol=0.5, rel_tol=0) # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -32,15 +36,16 @@ def test_condition_equal_columns_struct() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, - abs_tol=0.5, - rel_tol=0, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [True, False] @@ -49,17 +54,20 @@ def test_condition_equal_columns_different_struct_fields() -> None: lhs = pl.DataFrame( { "pk": [1, 2], - "a_left": [{"x": 1.0, "z": 2.0}, {"x": 2.0, "z": 2.1}], + "a": [{"x": 1.0, "z": 2.0}, {"x": 2.0, "z": 2.1}], } ) rhs = pl.DataFrame( { "pk": [1, 2], - "a_right": [{"y": 2.0, "x": 1.1}, {"y": 2.7, "x": 2.1}], + "a": [{"y": 2.0, "x": 1.1}, {"y": 2.7, "x": 2.1}], } ) + c = compare_frames(lhs, rhs, primary_key="pk") # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -67,13 +75,16 @@ def test_condition_equal_columns_different_struct_fields() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [False, False] @@ -88,21 +99,18 @@ def test_condition_equal_columns_list_array_with_tolerance( ) -> None: # Arrange lhs = pl.DataFrame( - { - "pk": [1, 2, 3], - "a_left": [[1.0, 1.1], [2.0, 2.1], [3.0, 3.0]], - }, - schema={"pk": pl.Int64, "a_left": lhs_type}, + {"pk": [1, 2, 3], "a": [[1.0, 1.1], [2.0, 2.1], [3.0, 3.0]]}, + schema={"pk": pl.Int64, "a": lhs_type}, ) rhs = pl.DataFrame( - { - "pk": [1, 2, 3], - "a_right": [[1.0, 1.1], [2.0, 2.2], [3.0, 3.7]], - }, - schema={"pk": pl.Int64, "a_right": rhs_type}, + {"pk": [1, 2, 3], "a": [[1.0, 1.1], [2.0, 2.2], [3.0, 3.7]]}, + schema={"pk": pl.Int64, "a": rhs_type}, ) + c = compare_frames(lhs, rhs, primary_key="pk", abs_tol=0.5, rel_tol=0) # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -110,14 +118,19 @@ def test_condition_equal_columns_list_array_with_tolerance( "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - abs_tol=0.5, - rel_tol=0, - max_list_length=2, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) + # Assert + if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List): + assert c._max_list_lengths_by_column == {"a": 2} + else: + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [True, True, False] @@ -136,27 +149,30 @@ def test_condition_equal_columns_nested_list_array_with_tolerance( lhs = pl.DataFrame( { "pk": [1, 2, 3], - "a_left": [ + "a": [ [[1.0, 1.1, 1.3], [2.0, 2.1, 2.2]], [[3.0, 3.0, 3.1], [4.0, 4.0, 4.1]], [[5.0, 5.0, 5.1], [6.0, 6.0, 6.1]], ], }, - schema={"pk": pl.Int64, "a_left": lhs_type}, + schema={"pk": pl.Int64, "a": lhs_type}, ) rhs = pl.DataFrame( { "pk": [1, 2, 3], - "a_right": [ + "a": [ [[1.0, 1.1, 1.3], [2.0, 2.1, 2.2]], [[3.0, 3.0, 3.1], [4.0, 4.4, 4.1]], [[5.0, 5.0, 5.1], [6.0, 6.8, 6.1]], ], }, - schema={"pk": pl.Int64, "a_right": rhs_type}, + schema={"pk": pl.Int64, "a": rhs_type}, ) + c = compare_frames(lhs, rhs, primary_key="pk", abs_tol=0.5, rel_tol=0) # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -164,36 +180,31 @@ def test_condition_equal_columns_nested_list_array_with_tolerance( "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - abs_tol=0.5, - rel_tol=0, - max_list_length=2, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) + # Assert if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List): - assert actual.to_list() == [True, False, False] + assert c._max_list_lengths_by_column == {"a": 3} else: - assert actual.to_list() == [True, True, False] + assert c._max_list_lengths_by_column == {} + assert actual.to_list() == [True, True, False] def test_condition_equal_columns_nested_dtype_mismatch() -> None: # Arrange - lhs = pl.DataFrame( - { - "pk": [1, 2], - "a_left": [{"x": 1}, {"x": 2}], - }, - ) - rhs = pl.DataFrame( - { - "pk": [1, 2], - "a_right": [[1.0, 1.1], [2.0, 2.2]], - }, - ) + lhs = pl.DataFrame({"pk": [1, 2], "a": [{"x": 1}, {"x": 2}]}) + rhs = pl.DataFrame({"pk": [1, 2], "a": [[1.0, 1.1], [2.0, 2.2]]}) + c = compare_frames(lhs, rhs, primary_key="pk") # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -201,32 +212,28 @@ def test_condition_equal_columns_nested_dtype_mismatch() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [False, False] def test_condition_equal_columns_exactly_one_nested() -> None: # Arrange - lhs = pl.DataFrame( - { - "pk": [1, 2], - "a_left": [{"x": 1}, {"x": 2}], - }, - ) - rhs = pl.DataFrame( - { - "pk": [1, 2], - "a_right": [1, 2], - }, - ) + lhs = pl.DataFrame({"pk": [1, 2], "a": [{"x": 1}, {"x": 2}]}) + rhs = pl.DataFrame({"pk": [1, 2], "a": [1, 2]}) + c = compare_frames(lhs, rhs, primary_key="pk") # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -234,13 +241,16 @@ def test_condition_equal_columns_exactly_one_nested() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [False, False] @@ -249,7 +259,7 @@ def test_condition_equal_columns_temporal_tolerance() -> None: lhs = pl.DataFrame( { "pk": [1, 2, 3, 4], - "a_left": [ + "a": [ dt.datetime(2025, 1, 1, 9, 0, 0), dt.datetime(2025, 1, 1, 10, 0, 0), None, @@ -260,7 +270,7 @@ def test_condition_equal_columns_temporal_tolerance() -> None: rhs = pl.DataFrame( { "pk": [1, 2, 3, 4], - "a_right": [ + "a": [ dt.datetime(2025, 1, 1, 9, 0, 1), dt.datetime(2025, 1, 1, 10, 0, 5), dt.datetime(2025, 1, 1, 10, 0, 0), @@ -268,8 +278,13 @@ def test_condition_equal_columns_temporal_tolerance() -> None: ], }, ) + c = compare_frames( + lhs, rhs, primary_key="pk", abs_tol_temporal=dt.timedelta(seconds=2) + ) # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -277,31 +292,36 @@ def test_condition_equal_columns_temporal_tolerance() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, - abs_tol_temporal=dt.timedelta(seconds=2), + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], + abs_tol_temporal=c.abs_tol_temporal_by_column["a"], ) ) .to_series() ) # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [True, False, False, True] def test_condition_equal_columns_two_lists() -> None: + # Arrange lhs = pl.DataFrame( - { - "pk": [1, 2, 3, 4, 5], - "a_left": [[1.0, 2.0], [3.0], [5.0, None], None, None], - }, + {"pk": [1, 2, 3, 4, 5], "a": [[1.0, 2.0], [3.0], [5.0, None], None, None]}, ) rhs = pl.DataFrame( { "pk": [1, 2, 3, 4, 5], - "a_right": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0], None], + "a": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0], None], }, ) + c = compare_frames(lhs, rhs, primary_key="pk", abs_tol=0.5, rel_tol=0) + # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -309,31 +329,31 @@ def test_condition_equal_columns_two_lists() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - abs_tol=0.5, - rel_tol=0, - max_list_length=2, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) + + # Assert + assert c._max_list_lengths_by_column == {"a": 2} assert actual.to_list() == [True, False, False, False, True] def test_condition_equal_columns_array_vs_list_length_mismatch() -> None: + # Arrange lhs = pl.DataFrame( - { - "pk": [1, 2], - "a_left": [[1.0, 2.0], [3.0, 4.0]], - }, - schema={"pk": pl.Int64, "a_left": pl.Array(pl.Float64, shape=2)}, - ) - rhs = pl.DataFrame( - { - "pk": [1, 2], - "a_right": [[1.0, 2.0], [3.0]], - }, + {"pk": [1, 2], "a": [[1.0, 2.0], [3.0, 4.0]]}, + schema={"pk": pl.Int64, "a": pl.Array(pl.Float64, shape=2)}, ) + rhs = pl.DataFrame({"pk": [1, 2], "a": [[1.0, 2.0], [3.0]]}) + c = compare_frames(lhs, rhs, primary_key="pk", abs_tol=0.5, rel_tol=0) + # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -341,32 +361,34 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, - abs_tol=0.5, - rel_tol=0, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) + + # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [True, False] def test_condition_equal_columns_two_arrays_different_shapes() -> None: + # Arrange lhs = pl.DataFrame( - { - "pk": [1], - "a_left": [[1.0, 2.0]], - }, - schema={"pk": pl.Int64, "a_left": pl.Array(pl.Float64, shape=2)}, + {"pk": [1], "a": [[1.0, 2.0]]}, + schema={"pk": pl.Int64, "a": pl.Array(pl.Float64, shape=2)}, ) rhs = pl.DataFrame( - { - "pk": [1], - "a_right": [[1.0, 2.0, 3.0]], - }, - schema={"pk": pl.Int64, "a_right": pl.Array(pl.Float64, shape=3)}, + {"pk": [1], "a": [[1.0, 2.0, 3.0]]}, + schema={"pk": pl.Int64, "a": pl.Array(pl.Float64, shape=3)}, ) + c = compare_frames(lhs, rhs, primary_key="pk") + # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -374,11 +396,16 @@ def test_condition_equal_columns_two_arrays_different_shapes() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) + + # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [False] @@ -391,21 +418,80 @@ def test_condition_equal_columns_two_arrays_different_shapes() -> None: def test_condition_equal_columns_empty_list_array( lhs_type: pl.DataType, rhs_type: pl.DataType ) -> None: + # Arrange + lhs = pl.DataFrame( + {"pk": [1, 2], "a": [[], None]}, + schema={"pk": pl.Int64, "a": lhs_type}, + ) + rhs = pl.DataFrame( + {"pk": [1, 2], "a": [[], None]}, + schema={"pk": pl.Int64, "a": rhs_type}, + ) + c = compare_frames(lhs, rhs, primary_key="pk") + + # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], + ) + ) + .to_series() + ) + + # Assert + if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List): + assert c._max_list_lengths_by_column == {"a": 0} + else: + assert c._max_list_lengths_by_column == {} + assert actual.to_list() == [True, True] + + +def test_condition_equal_columns_lists_only_inner() -> None: + # Arrange lhs = pl.DataFrame( { "pk": [1, 2], - "a_left": [[], None], + "a": [ + { + "x": 1, + "y": [1.0, 2.0, 3.0], + }, + { + "x": 2, + "y": [4.0, 5.0, 6.0], + }, + ], }, - schema={"pk": pl.Int64, "a_left": lhs_type}, ) rhs = pl.DataFrame( { "pk": [1, 2], - "a_right": [[], None], + "a": [ + { + "x": 1, + "y": [1.0, 2.1, 3.0], + }, + { + "x": 2, + "y": [4.0, 5.3, 6.0], + }, + ], }, - schema={"pk": pl.Int64, "a_right": rhs_type}, ) + c = compare_frames(lhs, rhs, primary_key="pk", abs_tol=0.2, rel_tol=0) + # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -413,12 +499,17 @@ def test_condition_equal_columns_empty_list_array( "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) - assert actual.to_list() == [True, True] + + # Assert + assert c._max_list_lengths_by_column == {"a": 3} + assert actual.to_list() == [True, False] @pytest.mark.parametrize(