Skip to content

Commit

Permalink
chore: misc typing improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
brendancooley committed Nov 3, 2023
1 parent 017c59b commit 161300b
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/patito/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __repr_args__(self) -> "ReprArgs":
class DataFrameValidationError(Representation, ValueError):
__slots__ = "raw_errors", "model", "_error_cache"

def __init__(self, errors: Sequence[ErrorList], model: "BaseModel") -> None:
def __init__(self, errors: Sequence[ErrorList], model: Type["BaseModel"]) -> None:
self.raw_errors = errors
self.model = model
self._error_cache: Optional[List["ErrorDict"]] = None
Expand Down
44 changes: 30 additions & 14 deletions src/patito/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def columns(cls: Type[ModelType]) -> List[str]: # type: ignore
@property
def dtypes( # type: ignore
cls: Type[ModelType], # pyright: ignore
) -> dict[str, Type[pl.DataType]]:
) -> dict[str, PolarsDataType]:
"""
Return the polars dtypes of the dataframe.
Expand All @@ -153,7 +153,7 @@ def dtypes( # type: ignore
@property
def valid_dtypes( # type: ignore
cls: Type[ModelType], # pyright: ignore
) -> dict[str, List[Union[pl.PolarsDataType, pl.List]]]:
) -> dict[str, List[Union[PolarsDataType, pl.List]]]:
"""
Return a list of polars dtypes which Patito considers valid for each field.
Expand Down Expand Up @@ -197,10 +197,10 @@ def valid_dtypes( # type: ignore

@classmethod
def _valid_dtypes( # noqa: C901
cls: Type[ModelType],
cls: Type[ModelType], # pyright: ignore
column: str,
props: Dict,
) -> Optional[List[pl.PolarsDataType]]:
) -> Optional[List[PolarsDataType]]:
"""
Map schema property to list of valid polars data types.
Expand All @@ -218,10 +218,14 @@ def _valid_dtypes( # noqa: C901
f"No valid dtype mapping found for column '{column}'."
)
return [pl.List(dtype) for dtype in item_dtypes]

if "dtype" in props and 'anyOf' not in props:
if props['dtype'] not in cls._pydantic_type_to_valid_polars_types(props): # TODO should we allow pl floats for integer columns? Other type hierarchies to consider?
raise ValueError(f"Invalid dtype {props['dtype']} for column '{column}'. Check that specified dtype is allowable for the given type annotations.")

if "dtype" in props and "anyOf" not in props:
if props["dtype"] not in cls._pydantic_type_to_valid_polars_types(
props
): # TODO should we allow pl floats for integer columns? Other type hierarchies to consider?
raise ValueError(
f"Invalid dtype {props['dtype']} for column '{column}'. Check that specified dtype is allowable for the given type annotations."
)
return [
props["dtype"],
]
Expand All @@ -239,11 +243,13 @@ def _valid_dtypes( # noqa: C901
column, {"type": PYTHON_TO_PYDANTIC_TYPES.get(type(props["const"]))}
)
return None

return cls._pydantic_type_to_valid_polars_types(props)

@staticmethod
def _pydantic_type_to_valid_polars_types(props: Dict) -> Optional[List[pl.DataType]]:
def _pydantic_type_to_valid_polars_types(
props: Dict,
) -> Optional[List[PolarsDataType]]:
if props["type"] == "integer":
return [
pl.Int64,
Expand Down Expand Up @@ -574,6 +580,9 @@ class Model(BaseModel, metaclass=ModelMetaclass):

defaults: ClassVar[Dict[str, Any]]

if TYPE_CHECKING:
model_fields: ClassVar[dict[str, FieldInfo]]

@classmethod # type: ignore[misc]
@property
def DataFrame(
Expand Down Expand Up @@ -786,6 +795,8 @@ def example_value( # noqa: C901
field_type = "null"
else:
field_type = allowable[0]
else:
raise NotImplementedError
if "const" in properties:
# The default value is the only valid value, provided as const
return properties["const"]
Expand Down Expand Up @@ -1461,8 +1472,13 @@ def _derive_field(
if x in field.__slots__ and x not in ["annotation", "default"]
}
if make_nullable:
# This originally non-nullable field has become nullable
field_type = Optional[field_type]
if field_type is None:
raise TypeError(
"Cannot make field nullable if no type annotation is provided!"
)
else:
# This originally non-nullable field has become nullable
field_type = Optional[field_type]
elif field.is_required() and default is None:
# We need to replace Pydantic's None default value with ... in order
# to make it clear that the field is still non-nullable and
Expand Down Expand Up @@ -1507,10 +1523,10 @@ def __init__(
)


def Field(
def Field( # noqa: C901
*args,
**kwargs,
):
) -> Any:
pt_kwargs = {k: kwargs.pop(k, None) for k in get_args(PT_INFO)}
meta_kwargs = {
k: v for k, v in kwargs.items() if k in fields.FieldInfo.metadata_lookup
Expand Down
19 changes: 7 additions & 12 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,23 +461,18 @@ class Test(pt.Model):
foo: str | None = pt.Field(dtype=pl.Utf8)

assert Test.nullable_columns == {"foo"}
assert set(Test.valid_dtypes['foo']) == {pl.Utf8, pl.Null}
assert set(Test.valid_dtypes["foo"]) == {pl.Utf8, pl.Null}


def test_conflicting_type_dtype():

class Test(pt.Model):
class Test1(pt.Model):
foo: int = pt.Field(dtype=pl.Utf8)

with pytest.raises(ValueError):
Test.valid_dtypes()
class Test(pt.Model):
Test1.valid_dtypes

class Test2(pt.Model):
foo: str = pt.Field(dtype=pl.Float32)

with pytest.raises(ValueError):
Test.valid_dtypes()


if __name__ == "__main__":
test_conflicting_type_dtype()
with pytest.raises(ValueError):
Test2.valid_dtypes
2 changes: 0 additions & 2 deletions tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,5 +253,3 @@ class Model(pt.Model):

# Or a list of columns
assert df.drop(["column_1", "column_2"]).columns == []


0 comments on commit 161300b

Please sign in to comment.