From 63a1cc0b949bcccab4708606da452b30d12f6bbf Mon Sep 17 00:00:00 2001 From: nstarman Date: Fri, 24 May 2024 10:18:55 -0400 Subject: [PATCH] fix: promotion in sqrt Signed-off-by: nstarman --- src/unxt/_quantity/register_primitives.py | 43 ++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/unxt/_quantity/register_primitives.py b/src/unxt/_quantity/register_primitives.py index 1a4e28f..6df96d8 100644 --- a/src/unxt/_quantity/register_primitives.py +++ b/src/unxt/_quantity/register_primitives.py @@ -3307,10 +3307,51 @@ def _sort_p_one_operand( @register(lax.sqrt_p) -def _sqrt_p(x: AbstractQuantity) -> AbstractQuantity: +def _sqrt_p_q(x: AbstractQuantity) -> AbstractQuantity: + """Square root of a quantity. + + Examples + -------- + >>> import quaxed.numpy as qnp + + >>> from unxt import UncheckedQuantity + >>> q = UncheckedQuantity(9, "m") + >>> qnp.sqrt(q) + UncheckedQuantity(Array(3., dtype=float32), unit='m(1/2)') + + >>> from unxt import Quantity + >>> q = Quantity(9, "m") + >>> qnp.sqrt(q) + Quantity['m0.5'](Array(3., dtype=float32), unit='m(1/2)') + + """ + # Apply sqrt to the value and adjust the unit return type_np(x)(lax.sqrt(x.value), unit=x.unit ** (1 / 2)) +@register(lax.sqrt_p) +def _sqrt_p_d(x: AbstractDistance) -> Quantity: + """Square root of a quantity. + + Examples + -------- + >>> import quaxed.numpy as qnp + + >>> from unxt import Distance + >>> q = Distance(9, "m") + >>> qnp.sqrt(q) + Quantity['m0.5'](Array(3., dtype=float32), unit='m(1/2)') + + >>> from unxt import Parallax + >>> q = Parallax(9, "mas") + >>> qnp.sqrt(q) + Quantity['rad0.5'](Array(3., dtype=float32), unit='mas(1/2)') + + """ + # Promote to something that supports sqrt units. + return Quantity(lax.sqrt(x.value), unit=x.unit ** (1 / 2)) + + # ==============================================================================