Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Nov 10, 2023
1 parent 08afd46 commit b045720
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 22 deletions.
5 changes: 1 addition & 4 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@ def docs(session: nox.Session) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--serve", action="store_true", help="Serve after building")
parser.add_argument(
"-b",
dest="builder",
default="html",
help="Build target (default: html)",
"-b", dest="builder", default="html", help="Build target (default: html)"
)
args, posargs = parser.parse_known_args(session.posargs)

Expand Down
2 changes: 1 addition & 1 deletion src/galdynamix/dynamics/_orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Orbit(eqx.Module): # type: ignore[misc]
@property
@partial_jit()
def qp(self) -> jt.Array:
"""Return as a single Array[(N, Q + P + T),]."""
"""Return as a single Array[(N, Q + P),]."""
# Determine output shape
qd = self.q.shape[1] # dimensionality of q
shape = (self.q.shape[0], qd + self.p.shape[1])
Expand Down
2 changes: 1 addition & 1 deletion src/galdynamix/dynamics/mockstream/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class MockStream(eqx.Module): # type: ignore[misc]
@property
@partial_jit()
def qp(self) -> jt.Array:
"""Return as a single Array[(N, Q + P + T),]."""
"""Return as a single Array[(N, Q + P),]."""
# Determine output shape
qd = self.q.shape[1] # dimensionality of q
shape = (self.q.shape[0], qd + self.p.shape[1])
Expand Down
10 changes: 6 additions & 4 deletions src/galdynamix/dynamics/mockstream/_df/fardal.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,19 @@ def _sample(
kvz_samp = kvz_bar + jax.random.normal(keyd, shape=(1,)) * sigma_kvz

# Trailing arm
x_trail = x + kr_samp * r_hat * (r_tidal) # nudge out
x_trail = x_trail + z_hat * kz_samp * (r_tidal / 1.0)
x_trail = (
x + (kr_samp * r_hat * (r_tidal)) + (z_hat * kz_samp * (r_tidal / 1.0))
)
v_trail = (
v
+ (0.0 + kvphi_samp * v_circ * (1.0)) * phi_hat
+ (kvz_samp * v_circ * (1.0)) * z_hat
)

# Leading arm
x_lead = x + kr_samp * r_hat * (-r_tidal) # nudge in
x_lead = x_lead + z_hat * kz_samp * (-r_tidal / 1.0)
x_lead = (
x + (kr_samp * r_hat * (-r_tidal)) + (z_hat * kz_samp * (-r_tidal / 1.0))
)
v_lead = (
v
+ (0.0 + kvphi_samp * v_circ * (-1.0)) * phi_hat
Expand Down
7 changes: 1 addition & 6 deletions src/galdynamix/dynamics/mockstream/_mockstream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,7 @@ class MockStreamGenerator(eqx.Module): # type: ignore[misc]

@partial_jit(static_argnames=("seed_num",))
def _run_scan(
self,
ts: jt.Array,
prog_w0: jt.Array,
prog_mass: jt.Array,
*,
seed_num: int,
self, ts: jt.Array, prog_w0: jt.Array, prog_mass: jt.Array, *, seed_num: int
) -> tuple[tuple[jt.Array, jt.Array], Orbit]:
"""Generate stellar stream by scanning over the release model/integration.
Expand Down
6 changes: 2 additions & 4 deletions src/galdynamix/integrate/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class DiffraxIntegrator(AbstractIntegrator):
Solver: AbstractSolver = eqx.field(default=Dopri5, static=True)
SaveAt: DiffraxSaveAt = eqx.field(default=DiffraxSaveAt, static=True)
stepsize_controller: AbstractStepSizeController = eqx.field(
default=PIDController(rtol=1e-7, atol=1e-7),
static=True,
default=PIDController(rtol=1e-7, atol=1e-7), static=True
)
diffeq_kw: tuple[tuple[str, Any], ...] = eqx.field(
default_factory=lambda: (
Expand All @@ -39,8 +38,7 @@ class DiffraxIntegrator(AbstractIntegrator):
static=True,
)
solver_kw: tuple[tuple[str, Any], ...] = eqx.field(
default_factory=lambda: (("scan_kind", "bounded"),),
static=True,
default_factory=lambda: (("scan_kind", "bounded"),), static=True
)

def run(
Expand Down
3 changes: 1 addition & 2 deletions src/galdynamix/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def _init_units(self) -> None:
value = getattr(self, f.name)
if isinstance(value, u.Quantity):
value = value.to_value(
self.units[param.physical_type],
equivalencies=param.equivalencies,
self.units[param.physical_type], equivalencies=param.equivalencies
)
object.__setattr__(self, f.name, value)

Expand Down

0 comments on commit b045720

Please sign in to comment.