diff --git a/src/patito/exceptions.py b/src/patito/exceptions.py index 7240c55..1a816e6 100644 --- a/src/patito/exceptions.py +++ b/src/patito/exceptions.py @@ -127,7 +127,7 @@ def __repr_args__(self) -> "ReprArgs": class DataFrameValidationError(Representation, ValueError): __slots__ = "raw_errors", "model", "_error_cache" - def __init__(self, errors: Sequence[ErrorList], model: "BaseModel") -> None: + def __init__(self, errors: Sequence[ErrorList], model: Type["BaseModel"]) -> None: self.raw_errors = errors self.model = model self._error_cache: Optional[List["ErrorDict"]] = None diff --git a/src/patito/pydantic.py b/src/patito/pydantic.py index 23123a8..86cd75a 100644 --- a/src/patito/pydantic.py +++ b/src/patito/pydantic.py @@ -126,7 +126,7 @@ def columns(cls: Type[ModelType]) -> List[str]: # type: ignore @property def dtypes( # type: ignore cls: Type[ModelType], # pyright: ignore - ) -> dict[str, Type[pl.DataType]]: + ) -> dict[str, PolarsDataType]: """ Return the polars dtypes of the dataframe. @@ -153,7 +153,7 @@ def dtypes( # type: ignore @property def valid_dtypes( # type: ignore cls: Type[ModelType], # pyright: ignore - ) -> dict[str, List[Union[pl.PolarsDataType, pl.List]]]: + ) -> dict[str, List[Union[PolarsDataType, pl.List]]]: """ Return a list of polars dtypes which Patito considers valid for each field. @@ -197,10 +197,10 @@ def valid_dtypes( # type: ignore @classmethod def _valid_dtypes( # noqa: C901 - cls: Type[ModelType], + cls: Type[ModelType], # pyright: ignore column: str, props: Dict, - ) -> Optional[List[pl.PolarsDataType]]: + ) -> Optional[List[PolarsDataType]]: """ Map schema property to list of valid polars data types. @@ -218,10 +218,14 @@ def _valid_dtypes( # noqa: C901 f"No valid dtype mapping found for column '{column}'." ) return [pl.List(dtype) for dtype in item_dtypes] - - if "dtype" in props and 'anyOf' not in props: - if props['dtype'] not in cls._pydantic_type_to_valid_polars_types(props): # TODO should we allow pl floats for integer columns? Other type hierarchies to consider? - raise ValueError(f"Invalid dtype {props['dtype']} for column '{column}'. Check that specified dtype is allowable for the given type annotations.") + + if "dtype" in props and "anyOf" not in props: + if props["dtype"] not in cls._pydantic_type_to_valid_polars_types( + props + ): # TODO should we allow pl floats for integer columns? Other type hierarchies to consider? + raise ValueError( + f"Invalid dtype {props['dtype']} for column '{column}'. Check that specified dtype is allowable for the given type annotations." + ) return [ props["dtype"], ] @@ -239,11 +243,13 @@ def _valid_dtypes( # noqa: C901 column, {"type": PYTHON_TO_PYDANTIC_TYPES.get(type(props["const"]))} ) return None - + return cls._pydantic_type_to_valid_polars_types(props) @staticmethod - def _pydantic_type_to_valid_polars_types(props: Dict) -> Optional[List[pl.DataType]]: + def _pydantic_type_to_valid_polars_types( + props: Dict, + ) -> Optional[List[PolarsDataType]]: if props["type"] == "integer": return [ pl.Int64, @@ -574,6 +580,9 @@ class Model(BaseModel, metaclass=ModelMetaclass): defaults: ClassVar[Dict[str, Any]] + if TYPE_CHECKING: + model_fields: ClassVar[dict[str, FieldInfo]] + @classmethod # type: ignore[misc] @property def DataFrame( @@ -786,6 +795,8 @@ def example_value( # noqa: C901 field_type = "null" else: field_type = allowable[0] + else: + raise NotImplementedError if "const" in properties: # The default value is the only valid value, provided as const return properties["const"] @@ -1461,8 +1472,13 @@ def _derive_field( if x in field.__slots__ and x not in ["annotation", "default"] } if make_nullable: - # This originally non-nullable field has become nullable - field_type = Optional[field_type] + if field_type is None: + raise TypeError( + "Cannot make field nullable if no type annotation is provided!" + ) + else: + # This originally non-nullable field has become nullable + field_type = Optional[field_type] elif field.is_required() and default is None: # We need to replace Pydantic's None default value with ... in order # to make it clear that the field is still non-nullable and @@ -1507,10 +1523,10 @@ def __init__( ) -def Field( +def Field( # noqa: C901 *args, **kwargs, -): +) -> Any: pt_kwargs = {k: kwargs.pop(k, None) for k in get_args(PT_INFO)} meta_kwargs = { k: v for k, v in kwargs.items() if k in fields.FieldInfo.metadata_lookup diff --git a/tests/test_model.py b/tests/test_model.py index c528a66..e075342 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -461,23 +461,18 @@ class Test(pt.Model): foo: str | None = pt.Field(dtype=pl.Utf8) assert Test.nullable_columns == {"foo"} - assert set(Test.valid_dtypes['foo']) == {pl.Utf8, pl.Null} + assert set(Test.valid_dtypes["foo"]) == {pl.Utf8, pl.Null} def test_conflicting_type_dtype(): - - class Test(pt.Model): + class Test1(pt.Model): foo: int = pt.Field(dtype=pl.Utf8) with pytest.raises(ValueError): - Test.valid_dtypes() - - class Test(pt.Model): + Test1.valid_dtypes + + class Test2(pt.Model): foo: str = pt.Field(dtype=pl.Float32) - - with pytest.raises(ValueError): - Test.valid_dtypes() - -if __name__ == "__main__": - test_conflicting_type_dtype() + with pytest.raises(ValueError): + Test2.valid_dtypes diff --git a/tests/test_polars.py b/tests/test_polars.py index b156abe..c288c0e 100644 --- a/tests/test_polars.py +++ b/tests/test_polars.py @@ -253,5 +253,3 @@ class Model(pt.Model): # Or a list of columns assert df.drop(["column_1", "column_2"]).columns == [] - -