Skip to content

Commit

Permalink
fix: search for nested field constraints on validation, ignore nulls
Browse files Browse the repository at this point in the history
  • Loading branch information
brendancooley committed Nov 3, 2023
1 parent 161300b commit 9e132bc
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
26 changes: 19 additions & 7 deletions src/patito/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,22 +244,34 @@ def _find_errors( # noqa: C901
"minLength": lambda v: col.str.len_chars() >= v,
"maxLength": lambda v: col.str.len_chars() <= v,
}
checks = [
if "anyOf" in column_properties:
checks = [
check(x[key])
for key, check in filters.items()
for x in column_properties["anyOf"]
if key in x
]
else:
checks = []
checks += [
check(column_properties[key])
for key, check in filters.items()
if key in column_properties
]
if checks:
lazy_df = dataframe.lazy()
n_invalid_rows = 0
for check in checks:
lazy_df = lazy_df.filter(check)
valid_rows = lazy_df.collect()
invalid_rows = dataframe.height - valid_rows.height
if invalid_rows > 0:
lazy_df = dataframe.lazy()
lazy_df = lazy_df.filter(
~check
) # get failing rows (nulls will evaluate to null on boolean check, we only want failures (false)))
invalid_rows = lazy_df.collect()
n_invalid_rows += invalid_rows.height
if n_invalid_rows > 0:
errors.append(
ErrorWrapper(
RowValueError(
f"{invalid_rows} row{'' if invalid_rows == 1 else 's'} "
f"{n_invalid_rows} row{'' if n_invalid_rows == 1 else 's'} "
"with out of bound values."
),
loc=column_name,
Expand Down
21 changes: 21 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from datetime import date, datetime
from typing import List, Optional, Union, Literal
import re

import polars as pl
import pytest
Expand Down Expand Up @@ -600,3 +601,23 @@ class ListModel(pt.Model):
# print(old, new)
with pytest.raises(DataFrameValidationError):
ListModel.validate(valid_df.with_columns(pl.col(old).alias(new)))


def test_nested_field_attrs():
"""ensure that constraints are respected even when embedded inside 'anyOf'"""

class Test(pt.Model):
foo: int | None = pt.Field(
dtype=pl.Int64, ge=0, le=100, constraints=pt.field.sum() == 100
)

test_df = Test.DataFrame(
{"foo": [110, -10]}
) # meets constraint, but violates bounds (embedded in 'anyOf' in properties)
with pytest.raises(DataFrameValidationError) as e:
Test.validate(test_df)
pattern = re.compile(r"2 rows with out of bound values")
assert len(pattern.findall(str(e.value))) == 1

null_test_df = Test.DataFrame({"foo": [100, None, None]})
Test.validate(null_test_df) # should not raise

0 comments on commit 9e132bc

Please sign in to comment.