### Some example *pydantic* models that could have been patito models that describe a dataframe

In [264]:
from pydantic import BaseModel, Field
from enum import Enum
import polars as pl
from datetime import date, datetime
from typing import Literal, Optional, get_args, get_origin
from pydantic.fields import FieldInfo
from polars.datatypes import (
    DataTypeClass as PolarsDataType,
)  # no idea if this should be DataTypeClass or DatType, but only the former works
from pprint import pprint
from types import NoneType, UnionType


class Foo(Enum):
    A = 1
    B = 2


class SimpleExample(BaseModel):
    id: str                                                                    # pl.Utf8, required, not nullable, not unique, str, no constraints
    name: str                                                                  # pl.Utf8, required, not nullable, not unique, str, no constraints
    int_with_dtype_value: int = Field(json_schema_extra={"dtype": pl.Int16()}) # pl.Int16, required, not nullable, not unique, int, no constraints
    not_required_bc_has_default: bool = True                                   # pl.Boolean, not required, nullable, not unique, bool, no constraints


class NearlyCompleteExample(BaseModel):
    int_with_dtype_value: int = Field(json_schema_extra={"dtype": pl.Int64()})
    int_value: int
    float_value: float
    str_value: str
    bool_value: bool
    list_value: list[int]
    list_value_nullable: list[int | None]
    literal_value: Literal["a", "b"]
    default_value: str = "my_default"
    optional_value: Optional[int]
    bounded_value: int = Field(ge=10, le=20)
    date_value: date
    datetime_value: datetime
    enum_value: Foo

### A dataclass-like structure that should describe the schema of a column, including any constraints or similar

In [100]:
class ColumnInfo(BaseModel, arbitrary_types_allowed=True):
    """A model containing info patito needs about a column."""

    name: str
    dtype: PolarsDataType
    required: bool
    nullable: bool
    unique: bool
    type_annotation: type
    contraints: list[pl.Expr] | None = None

### Mapping between type hints and the biggest polars data type

In [101]:
PYTHON_TO_POLARS_TYPES: dict[type, PolarsDataType] = {
    str: pl.Utf8,
    int: pl.Int64,
    float: pl.Float64,
    bool: pl.Boolean,
    list: pl.List,
}

### Utility functions that let us reason about type hints

In [258]:
from re import U
from typing import cast


def is_single_type(type: type) -> bool:
    """Type hint is a single type.

    True for: int, str, float, bool, etc.
    False for: Optional[int], Union[int, str], Literal["a", "b"], etc.
    """
    return get_args(type) == ()


def is_literal(type_: type) -> bool:
    "Determine whether the type hint is a Literal type."
    try:
        return type_.__dict__["__origin__"] is Literal
    except KeyError:
        return False


def get_enum_inner_type(enum: type) -> type | None:
    "Get the type of the values of the enum if it exists, None otherwise."
    if issubclass(enum, Enum):
        enum_types = set(type(value) for value in enum)  # type: ignore
        if len(enum_types) > 1:
            raise TypeError(
                "All enumerated values of enums used to annotate "
                "Patito model fields must have the same type. "
                "Encountered types: "
                f"{sorted(map(lambda t: t.__name__, enum_types))}."
            )
        enum_type = enum_types.pop()
    else:
        enum_type = None
    return enum_type


def get_polars_dtype(field_info: FieldInfo) -> PolarsDataType | None:
    """Get polars dtype if is specified in the json schema extra"""
    if schema_extra := field_info.json_schema_extra:
        dtype = cast(PolarsDataType, schema_extra.get("dtype"))
    else:
        dtype = None
    return dtype


def get_dtype(field_info: FieldInfo, python_type_hint: type) -> PolarsDataType:
    """Get polars dtype from field info if it is specified as an Extra, or from the type hint."""
    return get_polars_dtype(field_info) or get_polars_type_from_annotation(python_type_hint)


def get_is_unique(field_info: FieldInfo) -> bool:
    "Get whether the field is unique if it is specified in the json schema extra."
    if schema_extra := field_info.json_schema_extra:
        is_unique = bool(schema_extra.get("unique", False))
    else:
        is_unique = False
    return is_unique


def unnest_type_hint(python_type_hint: type) -> list[type]:
    """Disassemble a generic type into its components.

    Example:
        list[list[int]] -> [list, list, int]
    """
    origin = get_origin(python_type_hint)
    if origin is None:
        return [python_type_hint]
    else:
        args = get_args(python_type_hint)
        foo = list(unnest_type_hint(args[0]))
        return [origin] + foo

# def unnest_type_hint(python_type_hint: type) -> list[type | tuple[type, ...]]:
#     """Disassemble a generic type into its components.

#     Example:
#         list[list[int]] -> [list, list, int]
#     """
#     origin = get_origin(python_type_hint)

#     if origin is None:
#         return [python_type_hint]

#     else:
#         inner_args = get_args(python_type_hint)
#         if get_origin(python_type_hint) is UnionType:
#             print(len(inner_args))
#             foo = tuple(unnest_type_hint(arg) for arg in inner_args)
#         else:
#             foo = unnest_type_hint(inner_args[0])
#         return [origin] + [foo]

def get_type_hint_as_list(python_type_hint: type):
    types_with_uniontype = unnest_type_hint(python_type_hint)
    types = [type_ for type_ in types_with_uniontype if type_ is not UnionType]
    return types

def get_polars_type_from_annotation(python_type_hint: type) -> PolarsDataType:
    """Get the full nested polars type from a type hint."""
    python_types = unnest_type_hint(python_type_hint)
    polars_types = [PYTHON_TO_POLARS_TYPES[type_] for type_ in python_types]

    polars_types_reversed = list(reversed(polars_types))
    
    # first type is e.g. `pl.Int64` without the (), so we 
    # add () to call it
    full_type = polars_types_reversed[0]()
    for type_ in polars_types_reversed[1:]:
        full_type = type_(full_type)  # e.g. pl.List(pl.Int64())
    return full_type

In [217]:
# Test unnesting of type hints
test_cases = [
    (list[list[int]], (list, list, int)),
    (list[float], (list, float)),
    (float, (float,)),
    (int | None, ((int, None),)),
    (list[int | None] | None, ((list, (int, None)), None)),
]

# Run the test cases
for python_type_hint, expected in test_cases:
    result = unnest_type_hint(python_type_hint)
    assert result == expected, python_type_hint


test_cases = [
    (list[list[int]], (list, list, int)),
    (list[float], (list, float)),
    (float, (float,)),
    (int | None, ((int, None),)),
    (list[int | None] | None, ((list, (int, None)), None)),
]

got origin union
here None
here None
here None


### Convert a pydantic schema into a patito schema

In [261]:
model = SimpleExample  # or NearlyCompleteExample

model_fields = SimpleExample.model_fields

# print the field name, field info and the result of typing.get_args on the type annotation
for field_name, field_info in model_fields.items():
    print(field_name, field_info, "\t\t", get_args(field_info.annotation))

id annotation=str required=True 		 ()
name annotation=str required=True 		 ()
int_with_dtype_value annotation=int required=True json_schema_extra={'dtype': Int64} 		 ()
not_required_bc_has_default annotation=bool required=False default=True 		 ()


In [262]:
fields = {}
for field_name, field_info in model_fields.items():
    fields[field_name] = {}

    annotation = field_info.annotation

    assert annotation is not None, (
        f"Encountered a case where `field_info.annotation` is None for field `{field_name}`.`"
        "Please report this with an example of your Model in an issue to the patito github repo."
    )

    # If the column type has a union, then the only sensible option is that we are unioned with None
    # this means that the column is nullable
    nullable = True if get_origin(annotation) is UnionType else False
    
    if is_single_type(annotation):
        # e.g. regular type like int, float, str, bool, but also Enum
        if enum_type := get_enum_inner_type(annotation):
            python_type = enum_type
        else:
            python_type = annotation
    else:
        # `get_args`` transforms:
        # int into () # single types become empty tuples
        # Union[int, str] into (int, str)
        # Optional[int] into (int, NoneType)
        # Literal["a", "b"] into ("a", "b")
        # list[int] into (int,)
        type_args = get_args(field_info.annotation)  # type_args is now e.g. (int, NoneType)

        non_null_type_args: tuple[type, ...] = tuple(arg for arg in type_args if arg is not NoneType)  # e.g. ()
        # should perform some recursion here to distangle nested types like list[list[int]] or list[Optional[int]]
        python_type = non_null_type_args[0] # Should only be one type here, or e.g. int or List[int], not a Union

        
    dtype = get_dtype(field_info, python_type) # currently doesn't None *inside* list

    column_info = ColumnInfo(
        name=field_name,
        dtype=dtype,
        required=field_info.is_required(),
        nullable=nullable,
        unique=get_is_unique(field_info),
        type_annotation=annotation,
    )

    fields[field_name] = column_info

pprint(fields)

{'id': ColumnInfo(name='id', dtype=Utf8, required=True, nullable=False, unique=False, type_annotation=<class 'str'>, contraints=None),
 'int_with_dtype_value': ColumnInfo(name='int_with_dtype_value', dtype=Int64, required=True, nullable=False, unique=False, type_annotation=<class 'int'>, contraints=None),
 'name': ColumnInfo(name='name', dtype=Utf8, required=True, nullable=False, unique=False, type_annotation=<class 'str'>, contraints=None),
 'not_required_bc_has_default': ColumnInfo(name='not_required_bc_has_default', dtype=Boolean, required=False, nullable=False, unique=False, type_annotation=<class 'bool'>, contraints=None)}


The above fields are now mostly ready to be used in patito. Still need to add some logic around parsing constraints from the json_schema_extra, and decide how to handle the inner-optional type `list[int | None]` - How should we parse that into `ColumnInfo`? What about `list[list[int | None] | None]`?.

### Tests I've written along the way

In [82]:
# Test unnesting of type hints
test_cases = [
    (list[list[int]], [list, list, int]),
    (float, [float,]),
    (list[float], [list, float]),
    (list[int | None], [list, (int, None)]),
    (int | None, [(int, None)]),
    (list[list[int | None]], [list, list, (int, None)]),
]

# Run the test cases
for python_type_hint, expected in test_cases:
    result = unnest_type_hint(python_type_hint)
    assert result == expected, python_type_hint




In [None]:
# An attempt at unnesting type hints into tuples
type_ = list[int | None] | None

entries = (type_,)

len_entries = len(entries)
count = 0
while count < len_entries:
    entries2 = ()
    count = 0
    len_entries = len(entries)
    for entry in entries:
        origin = get_origin(entry)
        if origin == UnionType:
            print("got origin union")
            union_entries = tuple(entry2 for entry2 in get_args(type_))
            entries2 += union_entries
        elif origin == list:
            union_entries = tuple(entry2 for entry2 in get_args(type_))
            entries2 += ((list, union_entries),)
        else:
            print("here", origin)
            entries2 += (entry,)
            count += 1
    entries = entries2
