Skip to content

Commit

Permalink
Merge pull request #65 from JakobGM/thomasaarholt/simplify-default-dt…
Browse files Browse the repository at this point in the history
…ypes

Simplify default_dtypes and return DataType over DataTypeClass when possible"
  • Loading branch information
thomasaarholt committed Apr 7, 2024
2 parents f80aba1 + 9578319 commit fd2f49e
Showing 1 changed file with 16 additions and 23 deletions.
39 changes: 16 additions & 23 deletions src/patito/_pydantic/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,17 @@ def valid_dtypes_for_model(
@cache
def default_dtypes_for_model(
cls: Type[ModelType],
) -> dict[str, DataTypeClass | DataType]:
default_dtypes = {}
) -> dict[str, DataType]:
default_dtypes: dict[str, DataType] = {}
for column in cls.columns:
dtype = cls.column_infos[column].dtype
dtype = (
cls.column_infos[column].dtype
or DtypeResolver(cls.model_fields[column].annotation).default_polars_dtype()
)
if dtype is None:
default_dtype = DtypeResolver(
cls.model_fields[column].annotation
).default_polars_dtype()
if default_dtype is None:
raise ValueError(
f"Unable to find a default dtype for column `{column}`"
)
else:
default_dtypes[column] = default_dtype
else:
default_dtypes[column] = (
dtype if isinstance(dtype, DataType) else dtype()
) # if dtype is not instantiated, instantiate it
raise ValueError(f"Unable to find a default dtype for column `{column}`")

default_dtypes[column] = dtype if isinstance(dtype, DataType) else dtype()
return default_dtypes


Expand Down Expand Up @@ -130,9 +123,9 @@ def valid_polars_dtypes(self) -> DataTypeGroup:
return PT_BASE_SUPPORTED_DTYPES
return self._valid_polars_dtypes_for_schema(self.schema)

def default_polars_dtype(self) -> DataTypeClass | DataType | None:
def default_polars_dtype(self) -> DataType | None:
if self.annotation == Any:
return pl.String
return pl.String()
return self._default_polars_dtype_for_schema(self.schema)

def _valid_polars_dtypes_for_schema(
Expand Down Expand Up @@ -197,9 +190,7 @@ def _pydantic_subschema_to_valid_polars_types(
PydanticBaseType(pyd_type), props.get("format"), props.get("enum")
)

def _default_polars_dtype_for_schema(
self, schema: Dict
) -> DataTypeClass | DataType | None:
def _default_polars_dtype_for_schema(self, schema: Dict) -> DataType | None:
if "anyOf" in schema:
if len(schema["anyOf"]) == 2: # look for optionals first
schema = _without_optional(schema)
Expand All @@ -216,10 +207,12 @@ def _default_polars_dtype_for_schema(
def _pydantic_subschema_to_default_dtype(
self,
props: Dict,
) -> DataTypeClass | DataType | None:
) -> DataType | None:
if "column_info" in props: # user has specified in patito model
if props["column_info"]["dtype"] is not None:
return dtype_from_string(props["column_info"]["dtype"])
dtype = dtype_from_string(props["column_info"]["dtype"])
dtype = dtype() if isinstance(dtype, DataTypeClass) else dtype
return dtype
if "type" not in props:
if "enum" in props:
raise TypeError("Mixed type enums not supported by patito.")
Expand Down

0 comments on commit fd2f49e

Please sign in to comment.