Skip to content

Commit

Permalink
refactor: update check init
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jul 3, 2024
1 parent 6a2c98f commit d1cde22
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 11 deletions.
32 changes: 21 additions & 11 deletions src/galax/potential/_potential/builtin/multipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ class MultipoleInnerPotential(AbstractMultipolePotential):

def __check_init__(self) -> None:
shape = (self.l_max + 1, self.l_max + 1)
# TODO: don't use .value
if self.Slm.value.shape != shape or self.Tlm.value.shape != shape:
t = Quantity(0.0, "Gyr")
s_shape, t_shape = self.Slm(t).shape, self.Tlm(t).shape
# TODO: check shape across time.
if s_shape != shape or t_shape != shape:
msg = (
"Slm and Tlm must have the shape (l_max + 1, l_max + 1)."
f"Slm shape: {self.Slm.value.shape}, Tlm shape: {self.Tlm.value.shape}"
f"Slm shape: {s_shape}, Tlm shape: {t_shape}"
)
raise ValueError(msg)

Expand Down Expand Up @@ -117,9 +119,14 @@ class MultipoleOuterPotential(AbstractMultipolePotential):

def __check_init__(self) -> None:
shape = (self.l_max + 1, self.l_max + 1)
# TODO: don't use .value
if self.Slm.value.shape != shape or self.Tlm.value.shape != shape:
msg = "Slm and Tlm must have the shape (l_max + 1, l_max + 1)."
t = Quantity(0.0, "Gyr")
s_shape, t_shape = self.Slm(t).shape, self.Tlm(t).shape
# TODO: check shape across time.
if s_shape != shape or t_shape != shape:
msg = (
"Slm and Tlm must have the shape (l_max + 1, l_max + 1)."
f"Slm shape: {s_shape}, Tlm shape: {t_shape}"
)
raise ValueError(msg)

@partial(jax.jit, inline=True)
Expand Down Expand Up @@ -178,12 +185,15 @@ class MultipolePotential(AbstractMultipolePotential):

def __check_init__(self) -> None:
shape = (self.l_max + 1, self.l_max + 1)
# TODO: don't use .value
t = Quantity(0.0, "Gyr")
is_shape, it_shape = self.ISlm(t).shape, self.ITlm(t).shape
os_shape, ot_shape = self.OSlm(t).shape, self.OTlm(t).shape
# TODO: check shape across time.
if (
self.ISlm.value.shape != shape
or self.ITlm.value.shape != shape
or self.OSlm.value.shape != shape
or self.OTlm.value.shape != shape
is_shape != shape
or it_shape != shape
or os_shape != shape
or ot_shape != shape
):
msg = "I/OSlm and I/OTlm must have the shape (l_max + 1, l_max + 1)."
raise ValueError(msg)
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/potential/builtin/multipole/test_innermultipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ def fields_(

# ==========================================================================

def test_check_init(
self, pot_cls: type[gp.MultipoleInnerPotential], fields_: dict[str, Any]
) -> None:
"""Test the `MultipoleInnerPotential.__check_init__` method."""
fields_["Slm"] = fields_["Slm"][::2] # make it the wrong shape
with pytest.raises(ValueError, match="Slm and Tlm must have the shape"):
pot_cls(**fields_)

# ==========================================================================

def test_potential(self, pot: gp.MultipoleInnerPotential, x: gt.QVec3) -> None:
expect = Quantity(32.96969177, unit="kpc2 / Myr2")
assert qnp.isclose(
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/potential/builtin/multipole/test_multipole.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test the `MultipolePotential` class."""

import re
from typing import Any

import astropy.units as u
Expand Down Expand Up @@ -252,6 +253,17 @@ def fields_(

# ==========================================================================

def test_check_init(
self, pot_cls: type[gp.MultipoleInnerPotential], fields_: dict[str, Any]
) -> None:
"""Test the `MultipoleInnerPotential.__check_init__` method."""
fields_["ISlm"] = fields_["ISlm"][::2] # make it the wrong shape
match = re.escape("I/OSlm and I/OTlm must have the shape")
with pytest.raises(ValueError, match=match):
pot_cls(**fields_)

# ==========================================================================

def test_potential(self, pot: gp.MultipolePotential, x: gt.QVec3) -> None:
expect = Quantity(33.59908611, unit="kpc2 / Myr2")
assert jnp.isclose(
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/potential/builtin/multipole/test_outermultipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ def fields_(

# ==========================================================================

def test_check_init(
self, pot_cls: type[gp.MultipoleInnerPotential], fields_: dict[str, Any]
) -> None:
"""Test the `MultipoleInnerPotential.__check_init__` method."""
fields_["Slm"] = fields_["Slm"][::2] # make it the wrong shape
with pytest.raises(ValueError, match="Slm and Tlm must have the shape"):
pot_cls(**fields_)

# ==========================================================================

def test_potential(self, pot: gp.MultipoleOuterPotential, x: gt.QVec3) -> None:
expect = Quantity(0.62939434, unit="kpc2 / Myr2")
assert qnp.isclose(
Expand Down

0 comments on commit d1cde22

Please sign in to comment.