From 017c59b2373705059b951b6a22f8f6dd787c851a Mon Sep 17 00:00:00 2001 From: Brendan Cooley Date: Fri, 3 Nov 2023 10:07:19 -0400 Subject: [PATCH] check: annotated dtypes match those specified in Field.dtype --- src/patito/pydantic.py | 11 ++++++++--- tests/test_model.py | 19 +++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/patito/pydantic.py b/src/patito/pydantic.py index 6e63e75..23123a8 100644 --- a/src/patito/pydantic.py +++ b/src/patito/pydantic.py @@ -220,6 +220,8 @@ def _valid_dtypes( # noqa: C901 return [pl.List(dtype) for dtype in item_dtypes] if "dtype" in props and 'anyOf' not in props: + if props['dtype'] not in cls._pydantic_type_to_valid_polars_types(props): # TODO should we allow pl floats for integer columns? Other type hierarchies to consider? + raise ValueError(f"Invalid dtype {props['dtype']} for column '{column}'. Check that specified dtype is allowable for the given type annotations.") return [ props["dtype"], ] @@ -237,7 +239,12 @@ def _valid_dtypes( # noqa: C901 column, {"type": PYTHON_TO_PYDANTIC_TYPES.get(type(props["const"]))} ) return None - elif props["type"] == "integer": + + return cls._pydantic_type_to_valid_polars_types(props) + + @staticmethod + def _pydantic_type_to_valid_polars_types(props: Dict) -> Optional[List[pl.DataType]]: + if props["type"] == "integer": return [ pl.Int64, pl.Int32, @@ -272,8 +279,6 @@ def _valid_dtypes( # noqa: C901 return None # pragma: no cover elif props["type"] == "null": return [pl.Null] - else: # pragma: no cover - return None @property def valid_sql_types( # type: ignore # noqa: C901 diff --git a/tests/test_model.py b/tests/test_model.py index 6ae74b6..c528a66 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -462,3 +462,22 @@ class Test(pt.Model): assert Test.nullable_columns == {"foo"} assert set(Test.valid_dtypes['foo']) == {pl.Utf8, pl.Null} + + +def test_conflicting_type_dtype(): + + class Test(pt.Model): + foo: int = pt.Field(dtype=pl.Utf8) + + with pytest.raises(ValueError): + Test.valid_dtypes() + + class Test(pt.Model): + foo: str = pt.Field(dtype=pl.Float32) + + with pytest.raises(ValueError): + Test.valid_dtypes() + + +if __name__ == "__main__": + test_conflicting_type_dtype()