Skip to content

Commit

Permalink
fix: array-like constructors (#117)
Browse files Browse the repository at this point in the history
* fix: array-like constructors

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed May 25, 2024
1 parent 73c9e5e commit 87f5f60
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions src/unxt/_quantity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

from collections.abc import Callable, Sequence
from dataclasses import fields, replace
from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, TypeVar
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypeGuard, TypeVar

import equinox as eqx
import jax
import jax.core
import jax.numpy as jnp
import numpy as np
from astropy.units import (
CompositeUnit,
Quantity as AstropyQuantity,
Expand All @@ -32,6 +33,8 @@


FMT = TypeVar("FMT")
ArrayLikeScalar: TypeAlias = np.bool_ | np.number | bool | int | float | complex
ArrayLikeSequence: TypeAlias = list[ArrayLikeScalar] | tuple[ArrayLikeScalar, ...]


def _flip_binop(binop: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]:
Expand Down Expand Up @@ -139,7 +142,7 @@ def __repr__(self) -> str:
@dispatcher
def constructor(
cls: "type[AbstractQuantity]",
value: ArrayLike,
value: ArrayLike | ArrayLikeSequence,
unit: Any,
/,
*,
Expand All @@ -149,7 +152,7 @@ def constructor(
Parameters
----------
value : ArrayLike
value : ArrayLike | list[...] | tuple[...]
The array-like value.
unit : Any
The unit of the value.
Expand All @@ -173,6 +176,12 @@ def constructor(
>>> Quantity.constructor(x, "m")
Quantity['length'](Array([1., 2., 3.], dtype=float32), unit='m')
>>> Quantity.constructor([1.0, 2, 3], "m")
Quantity['length'](Array([1., 2., 3.], dtype=float32), unit='m')
>>> Quantity.constructor((1.0, 2, 3), "m")
Quantity['length'](Array([1., 2., 3.], dtype=float32), unit='m')
"""
# Dispatch on both arguments.
# Construct using the standard `__init__` method.
Expand All @@ -181,7 +190,11 @@ def constructor(
@classmethod # type: ignore[no-redef]
@dispatcher
def constructor(
cls: "type[AbstractQuantity]", value: ArrayLike, *, unit: Any, dtype: Any = None
cls: "type[AbstractQuantity]",
value: ArrayLike | ArrayLikeSequence,
*,
unit: Any,
dtype: Any = None,
) -> "AbstractQuantity":
"""Construct a `Quantity` from an array-like value and a unit kwarg.
Expand All @@ -195,7 +208,7 @@ def constructor(
@classmethod # type: ignore[no-redef]
@dispatcher
def constructor(
cls: "type[AbstractQuantity]", *, value: ArrayLike, unit: Any, dtype: Any = None
cls: "type[AbstractQuantity]", *, value: Any, unit: Any, dtype: Any = None
) -> "AbstractQuantity":
"""Construct a `Quantity` from value and unit kwargs.
Expand Down

0 comments on commit 87f5f60

Please sign in to comment.