From 08aa512cd4b3bc39a2f3053b586feecd45d25598 Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 13 Mar 2024 15:26:01 -0400 Subject: [PATCH] feat: interpolated integration Signed-off-by: nstarman --- src/galax/coordinates/_psp/__init__.py | 3 +- src/galax/coordinates/_psp/interpolated.py | 85 ++++++++++ src/galax/coordinates/_psp/psp.py | 85 +--------- .../dynamics/_dynamics/integrate/_api.py | 9 +- .../dynamics/_dynamics/integrate/_base.py | 8 +- .../dynamics/_dynamics/integrate/_builtin.py | 156 +++++++++++++++--- .../dynamics/_dynamics/integrate/_funcs.py | 38 ++++- src/galax/dynamics/_dynamics/orbit.py | 2 +- src/galax/potential/_potential/base.py | 15 +- src/galax/typing.py | 3 + 10 files changed, 279 insertions(+), 125 deletions(-) create mode 100644 src/galax/coordinates/_psp/interpolated.py diff --git a/src/galax/coordinates/_psp/__init__.py b/src/galax/coordinates/_psp/__init__.py index 0ec944b1..9dcaa589 100644 --- a/src/galax/coordinates/_psp/__init__.py +++ b/src/galax/coordinates/_psp/__init__.py @@ -8,4 +8,5 @@ from . import operator_compat # noqa: F401 from .base import AbstractPhaseSpacePosition -from .psp import InterpolatedPhaseSpacePosition, PhaseSpacePosition +from .interpolated import InterpolatedPhaseSpacePosition +from .psp import PhaseSpacePosition diff --git a/src/galax/coordinates/_psp/interpolated.py b/src/galax/coordinates/_psp/interpolated.py new file mode 100644 index 00000000..cd9e8ca6 --- /dev/null +++ b/src/galax/coordinates/_psp/interpolated.py @@ -0,0 +1,85 @@ +"""galax: Galactic Dynamix in Jax.""" + +__all__ = ["InterpolatedPhaseSpacePosition"] + +from typing import Protocol, final, runtime_checkable + +import equinox as eqx +import jax.numpy as jnp + +from coordinax import Abstract3DVector, Abstract3DVectorDifferential +from unxt import AbstractUnitSystem, Quantity + +import galax.typing as gt +from .base import AbstractPhaseSpacePosition +from .utils import _p_converter, _q_converter +from galax.utils._shape import batched_shape, expand_batch_dims, vector_batched_shape + + +@runtime_checkable +class Interpolation(Protocol): + """Protocol for interpolating phase-space positions.""" + + units: AbstractUnitSystem + + def __call__(self, t: gt.VecTime) -> gt.BatchVecTime6: + pass + + +@final +class InterpolatedPhaseSpacePosition(AbstractPhaseSpacePosition): + """Interpolated phase-space position.""" + + q: Abstract3DVector = eqx.field(converter=_q_converter) + """Positions, e.g Cartesian3DVector. + + This is a 3-vector with a batch shape allowing for vector inputs. + """ + + p: Abstract3DVectorDifferential = eqx.field(converter=_p_converter) + r"""Conjugate momenta, e.g. CartesianDifferential3D. + + This is a 3-vector with a batch shape allowing for vector inputs. + """ + + t: gt.BroadBatchFloatQScalar | gt.QVec1 = eqx.field( + converter=Quantity["time"].constructor + ) + """The time corresponding to the positions. + + This is a Quantity with the same batch shape as the positions and + velocities. If `t` is a scalar it will be broadcast to the same batch shape + as `q` and `p`. + """ + + interpolation: Interpolation + """The interpolation function.""" + + def __post_init__(self) -> None: + """Post-initialization.""" + # Need to ensure t shape is correct. Can be Vec0. + if self.t.ndim in (0, 1): + t = expand_batch_dims(self.t, ndim=self.q.ndim - self.t.ndim) + object.__setattr__(self, "t", t) + + def __call__(self, t: gt.BatchFloatQScalar) -> PhaseSpacePosition: + """Call the interpolation.""" + qp = self.interpolation(t) + units = self.interpolation.units + return PhaseSpacePosition( + q=Quantity(qp[..., 0:3], units["length"]), + p=Quantity(qp[..., 3:6], units["speed"]), + t=t, + ) + + # ========================================================================== + # Array properties + + @property + def _shape_tuple(self) -> tuple[tuple[int, ...], ComponentShapeTuple]: + """Batch, component shape.""" + qbatch, qshape = vector_batched_shape(self.q) + pbatch, pshape = vector_batched_shape(self.p) + tbatch, _ = batched_shape(self.t, expect_ndim=0) + batch_shape = jnp.broadcast_shapes(qbatch, pbatch, tbatch) + return batch_shape, ComponentShapeTuple(q=qshape, p=pshape, t=1) diff --git a/src/galax/coordinates/_psp/psp.py b/src/galax/coordinates/_psp/psp.py index c0a765da..ce49535a 100644 --- a/src/galax/coordinates/_psp/psp.py +++ b/src/galax/coordinates/_psp/psp.py @@ -1,15 +1,14 @@ """galax: Galactic Dynamix in Jax.""" -__all__ = ["PhaseSpacePosition", "InterpolatedPhaseSpacePosition"] +__all__ = ["PhaseSpacePosition"] -from typing import Any, NamedTuple, Protocol, TypeAlias, final, runtime_checkable +from typing import Any, NamedTuple, final import equinox as eqx import jax.numpy as jnp -from jaxtyping import Shaped from coordinax import Abstract3DVector, Abstract3DVectorDifferential -from unxt import AbstractUnitSystem, Quantity +from unxt import Quantity import galax.typing as gt from .base import AbstractPhaseSpacePosition @@ -134,8 +133,8 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition): def __post_init__(self) -> None: """Post-initialization.""" # Need to ensure t shape is correct. Can be Vec0. - if self.t is not None and self.t.ndim in (0, 1): - t = expand_batch_dims(self.t, ndim=self.q.ndim - self.t.ndim) + if (t := self.t) is not None and t.ndim in (0, 1): + t = expand_batch_dims(t, ndim=self.q.ndim - t.ndim) object.__setattr__(self, "t", t) # ========================================================================== @@ -163,77 +162,3 @@ def wt(self, *, units: Any) -> gt.BatchVec7: self.t, self.t is None, "No time defined for phase-space position" ) return super().wt(units=units) - - -# ============================================================================ - -BatchVecTime6: TypeAlias = Shaped[gt.VecTime6, "*batch"] - - -@runtime_checkable -class Interpolation(Protocol): - """Protocol for interpolating phase-space positions.""" - - units: AbstractUnitSystem - - def __call__(self, t: gt.VecTime) -> BatchVecTime6: - pass - - -@final -class InterpolatedPhaseSpacePosition(AbstractPhaseSpacePosition): - """Interpolated phase-space position.""" - - q: Abstract3DVector = eqx.field(converter=_q_converter) - """Positions, e.g Cartesian3DVector. - - This is a 3-vector with a batch shape allowing for vector inputs. - """ - - p: Abstract3DVectorDifferential = eqx.field(converter=_p_converter) - r"""Conjugate momenta, e.g. CartesianDifferential3D. - - This is a 3-vector with a batch shape allowing for vector inputs. - """ - - t: gt.BroadBatchFloatQScalar | gt.QVec1 = eqx.field( - converter=Quantity["time"].constructor - ) - """The time corresponding to the positions. - - This is a Quantity with the same batch shape as the positions and - velocities. If `t` is a scalar it will be broadcast to the same batch shape - as `q` and `p`. - """ - - interpolation: Interpolation - """The interpolation function.""" - - def __post_init__(self) -> None: - """Post-initialization.""" - # Need to ensure t shape is correct. Can be Vec0. - if self.t.ndim in (0, 1): - t = expand_batch_dims(self.t, ndim=self.q.ndim - self.t.ndim) - object.__setattr__(self, "t", t) - - def __call__(self, t: gt.BatchFloatQScalar) -> PhaseSpacePosition: - """Call the interpolation.""" - qp = self.interpolation(t) - units = self.interpolation.units - return PhaseSpacePosition( - q=Quantity(qp[..., 0:3], units["length"]), - p=Quantity(qp[..., 3:6], units["speed"]), - t=t, - ) - - # ========================================================================== - # Array properties - - @property - def _shape_tuple(self) -> tuple[tuple[int, ...], ComponentShapeTuple]: - """Batch, component shape.""" - qbatch, qshape = vector_batched_shape(self.q) - pbatch, pshape = vector_batched_shape(self.p) - tbatch, _ = batched_shape(self.t, expect_ndim=0) - batch_shape = jnp.broadcast_shapes(qbatch, pbatch, tbatch) - return batch_shape, ComponentShapeTuple(q=qshape, p=pshape, t=1) diff --git a/src/galax/dynamics/_dynamics/integrate/_api.py b/src/galax/dynamics/_dynamics/integrate/_api.py index 4afd0ae5..4b1e29fe 100644 --- a/src/galax/dynamics/_dynamics/integrate/_api.py +++ b/src/galax/dynamics/_dynamics/integrate/_api.py @@ -1,11 +1,11 @@ __all__ = ["Integrator"] -from typing import Any, Protocol, TypeAlias, runtime_checkable +from typing import Any, Literal, Protocol, TypeAlias, runtime_checkable from unxt import AbstractUnitSystem +import galax.coordinates as gc import galax.typing as gt -from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition from galax.utils.dataclasses import _DataclassInstance SaveT: TypeAlias = gt.BatchQVecTime | gt.QVecTime | gt.BatchVecTime | gt.VecTime @@ -52,14 +52,15 @@ class Integrator(_DataclassInstance, Protocol): def __call__( self, F: FCallable, - w0: AbstractPhaseSpacePosition | gt.BatchVec6, + w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6, t0: gt.FloatQScalar | gt.FloatScalar, t1: gt.FloatQScalar | gt.FloatScalar, /, savet: SaveT | None = None, *, units: AbstractUnitSystem, - ) -> PhaseSpacePosition: + interpolated: Literal[False, True] = False, + ) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition: """Integrate. Parameters diff --git a/src/galax/dynamics/_dynamics/integrate/_base.py b/src/galax/dynamics/_dynamics/integrate/_base.py index 79272e01..102cc66e 100644 --- a/src/galax/dynamics/_dynamics/integrate/_base.py +++ b/src/galax/dynamics/_dynamics/integrate/_base.py @@ -1,14 +1,15 @@ __all__ = ["AbstractIntegrator"] import abc +from typing import Literal import equinox as eqx from unxt import AbstractUnitSystem +import galax.coordinates as gc import galax.typing as gt from ._api import FCallable -from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, misc] @@ -28,7 +29,7 @@ class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, mis def __call__( self, F: FCallable, - w0: AbstractPhaseSpacePosition | gt.BatchVec6, + w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6, t0: gt.FloatQScalar | gt.FloatScalar, t1: gt.FloatQScalar | gt.FloatScalar, /, @@ -37,7 +38,8 @@ def __call__( ) = None, *, units: AbstractUnitSystem, - ) -> PhaseSpacePosition: + interpolated: Literal[False, True] = False, + ) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition: """Run the integrator. Parameters diff --git a/src/galax/dynamics/_dynamics/integrate/_builtin.py b/src/galax/dynamics/_dynamics/integrate/_builtin.py index dd0a7c99..f60594ca 100644 --- a/src/galax/dynamics/_dynamics/integrate/_builtin.py +++ b/src/galax/dynamics/_dynamics/integrate/_builtin.py @@ -3,21 +3,28 @@ from collections.abc import Mapping from dataclasses import KW_ONLY from functools import partial -from typing import Any, final +from typing import Any, Literal, final import diffrax import equinox as eqx import jax +import jax.numpy as jnp +from diffrax import DenseInterpolation +from plum import overload import quaxed.array_api as xp -from unxt import AbstractUnitSystem, Quantity, to_units_value +from unxt import AbstractUnitSystem, Quantity, to_units_value, unitsystem import galax.coordinates as gc import galax.typing as gt from ._api import FCallable from ._base import AbstractIntegrator from galax.utils import ImmutableDict -from galax.utils._jax import vectorize_method + + +@partial(jnp.vectorize, signature="(T,6),(T,1)->(T,7)") +def _broadcast_concat(w: gt.VecTime6, ts: gt.VecTime1, /) -> gt.VecTime7: + return xp.concat((w, ts), axis=1) @final @@ -63,31 +70,89 @@ class DiffraxIntegrator(AbstractIntegrator): default=(("scan_kind", "bounded"),), static=True, converter=ImmutableDict ) - @vectorize_method(excluded=(0,), signature="(6),(),(),(T)->(T,7)") - @partial(jax.jit, static_argnums=(0, 1)) + # ===================================================== + # Call + + @partial(jax.jit, static_argnums=(0, 1, 6)) def _call_implementation( self, F: FCallable, - w0: gt.Vec6, + w0: gt.BatchVec6, t0: gt.FloatScalar, t1: gt.FloatScalar, ts: gt.VecTime, /, - ) -> gt.VecTime7: - solution = diffrax.diffeqsolve( - terms=diffrax.ODETerm(F), - solver=self.Solver(**self.solver_kw), - t0=t0, - t1=t1, - y0=w0, - dt0=None, - args=(), - saveat=diffrax.SaveAt(t0=False, t1=False, ts=ts, dense=False), - stepsize_controller=self.stepsize_controller, - **self.diffeq_kw, - ) - ts = solution.ts[:, None] if solution.ts.ndim == 1 else solution.ts - return xp.concat((solution.ys, ts), axis=1) + interpolated: Literal[False, True], + ) -> tuple[gt.BatchVecTime7, DenseInterpolation | None]: + # TODO: less awkward munging of the diffrax API + kw = dict(self.diffeq_kw) + if interpolated and kw.get("max_steps") is None: + kw.pop("max_steps") + + # `jax.numpy.vectorize` only works on arrays, so breaks for + # `diffrax.Solution` objects. To get around this we emulate the + # `jax.numpy.vectorize` by manually flattening the batch axes to be + # just the first axis, vmapping, then reshaping the result. + terms = diffrax.ODETerm(F) + solver = self.Solver(**self.solver_kw) + + @partial(jax.vmap, in_axes=(0, 0)) + def solve_diffeq(ts: gt.VecTime, w0: gt.Vec6, /) -> diffrax.Solution: + return diffrax.diffeqsolve( + terms=terms, + solver=solver, + t0=t0, + t1=t1, + y0=w0, + dt0=None, + args=(), + saveat=diffrax.SaveAt(t0=False, t1=False, ts=ts, dense=interpolated), + stepsize_controller=self.stepsize_controller, + **kw, + ) + + # Reshape inputs. Need to ensure that the inputs are batched then + # flattened, so that the vmap'ed `solve` can be applied. + num_t = ts.shape[-1] # number of times + batch_w = w0.shape[:-1] # batch shape + w0 = jnp.atleast_2d(w0) + ts = jnp.broadcast_to(jnp.atleast_2d(ts), (*w0.shape[:-1], num_t)) + + # Perform the integration + solution = solve_diffeq(ts.reshape(-1, num_t), w0.reshape(-1, 6)) + + # Parse the solution + w = jnp.concat((solution.ys, solution.ts[..., None]), axis=-1) + interp = solution.interpolation + + # Reshape outputs + w = w.reshape((*batch_w, num_t, 7)) + + return w, interp + + @overload + def __call__( + self, + F: FCallable, + w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6, + /, + ts: gt.BatchQVecTime | gt.BatchVecTime | gt.QVecTime | gt.VecTime, + *, + units: AbstractUnitSystem, + interpolated: Literal[False] = False, + ) -> gc.PhaseSpacePosition: ... + + @overload + def __call__( + self, + F: FCallable, + w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6, + /, + ts: gt.BatchQVecTime | gt.BatchVecTime | gt.QVecTime | gt.VecTime, + *, + units: AbstractUnitSystem, + interpolated: Literal[True], + ) -> gc.InterpolatedPhaseSpacePosition: ... def __call__( self, @@ -101,7 +166,8 @@ def __call__( ) = None, *, units: AbstractUnitSystem, - ) -> gc.PhaseSpacePosition: + interpolated: Literal[False, True] = False, + ) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition: """Run the integrator. Parameters @@ -119,6 +185,8 @@ def __call__( units : `unxt.AbstractUnitSystem` The unit system to use. + interpolated : bool, keyword-only + Whether to return an interpolated solution. Returns ------- @@ -200,14 +268,50 @@ def __call__( w0_: gt.Vec6 = ( w0.w(units=units) if isinstance(w0, gc.AbstractPhaseSpacePosition) else w0 ) + added_ndim = int(w0_.shape[:-1] == ()) # Perform the integration - w = self._call_implementation(F, w0_, t0_, t1_, savet_) + w, interp = self._call_implementation(F, w0_, t0_, t1_, savet_, interpolated) w = w[..., -1, :] if savet is None else w # Return + if interpolated: + out = gc.InterpolatedPhaseSpacePosition( # shape = (*batch, T) + q=Quantity(w[..., 0:3], units["length"]), + p=Quantity(w[..., 3:6], units["speed"]), + t=Quantity(savet_, units["time"]), + interpolation=DiffraxInterpolation( + interp, units=units, added_ndim=added_ndim + ), + ) + else: + out = gc.PhaseSpacePosition( # shape = (*batch, T) + q=Quantity(w[..., 0:3], units["length"]), + p=Quantity(w[..., 3:6], units["speed"]), + t=Quantity(w[..., -1], units["time"]), + ) + + return out + + +class DiffraxInterpolation(eqx.Module): # type: ignore[misc] + """Wrapper for ``diffrax.DenseInterpolation``.""" + + interpolation: DenseInterpolation + """:class:`diffrax.DenseInterpolation` object.""" + + units: AbstractUnitSystem = eqx.field(static=True, converter=unitsystem) + """The :class:`unxt.AbstractUnitSystem`.""" + + added_ndim: tuple[int, ...] = eqx.field(static=True) + """The number of dimensions added to the output of the interpolation.""" + + def __call__(self, t: gt.QVecTime, **_: Any) -> gc.PhaseSpacePosition: + t_ = jnp.atleast_1d(t.to_units_value(self.units["time"])) + ys = jax.vmap(lambda s: jax.vmap(s.evaluate)(t_))(self.interpolation) + ys = ys[(0,) * (ys.ndim - self.added_ndim - 1)] return gc.PhaseSpacePosition( - q=Quantity(w[..., 0:3], units["length"]), - p=Quantity(w[..., 3:6], units["speed"]), - t=Quantity(w[..., -1], units["time"]), + q=Quantity(ys[..., 0:3], self.units["length"]), + p=Quantity(ys[..., 3:6], self.units["speed"]), + t=t, ) diff --git a/src/galax/dynamics/_dynamics/integrate/_funcs.py b/src/galax/dynamics/_dynamics/integrate/_funcs.py index 7721dddf..0c07d457 100644 --- a/src/galax/dynamics/_dynamics/integrate/_funcs.py +++ b/src/galax/dynamics/_dynamics/integrate/_funcs.py @@ -4,6 +4,7 @@ from dataclasses import replace from functools import partial +from typing import Literal import jax import jax.numpy as jnp @@ -16,7 +17,7 @@ from ._api import Integrator from ._builtin import DiffraxIntegrator from galax.coordinates import PhaseSpacePosition -from galax.dynamics._dynamics.orbit import Orbit +from galax.dynamics._dynamics.orbit import InterpolatedOrbit, Orbit from galax.potential._potential.base import AbstractPotentialBase ############################################################################## @@ -29,20 +30,21 @@ _select_w0 = jnp.vectorize(jax.lax.select, signature="(),(6),(6)->(6)") -@partial(jax.jit, static_argnames=("integrator",)) +@partial(jax.jit, static_argnames=("integrator", "interpolated")) def evaluate_orbit( pot: AbstractPotentialBase, w0: PhaseSpacePosition | gt.BatchVec6, t: gt.QVecTime | gt.VecTime | APYQuantity, *, integrator: Integrator | None = None, -) -> Orbit: + interpolated: Literal[True, False] = False, +) -> Orbit | InterpolatedOrbit: """Compute an orbit in a potential. - :class:`~galax.coordinates.PhaseSpacePosition` includes a time in - addition to the position (and velocity) information, enabling the orbit to - be evaluated over a time range that is different from the initial time of - the position. + :class:`~galax.coordinates.PhaseSpacePosition` includes a time in addition + to the position (and velocity) information, enabling the orbit to be + evaluated over a time range that is different from the initial time of the + position. Parameters ---------- @@ -82,6 +84,10 @@ def evaluate_orbit( is used twice: once to integrate from `w0.t` to `t[0]` and then from `t[0]` to `t[1]`. + interpolated: bool, optional keyword-only + If `True`, return an interpolated orbit. If `False`, return the orbit + at the requested times. Default is `False`. + Returns ------- orbit : :class:`~galax.dynamics.Orbit` @@ -225,7 +231,23 @@ def evaluate_orbit( t[-1], savet=t, units=units, + interpolated=interpolated, + ) + wt = ( # get rid of unused batch dimensions + ws.t[..., 0] if ws.t is not None else None ) # Construct the orbit object - return Orbit(q=ws.q, p=ws.p, t=t, potential=pot) + # TODO: easier construction from the (Interpolated)PhaseSpacePosition + if interpolated: + out = InterpolatedOrbit( + q=ws.q, + p=ws.p, + t=wt, + interpolation=ws.interpolation, + potential=pot, + ) + else: + out = Orbit(q=ws.q, p=ws.p, t=wt, potential=pot) + + return out diff --git a/src/galax/dynamics/_dynamics/orbit.py b/src/galax/dynamics/_dynamics/orbit.py index 9104f880..9ff5f9d7 100644 --- a/src/galax/dynamics/_dynamics/orbit.py +++ b/src/galax/dynamics/_dynamics/orbit.py @@ -10,7 +10,7 @@ from unxt import Quantity from .base import AbstractOrbit -from galax.coordinates._psp.psp import Interpolation +from galax.coordinates._psp.interpolated import Interpolation from galax.coordinates._psp.utils import _p_converter, _q_converter from galax.potential._potential.base import AbstractPotentialBase from galax.typing import BatchFloatQScalar, QVec1, QVecTime diff --git a/src/galax/potential/_potential/base.py b/src/galax/potential/_potential/base.py index 2075f255..14710141 100644 --- a/src/galax/potential/_potential/base.py +++ b/src/galax/potential/_potential/base.py @@ -4,7 +4,7 @@ from dataclasses import KW_ONLY, fields from functools import partial from types import MappingProxyType -from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeAlias, cast import equinox as eqx import jax @@ -1896,6 +1896,7 @@ def evaluate_orbit( t: gt.QVecTime | gt.VecTime | APYQuantity, # TODO: must be a Quantity *, integrator: "Integrator | None" = None, + interpolated: Literal[True, False] = False, ) -> "Orbit": """Compute an orbit in a potential. @@ -1943,6 +1944,11 @@ def evaluate_orbit( Integrator to use. If `None`, the default integrator :class:`~galax.integrator.DiffraxIntegrator` is used. + interpolated: bool, optional keyword-only + If `True`, return an interpolated orbit. If `False`, return the orbit + at the requested times. Default is `False`. + + Returns ------- orbit : :class:`~galax.dynamics.Orbit` @@ -1956,4 +1962,9 @@ def evaluate_orbit( """ from galax.dynamics import evaluate_orbit - return cast("Orbit", evaluate_orbit(self, w0, t, integrator=integrator)) + return cast( + "Orbit", + evaluate_orbit( + self, w0, t, integrator=integrator, interpolated=interpolated + ), + ) diff --git a/src/galax/typing.py b/src/galax/typing.py index cc1f1d0a..bee9f66e 100644 --- a/src/galax/typing.py +++ b/src/galax/typing.py @@ -69,6 +69,7 @@ # Time vector VecTime = Float[Array, "time"] QVecTime = Float[Quantity, "time"] +VecTime1 = Float[Vec1, "time"] VecTime3 = Float[Vec3, "time"] VecTime6 = Float[Vec6, "time"] VecTime7 = Float[Vec7, "time"] @@ -105,6 +106,8 @@ # Specific BatchVecTime = Shaped[VecTime, "*batch"] +BatchVecTime6 = Shaped[VecTime6, "*batch"] +BatchVecTime7 = Shaped[VecTime7, "*batch"] BatchQVecTime = Shaped[QVecTime, "*batch"] # -----------------