In [264]:
from typing import Any, Sequence, get_args, get_origin, Callable
import functools

from typing_extensions import Annotated, TypeVar

import numpy as np

import numpydantic

import pydantic
from pydantic_core import core_schema
from pydantic import (
    BaseModel,
    GetCoreSchemaHandler,
    WrapValidator,
    ValidationInfo,
    ValidatorFunctionWrapHandler,
)

import pint

In [2]:
ureg = pint.UnitRegistry()

In [3]:
class DimensionError(ValueError):
    pass

In [36]:
_pint_base_repr = core_schema.tuple_positional_schema(items_schema=[
    core_schema.float_schema(),
    core_schema.str_schema()
])

In [94]:
numpydantic.maps.np_to_python[numpydantic.dtype.Float]

KeyError: (<class 'numpy.float16'>, <class 'numpy.float32'>, <class 'numpy.float64'>, <class 'numpy.float32'>, <class 'numpy.float64'>)

In [9]:
numpydantic.NDArray.__get_pydantic_core_schema__(_source_type=numpydantic.NDArray[numpydantic.Shape['2 x, 2 y'], numpydantic.dtype.Float], _handler=None)

{'type': 'json-or-python',
 'json_schema': {'type': 'list',
  'items_schema': {'type': 'list',
   'items_schema': {'type': 'float'},
   'min_length': 2,
   'max_length': 2,
   'metadata': {'name': 'y'}},
  'min_length': 2,
  'max_length': 2,
  'metadata': {'name': 'x'}},
 'python_schema': {'type': 'function-plain',
  'function': {'type': 'with-info',
   'function': <function numpydantic.schema.get_validate_interface.<locals>.validate_interface(value: Any, info: Optional[ForwardRef('ValidationInfo')] = None) -> numpydantic.types.NDArrayType>}},
 'serialization': {'type': 'function-plain',
  'function': <function numpydantic.schema._jsonize_array(value: Any, info: pydantic_core.core_schema.SerializationInfo) -> Union[list, dict]>,
  'info_arg': True,
  'when_used': 'json'}}

In [109]:
numpydantic.schema.get_validate_interface(Any, numpydantic.dtype.Int)(2.3)

DtypeError: Invalid dtype! expected (<class 'numpy.int8'>, <class 'numpy.int16'>, <class 'numpy.int32'>, <class 'numpy.int64'>, <class 'numpy.int16'>, <class 'numpy.uint8'>, <class 'numpy.uint16'>, <class 'numpy.uint32'>, <class 'numpy.uint64'>, <class 'numpy.uint16'>), got float64

typing.Union[str, type, typing.Any, numpy.generic]

In [322]:
def to_tuple(q: pint.Quantity):
    base = q.to_base_units()
    return (base.magnitude, str(base.units))

def to_array_tuple(q: pint.Quantity, info, array_serializer: Callable):
    base = q.to_base_units()
    return (array_serializer(base.magnitude, info=info), str(base.units))

In [323]:
def dummy_serializer(value, info):
    print("dummy", value, info)

functools.partial(to_array_tuple, array_serializer=dummy_serializer)(pint.Quantity(23, 'm'), info=None)

dummy 23 None


(None, 'meter')

In [277]:
def get_basic_type(t):
    if isinstance(t, str):
        t = np.dtype(t)
    if isinstance(t, Sequence):
        # numpydantic.dtype.Float is a sequence, for example
        # They all map to the same basic Python type
        t = t[0]
    if t in (float, int, complex):
        return t
    dtype = np.dtype(t)
    return numpydantic.maps.np_to_python[dtype.type]

In [278]:
def get_schema(t):
    basic_type = get_basic_type(t)
    if basic_type is float:
        return core_schema.float_schema()
    elif basic_type is int:
        return core_schema.int_schema()
    elif basic_type is complex:
        return core_schema.complex_schema()
    else:
        raise NotImplementedError(t)


In [279]:
class PintAnnotation:
    @classmethod
    def __get_pydantic_core_schema__(
        cls,
        _source_type: Any,
        _handler: GetCoreSchemaHandler,
    ) -> core_schema.CoreSchema:
        
        return core_schema.json_or_python_schema(
            json_schema=_pint_base_repr,
            python_schema=core_schema.is_instance_schema(pint.Quantity),
            serialization=core_schema.plain_serializer_function_ser_schema(
                to_tuple
            ),
        )

In [280]:
class PintArrayAnnotation:
    @classmethod
    def __get_pydantic_core_schema__(
        cls,
        _source_type: Any,
        _handler: GetCoreSchemaHandler,
    ) -> core_schema.CoreSchema:
        
        return core_schema.json_or_python_schema(
            json_schema=_pint_base_repr,
            python_schema=core_schema.is_instance_schema(pint.Quantity),
            serialization=core_schema.plain_serializer_function_ser_schema(
                to_tuple
            ),
        )

In [281]:
_length_dim = ureg.meter.dimensionality

In [282]:
str((1 * ureg.meter).to_base_units().units)

'meter'

In [283]:
t = numpydantic.NDArray[numpydantic.Shape['2 x, 2 y'], float]

In [284]:
t.__get_pydantic_core_schema__(t, None)

{'type': 'json-or-python',
 'json_schema': {'type': 'list',
  'items_schema': {'type': 'list',
   'items_schema': {'type': 'float'},
   'min_length': 2,
   'max_length': 2,
   'metadata': {'name': 'y'}},
  'min_length': 2,
  'max_length': 2,
  'metadata': {'name': 'x'}},
 'python_schema': {'type': 'function-plain',
  'function': {'type': 'with-info',
   'function': <function numpydantic.schema.get_validate_interface.<locals>.validate_interface(value: Any, info: Optional[ForwardRef('ValidationInfo')] = None) -> numpydantic.types.NDArrayType>}},
 'serialization': {'type': 'function-plain',
  'function': <function numpydantic.schema._jsonize_array(value: Any, info: pydantic_core.core_schema.SerializationInfo) -> Union[list, dict]>,
  'info_arg': True,
  'when_used': 'json'}}

In [326]:
class Length[DType](pint.Quantity):
    @classmethod
    def __get_pydantic_core_schema__(
        cls,
        _source_type: Any,
        _handler: GetCoreSchemaHandler,
    ) -> core_schema.CoreSchema:
        (dtype, ) = get_args(_source_type)
        magnitude_schema = get_schema(dtype)
        reference = 1 * ureg.meter
        units = str(reference.to_base_units().units)
        validate_function = numpydantic.schema.get_validate_interface(Any, dtype)
        target_type = get_basic_type(dtype)
        
        def validator(v: Any, info: core_schema.ValidationInfo) -> pint.Quantity:
            if isinstance(v, pint.Quantity):
                pass
            elif isinstance(v, Sequence):
                magnitude, unit = v
                # Turn into Quantity: magnitude * unit
                v = magnitude * ureg(unit)
            else:
                raise ValueError(f"Don't know how to interpret type {type(v)}.")
            # Check dimension
            if not v.check(reference.dimensionality):
                raise DimensionError(f"Expected dimensionality {reference.dimensionality}, got quantity {v}.")
            try:
                # First, try as-is
                validate_function(v.magnitude)
            except Exception:
                # See if we can go from int to float, for example
                if np.can_cast(type(v.magnitude), target_type):
                    v = target_type(v.magnitude) * v.units
                    validate_function(v.magnitude)
                else:
                    raise
            # Return target type
            return v
        
        json_schema = core_schema.tuple_positional_schema(items_schema=[
            magnitude_schema,
            core_schema.literal_schema([units])
        ])
        return core_schema.json_or_python_schema(
            json_schema=json_schema,
            python_schema=core_schema.with_info_plain_validator_function(validator),
            serialization=core_schema.plain_serializer_function_ser_schema(
                to_tuple
            ),
        )

class LengthArray[Shape, DType](pint.Quantity):
    @classmethod
    def __get_pydantic_core_schema__(
        cls,
        _source_type: Any,
        _handler: GetCoreSchemaHandler,
    ) -> core_schema.CoreSchema:
        shape, dtype = get_args(_source_type)
        numpydantic_type = numpydantic.NDArray[shape, dtype]
        reference = 1 * ureg.meter
        units = str(reference.to_base_units().units)
        validate_function = numpydantic.schema.get_validate_interface(shape, dtype)
        target_type = get_basic_type(dtype)

        magnitude_schema = numpydantic_type.__get_pydantic_core_schema__(
            _source_type=numpydantic_type,
            _handler=_handler
        )
        
        def validator(v: Any, info: core_schema.ValidationInfo) -> pint.Quantity:
            if isinstance(v, pint.Quantity):
                pass
            elif isinstance(v, Sequence):
                magnitude, unit = v
                # Turn into Quantity: magnitude * unit
                v = np.asarray(magnitude) * ureg(unit)
            else:
                raise ValueError(f"Don't know how to interpret type {type(v)}.")
            # Check dimension
            if not v.check(reference.dimensionality):
                raise DimensionError(f"Expected dimensionality {reference.dimensionality}, got quantity {v}.")
            try:
                # First, try as-is
                validate_function(v.magnitude)
            except Exception:
                # See if we can go from int to float, for example
                if np.can_cast(v.magnitude, target_type):
                    v = v.magnitude.astype(target_type) * v.units
                    validate_function(v.magnitude)
                else:
                    raise
            # Return target type
            return v
        
        json_schema = core_schema.tuple_positional_schema(items_schema=[
            magnitude_schema['json_schema'],
            core_schema.literal_schema([units])
        ])

        serializer = functools.partial(to_array_tuple, array_serializer=magnitude_schema['serialization']['function'])
        return core_schema.json_or_python_schema(
            json_schema=json_schema,
            python_schema=core_schema.with_info_plain_validator_function(validator),
            serialization=core_schema.plain_serializer_function_ser_schema(
                function=serializer,
                info_arg=True,
            ),
        )


In [327]:
functools.partial(to_array_tuple, array_serializer=dummy_serializer)(pint.Quantity(23, 'm'), 2)

dummy 23 2


(None, 'meter')

In [335]:
class TestSchema(BaseModel):
    l: LengthArray[numpydantic.Shape['* y, 2 x'], numpydantic.dtype.Float]
    m: Length[complex]

In [338]:
t = TestSchema(m=pint.Quantity(23.3, 'cm'), l=pint.Quantity(np.array([(1, 2), (3, 4), (5, 6)]), 'km'))

In [339]:
t.model_dump()

{'l': ([[1000.0, 2000.0], [3000.0, 4000.0], [5000.0, 6000.0]], 'meter'),
 'm': ((0.233+0j), 'meter')}

In [340]:
class NumpySchema(BaseModel):
    a: NDArray[Shape['2 x, 2 y'], Float]
    b: float

NameError: name 'NDArray' is not defined

In [33]:
n = NumpySchema(a=np.zeros((2, 2)), b=3)

In [34]:
n.model_json_schema()

{'properties': {'a': {'items': {'items': {'type': 'number'},
    'maxItems': 2,
    'minItems': 2,
    'type': 'array'},
   'maxItems': 2,
   'minItems': 2,
   'title': 'A',
   'type': 'array'},
  'b': {'title': 'B', 'type': 'number'}},
 'required': ['a', 'b'],
 'title': 'NumpySchema',
 'type': 'object'}