Skip to content

Commit

Permalink
fix: promotion in sqrt
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed May 24, 2024
1 parent 4b49e41 commit 63a1cc0
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion src/unxt/_quantity/register_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3307,10 +3307,51 @@ def _sort_p_one_operand(


@register(lax.sqrt_p)
def _sqrt_p(x: AbstractQuantity) -> AbstractQuantity:
def _sqrt_p_q(x: AbstractQuantity) -> AbstractQuantity:
"""Square root of a quantity.
Examples
--------
>>> import quaxed.numpy as qnp
>>> from unxt import UncheckedQuantity
>>> q = UncheckedQuantity(9, "m")
>>> qnp.sqrt(q)
UncheckedQuantity(Array(3., dtype=float32), unit='m(1/2)')
>>> from unxt import Quantity
>>> q = Quantity(9, "m")
>>> qnp.sqrt(q)
Quantity['m0.5'](Array(3., dtype=float32), unit='m(1/2)')
"""
# Apply sqrt to the value and adjust the unit
return type_np(x)(lax.sqrt(x.value), unit=x.unit ** (1 / 2))


@register(lax.sqrt_p)
def _sqrt_p_d(x: AbstractDistance) -> Quantity:
"""Square root of a quantity.
Examples
--------
>>> import quaxed.numpy as qnp
>>> from unxt import Distance
>>> q = Distance(9, "m")
>>> qnp.sqrt(q)
Quantity['m0.5'](Array(3., dtype=float32), unit='m(1/2)')
>>> from unxt import Parallax
>>> q = Parallax(9, "mas")
>>> qnp.sqrt(q)
Quantity['rad0.5'](Array(3., dtype=float32), unit='mas(1/2)')
"""
# Promote to something that supports sqrt units.
return Quantity(lax.sqrt(x.value), unit=x.unit ** (1 / 2))


# ==============================================================================


Expand Down

0 comments on commit 63a1cc0

Please sign in to comment.