Skip to content

Commit

Permalink
feat: position mul
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Aug 7, 2024
1 parent d724a47 commit fa62473
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 120 deletions.
5 changes: 2 additions & 3 deletions src/coordinax/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from dataclasses import fields
from enum import Enum
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Literal, TypeVar
from typing_extensions import Never
from typing import TYPE_CHECKING, Any, Literal, NoReturn, TypeVar

import astropy.units as u
import equinox as eqx
Expand Down Expand Up @@ -371,7 +370,7 @@ def __neg__(self) -> "Self":
def __rmul__(self: "AbstractVector", other: Any) -> Any:
return NotImplemented

def __setitem__(self, k: Any, v: Any) -> Never:
def __setitem__(self, k: Any, v: Any) -> NoReturn:
msg = f"{type(self).__name__} is immutable."
raise TypeError(msg)

Expand Down
153 changes: 134 additions & 19 deletions src/coordinax/_base_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@
from inspect import isabstract
from typing import TYPE_CHECKING, Any, TypeVar

import equinox as eqx
import jax
from jax import lax
from jaxtyping import ArrayLike
from plum import convert, dispatch
from quax import quaxify, register

import quaxed.array_api as xp
import quaxed.lax as qlax
from dataclassish import field_items
from unxt import Quantity

import coordinax._typing as ct
from ._base import AbstractVector
from ._utils import classproperty

Expand Down Expand Up @@ -96,7 +99,13 @@ def aval(self) -> jax.core.ShapedArray:
>>> vec = cx.CartesianPosition3D.constructor(Quantity([1, 2, 3], "m"))
>>> vec.aval()
ShapedArray(float32[3])
ConcreteArray([1. 2. 3.], dtype=float32)
>>> vec = cx.CartesianPosition3D.constructor(
... Quantity([[1, 2, 3], [4, 5, 6]], "m"))
>>> vec.aval()
ConcreteArray([[1. 2. 3.]
[4. 5. 6.]], dtype=float32)
"""
return jax.core.get_aval(
Expand Down Expand Up @@ -152,13 +161,13 @@ def __sub__(
def __mul__(
self: "AbstractPosition", other: ArrayLike
) -> "AbstractPosition": # TODO: use Self
return replace(self, **{k: v * other for k, v in field_items(self)})
return qlax.mul(self, other)

Check warning on line 164 in src/coordinax/_base_pos.py

View check run for this annotation

Codecov / codecov/patch

src/coordinax/_base_pos.py#L164

Added line #L164 was not covered by tests

@AbstractVector.__rmul__.dispatch # type: ignore[misc]
def __rmul__(
self: "AbstractPosition", other: ArrayLike
) -> "AbstractPosition": # TODO: use Self
return replace(self, **{k: other * v for k, v in field_items(self)})
return qlax.mul(other, self)

Check warning on line 170 in src/coordinax/_base_pos.py

View check run for this annotation

Codecov / codecov/patch

src/coordinax/_base_pos.py#L170

Added line #L170 was not covered by tests

@AbstractVector.__truediv__.dispatch # type: ignore[misc]
def __truediv__(
Expand Down Expand Up @@ -220,8 +229,8 @@ def represent_as(self, target: type[VT], /, *args: Any, **kwargs: Any) -> VT:

return represent_as(self, target, **kwargs)

@partial(jax.jit)
def norm(self) -> Quantity["length"]:
@partial(jax.jit, inline=True)
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.
Returns
Expand All @@ -231,18 +240,30 @@ def norm(self) -> Quantity["length"]:
Examples
--------
We assume the following imports:
>>> from unxt import Quantity
>>> from coordinax import CartesianPosition3D
>>> import coordinax as cx
We can compute the norm of a vector
>>> x, y, z = Quantity(1, "meter"), Quantity(2, "meter"), Quantity(3, "meter")
>>> vec = CartesianPosition3D(x=x, y=y, z=z)
>>> vec.norm()
>>> from unxt import Quantity
>>> from coordinax import CartesianPosition1D, RadialPosition
>>> v = cx.CartesianPosition1D.constructor(Quantity([-1], "kpc"))
>>> v.norm()
Quantity['length'](Array(1., dtype=float32), unit='kpc')
>>> v = cx.CartesianPosition2D.constructor(Quantity([3, 4], "kpc"))
>>> v.norm()
Quantity['length'](Array(5., dtype=float32), unit='kpc')
>>> v = cx.PolarPosition(r=Quantity(3, "kpc"), phi=Quantity(90, "deg"))
>>> v.norm()
Quantity['length'](Array(3., dtype=float32), unit='kpc')
>>> v = cx.CartesianPosition3D.constructor(Quantity([1, 2, 3], "m"))
>>> v.norm()
Quantity['length'](Array(3.7416575, dtype=float32), unit='m')
"""
return self.represent_as(self._cartesian_cls).norm()
return xp.linalg.vector_norm(self, axis=-1)


# ===================================================================
Expand All @@ -265,7 +286,100 @@ def normalize_vector(x: AbstractPosition, /) -> AbstractPosition:
raise NotImplementedError # pragma: no cover


@register(lax.reshape_p) # type: ignore[misc]
# ------------------------------------------------


@register(jax.lax.mul_p) # type: ignore[misc]
def _mul_p_vq(lhs: ArrayLike, rhs: AbstractPosition, /) -> AbstractPosition:
"""Scale a position by a scalar.
Examples
--------
>>> from unxt import Quantity
>>> import coordinax as cx
>>> import quaxed.numpy as jnp
>>> vec = cx.CartesianPosition3D.constructor(Quantity([1, 2, 3], "m"))
>>> jnp.multiply(2, vec)
CartesianPosition3D(
x=Quantity[...](value=f32[], unit=Unit("m")),
y=Quantity[...](value=f32[], unit=Unit("m")),
z=Quantity[...](value=f32[], unit=Unit("m"))
)
"""
# Validation
lhs = eqx.error_if(
lhs, any(jax.numpy.shape(lhs)), f"must be a scalar, not {type(lhs)}"
)
rhs = eqx.error_if(
rhs,
isinstance(rhs, rhs._cartesian_cls), # noqa: SLF001
"must register a Cartesian-specific dispatch",
)

rc = rhs.represent_as(rhs._cartesian_cls) # noqa: SLF001
nr = qlax.mul(lhs, rc)
return nr.represent_as(type(rhs))


@register(jax.lax.mul_p) # type: ignore[misc]
def _mul_p_qv(lhs: AbstractPosition, rhs: ArrayLike, /) -> AbstractPosition:
"""Scale a position by a scalar.
Examples
--------
>>> from unxt import Quantity
>>> import coordinax as cx
>>> import quaxed.numpy as jnp
>>> vec = cx.CartesianPosition3D.constructor(Quantity([1, 2, 3], "m"))
>>> jnp.multiply(vec, 2)
CartesianPosition3D(
x=Quantity[...](value=f32[], unit=Unit("m")),
y=Quantity[...](value=f32[], unit=Unit("m")),
z=Quantity[...](value=f32[], unit=Unit("m"))
)
"""
return qlax.mul(rhs, lhs) # re-dispatch on the other side


@register(jax.lax.mul_p) # type: ignore[misc]
def _mul_p_qq(lhs: AbstractPosition, rhs: AbstractPosition, /) -> Quantity:
"""Multiply two positions.
This is required to take the dot product of two vectors.
Examples
--------
>>> import quaxed.array_api as jnp
>>> from unxt import Quantity
>>> import coordinax as cx
>>> vec = cx.CartesianPosition3D(
... x=Quantity([1, 2, 3], "m"),
... y=Quantity([4, 5, 6], "m"),
... z=Quantity([7, 8, 9], "m"))
>>> jnp.multiply(vec, vec) # element-wise multiplication
Quantity['area'](Array([[ 1., 16., 49.],
[ 4., 25., 64.],
[ 9., 36., 81.]], dtype=float32), unit='m2')
>>> jnp.linalg.vector_norm(vec, axis=-1)
Quantity['length'](Array([ 8.124039, 9.643651, 11.224972], dtype=float32), unit='m')
""" # noqa: E501
lq = convert(lhs.represent_as(lhs._cartesian_cls), Quantity) # noqa: SLF001
rq = convert(rhs.represent_as(rhs._cartesian_cls), Quantity) # noqa: SLF001
return qlax.mul(lq, rq) # re-dispatch to Quantities


# ------------------------------------------------


@register(jax.lax.reshape_p) # type: ignore[misc]
def _reshape_p(
operand: AbstractPosition, *, new_sizes: tuple[int, ...], **kwargs: Any
) -> AbstractPosition:
Expand All @@ -282,19 +396,20 @@ def _reshape_p(
... z=Quantity([7, 8, 9], "m"))
>>> jnp.reshape(vec, shape=(3, 1, 3)) # (n_components *shape)
CartesianPosition3D(
x=Quantity[PhysicalType('length')](value=f32[1,3], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f32[1,3], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f32[1,3], unit=Unit("m"))
x=Quantity[PhysicalType('length')](value=f32[1,1,3], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f32[1,1,3], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f32[1,1,3], unit=Unit("m"))
)
"""
# Adjust the sizes for the components
new_sizes = new_sizes[1:]
new_sizes = (new_sizes[0] // len(operand.components), *new_sizes[1:])
# TODO: check integer division
# Reshape the components
return replace(
operand,
**{
k: quaxify(lax.reshape_p.bind)(v, new_sizes=new_sizes, **kwargs)
k: quaxify(jax.lax.reshape_p.bind)(v, new_sizes=new_sizes, **kwargs)
for k, v in field_items(operand)
},
)
36 changes: 17 additions & 19 deletions src/coordinax/_d1/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import equinox as eqx
import jax
from jaxtyping import ArrayLike
from quax import register

import quaxed.array_api as xp
from unxt import Quantity
Expand Down Expand Up @@ -109,25 +111,6 @@ def __sub__(
cart = other.represent_as(CartesianPosition1D)
return replace(self, x=self.x - cart.x)

# -----------------------------------------------------
# Methods

@partial(jax.jit)
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.
Examples
--------
>>> from unxt import Quantity
>>> from coordinax import CartesianPosition1D, RadialPosition
>>> q = CartesianPosition1D.constructor(Quantity([-1], "kpc"))
>>> q.norm()
Quantity['length'](Array(1., dtype=float32), unit='kpc')
"""
return xp.abs(self.x)


@final
class CartesianVelocity1D(AbstractVelocity1D):
Expand Down Expand Up @@ -208,3 +191,18 @@ def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableAcc:
"""
return xp.abs(self.d2_x)


# ===================================================================


@register(jax.lax.mul_p) # type: ignore[misc]
def _mul_p(lhs: ArrayLike, rhs: CartesianPosition1D, /) -> CartesianPosition1D:
"""Scale a position by a scalar."""
# Validation
lhs = eqx.error_if(

Check warning on line 203 in src/coordinax/_d1/cartesian.py

View check run for this annotation

Codecov / codecov/patch

src/coordinax/_d1/cartesian.py#L203

Added line #L203 was not covered by tests
lhs, any(jax.numpy.shape(lhs)), f"must be a scalar, not {type(lhs)}"
)

# Scale the components
return replace(rhs, x=lhs * rhs.x)

Check warning on line 208 in src/coordinax/_d1/cartesian.py

View check run for this annotation

Codecov / codecov/patch

src/coordinax/_d1/cartesian.py#L208

Added line #L208 was not covered by tests
47 changes: 17 additions & 30 deletions src/coordinax/_d2/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import equinox as eqx
import jax
from jaxtyping import ArrayLike
from quax import register

import quaxed.array_api as xp
from unxt import Quantity
Expand Down Expand Up @@ -103,21 +105,6 @@ def __sub__(
cart = other.represent_as(CartesianPosition2D)
return replace(self, x=self.x - cart.x, y=self.y - cart.y)

@partial(jax.jit)
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.
Examples
--------
>>> from unxt import Quantity
>>> from coordinax import CartesianPosition2D
>>> q = CartesianPosition2D.constructor(Quantity([3, 4], "kpc"))
>>> q.norm()
Quantity['length'](Array(5., dtype=float32), unit='kpc')
"""
return xp.sqrt(self.x**2 + self.y**2)


@final
class CartesianVelocity2D(AbstractVelocity2D):
Expand All @@ -143,21 +130,6 @@ def integral_cls(cls) -> type[CartesianPosition2D]:
def differential_cls(cls) -> type["CartesianAcceleration2D"]:
return CartesianAcceleration2D

@partial(jax.jit)
def norm(self, _: AbstractPosition2D | None = None, /) -> ct.BatchableSpeed:
"""Return the norm of the vector.
Examples
--------
>>> from unxt import Quantity
>>> from coordinax import CartesianVelocity2D
>>> v = CartesianVelocity2D.constructor(Quantity([3, 4], "km/s"))
>>> v.norm()
Quantity['speed'](Array(5., dtype=float32), unit='km / s')
"""
return xp.sqrt(self.d_x**2 + self.d_y**2)


@final
class CartesianAcceleration2D(AbstractAcceleration2D):
Expand Down Expand Up @@ -192,3 +164,18 @@ def norm(self, _: AbstractVelocity2D | None = None, /) -> ct.BatchableAcc:
"""
return xp.sqrt(self.d2_x**2 + self.d2_y**2)


# ===================================================================


@register(jax.lax.mul_p) # type: ignore[misc]
def _mul_p(lhs: ArrayLike, rhs: CartesianPosition2D, /) -> CartesianPosition2D:
"""Scale a position by a scalar."""
# Validation
lhs = eqx.error_if(

Check warning on line 176 in src/coordinax/_d2/cartesian.py

View check run for this annotation

Codecov / codecov/patch

src/coordinax/_d2/cartesian.py#L176

Added line #L176 was not covered by tests
lhs, any(jax.numpy.shape(lhs)), f"must be a scalar, not {type(lhs)}"
)

# Scale the components
return replace(rhs, x=lhs * rhs.x, y=lhs * rhs.y)

Check warning on line 181 in src/coordinax/_d2/cartesian.py

View check run for this annotation

Codecov / codecov/patch

src/coordinax/_d2/cartesian.py#L181

Added line #L181 was not covered by tests
Loading

0 comments on commit fa62473

Please sign in to comment.