From 87f5f60fcc7e73747da6abfe85e9bad9019b55b5 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Sat, 25 May 2024 12:24:33 -0400 Subject: [PATCH] fix: array-like constructors (#117) * fix: array-like constructors Signed-off-by: nstarman --- src/unxt/_quantity/base.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/unxt/_quantity/base.py b/src/unxt/_quantity/base.py index 9ba02db..a851ed4 100644 --- a/src/unxt/_quantity/base.py +++ b/src/unxt/_quantity/base.py @@ -5,12 +5,13 @@ from collections.abc import Callable, Sequence from dataclasses import fields, replace -from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypeGuard, TypeVar import equinox as eqx import jax import jax.core import jax.numpy as jnp +import numpy as np from astropy.units import ( CompositeUnit, Quantity as AstropyQuantity, @@ -32,6 +33,8 @@ FMT = TypeVar("FMT") +ArrayLikeScalar: TypeAlias = np.bool_ | np.number | bool | int | float | complex +ArrayLikeSequence: TypeAlias = list[ArrayLikeScalar] | tuple[ArrayLikeScalar, ...] def _flip_binop(binop: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]: @@ -139,7 +142,7 @@ def __repr__(self) -> str: @dispatcher def constructor( cls: "type[AbstractQuantity]", - value: ArrayLike, + value: ArrayLike | ArrayLikeSequence, unit: Any, /, *, @@ -149,7 +152,7 @@ def constructor( Parameters ---------- - value : ArrayLike + value : ArrayLike | list[...] | tuple[...] The array-like value. unit : Any The unit of the value. @@ -173,6 +176,12 @@ def constructor( >>> Quantity.constructor(x, "m") Quantity['length'](Array([1., 2., 3.], dtype=float32), unit='m') + >>> Quantity.constructor([1.0, 2, 3], "m") + Quantity['length'](Array([1., 2., 3.], dtype=float32), unit='m') + + >>> Quantity.constructor((1.0, 2, 3), "m") + Quantity['length'](Array([1., 2., 3.], dtype=float32), unit='m') + """ # Dispatch on both arguments. # Construct using the standard `__init__` method. @@ -181,7 +190,11 @@ def constructor( @classmethod # type: ignore[no-redef] @dispatcher def constructor( - cls: "type[AbstractQuantity]", value: ArrayLike, *, unit: Any, dtype: Any = None + cls: "type[AbstractQuantity]", + value: ArrayLike | ArrayLikeSequence, + *, + unit: Any, + dtype: Any = None, ) -> "AbstractQuantity": """Construct a `Quantity` from an array-like value and a unit kwarg. @@ -195,7 +208,7 @@ def constructor( @classmethod # type: ignore[no-redef] @dispatcher def constructor( - cls: "type[AbstractQuantity]", *, value: ArrayLike, unit: Any, dtype: Any = None + cls: "type[AbstractQuantity]", *, value: Any, unit: Any, dtype: Any = None ) -> "AbstractQuantity": """Construct a `Quantity` from value and unit kwargs.