Skip to content

Commit

Permalink
check: annotated dtypes match those specified in Field.dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
brendancooley committed Nov 3, 2023
1 parent e2bf0d7 commit 017c59b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/patito/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def _valid_dtypes( # noqa: C901
return [pl.List(dtype) for dtype in item_dtypes]

if "dtype" in props and 'anyOf' not in props:
if props['dtype'] not in cls._pydantic_type_to_valid_polars_types(props): # TODO should we allow pl floats for integer columns? Other type hierarchies to consider?
raise ValueError(f"Invalid dtype {props['dtype']} for column '{column}'. Check that specified dtype is allowable for the given type annotations.")
return [
props["dtype"],
]
Expand All @@ -237,7 +239,12 @@ def _valid_dtypes( # noqa: C901
column, {"type": PYTHON_TO_PYDANTIC_TYPES.get(type(props["const"]))}
)
return None
elif props["type"] == "integer":

return cls._pydantic_type_to_valid_polars_types(props)

@staticmethod
def _pydantic_type_to_valid_polars_types(props: Dict) -> Optional[List[pl.DataType]]:
if props["type"] == "integer":
return [
pl.Int64,
pl.Int32,
Expand Down Expand Up @@ -272,8 +279,6 @@ def _valid_dtypes( # noqa: C901
return None # pragma: no cover
elif props["type"] == "null":
return [pl.Null]
else: # pragma: no cover
return None

@property
def valid_sql_types( # type: ignore # noqa: C901
Expand Down
19 changes: 19 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,22 @@ class Test(pt.Model):

assert Test.nullable_columns == {"foo"}
assert set(Test.valid_dtypes['foo']) == {pl.Utf8, pl.Null}


def test_conflicting_type_dtype():

class Test(pt.Model):
foo: int = pt.Field(dtype=pl.Utf8)

with pytest.raises(ValueError):
Test.valid_dtypes()

class Test(pt.Model):
foo: str = pt.Field(dtype=pl.Float32)

with pytest.raises(ValueError):
Test.valid_dtypes()


if __name__ == "__main__":
test_conflicting_type_dtype()

0 comments on commit 017c59b

Please sign in to comment.