Skip to content

refactor: parameterise Ellipse + EllipseMultipole math on xp #407

@Jammy2211

Description

@Jammy2211

Overview

Add an xp=np keyword argument to every numeric method on Ellipse, EllipseMultipole, and EllipseMultipoleScaled, replacing bare np.* calls with xp.* so the geometry computations trace under jax.jit. Also replaces the two Python while loops in EllipseMultipole.get_shape_angle with a single xp.mod-based wrap so that method is JIT-safe too. Step 5 of 7 in the ellipse_fitting_jax feature (PyAutoPrompt/z_features/ellipse_fitting_jax.md). NumPy-path behaviour is preserved — xp=np is the default at every entry point and call sites in FitEllipse continue to pass nothing.

Plan

  • Thread xp=np through five Ellipse.*_from_major_axis_from methods that return arrays; replace np.* calls with xp.* in their bodies.
  • Thread xp=np through EllipseMultipole.points_perturbed_from and EllipseMultipoleScaled.points_perturbed_from.
  • Replace the two Python while loops in EllipseMultipole.get_shape_angle with a single arithmetic wrap using xp.mod. Document the (-180/m, 180/m] vs [-180/m, 180/m) boundary semantics change in a docstring note and pin it with a numpy-only boundary test.
  • Guard the if np.sum(idx) > 0: raise NotImplementedError() NaN check in points_from_major_axis_from behind if xp is np: — under JAX, NaNs propagate through nansum/nanmean downstream so a JIT-time raise would be incorrect.
  • Ellipse.total_points_from stays numpy (returns a Python int used as a static shape outside the JIT trace).
  • No JAX in unit tests per PyAutoGalaxy/CLAUDE.md rule. JAX parity is verified at the workspace_test level — implicitly through prompt 2's autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/ scripts once prompt 7 wires the JIT path. Until then, the xp=jnp path is exercised only manually during development.
  • Numpy-path numerics must not drift: prompt 3's rtol=1e-12 reference arrays and prompt 2's workspace_test reference numbers continue to pass.
Detailed implementation plan

Affected Repositories

  • PyAutoGalaxy (primary)

Work Classification

Library

Branch Survey

Repository Current Branch Dirty?
./PyAutoGalaxy main clean

Suggested branch: feature/ellipse-xp
Worktree root: ~/Code/PyAutoLabs-wt/ellipse-xp/ (created later by /start_library)

Implementation Steps

  1. autogalaxy/ellipse/ellipse/ellipse.py: add xp=np to:

    • Ellipse.angles_from_x0_from(self, pixel_scale, n_i=0, xp=np) — replace np.linspace with xp.linspace.
    • Ellipse.ellipse_radii_from_major_axis_from(self, pixel_scale, n_i=0, xp=np) — replace np.divide, np.sqrt, np.add, np.sin, np.cos. Thread xp into the inner angles_from_x0_from call.
    • Ellipse.x_from_major_axis_from(self, pixel_scale, n_i=0, xp=np) — replace np.cos. Thread xp into inner calls.
    • Ellipse.y_from_major_axis_from(self, pixel_scale, n_i=0, xp=np) — replace np.sin. Thread xp into inner calls.
    • Ellipse.points_from_major_axis_from(self, pixel_scale, n_i=0, xp=np) — replace np.stack. Wrap the NaN-check (idx = np.logical_or(np.isnan(x), np.isnan(y)); if np.sum(idx) > 0: raise NotImplementedError()) in if xp is np:. Thread xp into x_from_major_axis_from / y_from_major_axis_from calls.
  2. Ellipse.total_points_from — unchanged. Stays numpy because it returns a Python int used outside JIT traces.

  3. autogalaxy/ellipse/ellipse/ellipse_multipole.py:

    • EllipseMultipole.get_shape_angle(self, ellipse, xp=np) — replace the two while angle < -180/self.m: angle += 360/self.m / while angle > 180/self.m: angle -= 360/self.m loops with:
      period = 360.0 / self.m
      angle = xp.mod(angle + period / 2.0, period) - period / 2.0
      The original returns angles in (-period/2, period/2]; the rewrite returns [-period/2, period/2). Document the boundary difference in a docstring note and pin the boundary case (angle = period/2.0 exactly) in a new unit test.
    • EllipseMultipole.points_perturbed_from(self, pixel_scale, points, ellipse, n_i=0, xp=np) — replace np.arctan2, np.cos, np.sin, np.stack. multipole_comps_from / multipole_k_m_and_phi_m_from from convert.py operate on Python tuples and don't trip JIT; leave untouched.
    • EllipseMultipoleScaled.points_perturbed_from(self, pixel_scale, points, ellipse, n_i=0, xp=np) — same treatment as EllipseMultipole.points_perturbed_from.
  4. Call sites in FitEllipse (autogalaxy/ellipse/fit_ellipse.py:69-136) and elsewhere must NOT be updated in this PR — they pass nothing, get the numpy default, behaviour unchanged. Threading xp into those call sites happens in prompts 6 and 7.

  5. test_autogalaxy/ellipse/test_ellipse.py:

    • Add test__multipole__get_shape_angle__boundary — numpy-only. Input ellipse with angle() such that the un-normalised offset equals period/2.0 exactly; assert the rewrite returns -period/2.0 (the new convention). Documents the boundary-case change vs the original while-loop behaviour.
    • Do NOT add JAX-parity tests here. PyAutoGalaxy/CLAUDE.md: "Never use JAX in unit tests." JAX parity is checked at the workspace_test level via prompt 2's autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/ scripts once prompt 7 flips them to the JIT path.
  6. Run python -m pytest test_autogalaxy/ellipse/ -v from the worktree. Must report all tests pass — including prompt-3's rtol=1e-12 pins on _points_from_major_axis. If those drift, the numpy path semantics have changed and there's a bug.

  7. Run python scripts/jax_likelihood_functions/ellipse/{fit,multipoles}.py from autogalaxy_workspace_test/. Confirm all 8 reference numbers match the prompt-2 baseline.

Key Files

  • autogalaxy/ellipse/ellipse/ellipse.py — thread xp through five methods.
  • autogalaxy/ellipse/ellipse/ellipse_multipole.py — thread xp through three methods, rewrite get_shape_angle.
  • test_autogalaxy/ellipse/test_ellipse.py — one new test (boundary case for get_shape_angle).

Testing Approach

  • pytest: must remain green, including prompt-3 reference-array pins at rtol=1e-12.
  • Workspace parity: prompt-2 reference numbers byte-stable.
  • JAX-path verification: deferred to workspace_test level (prompt 7).

Original Prompt

Click to expand starting prompt

Step 5 of the ellipse-JAX series. With the 2D interpolator in place from prompt 4, the next blocker is the geometry math in @PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse.py and @PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse_multipole.py. Every routine on Ellipse uses bare np.*, and EllipseMultipole.get_shape_angle uses Python while loops to wrap an angle into [-180/m, 180/m] — both incompatible with jax.jit tracing. Convert these to the xp=np pattern documented in @PyAutoGalaxy/CLAUDE.md "JAX Support" section.

Please:

  1. Add xp=np as a keyword argument to every method in @PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse.py that returns a numerical array:

    • Ellipse.angles_from_x0_from
    • Ellipse.ellipse_radii_from_major_axis_from
    • Ellipse.x_from_major_axis_from
    • Ellipse.y_from_major_axis_from
    • Ellipse.points_from_major_axis_from

    Replace bare np.* with xp.* inside the function bodies (xp.linspace, xp.sin, xp.cos, xp.divide, xp.add, xp.sqrt, xp.stack). The total_points_from method stays numpy — its return type is a Python int and it's used to set static shapes outside the JIT trace.

    Special case in points_from_major_axis_from: the idx = np.logical_or(np.isnan(x), np.isnan(y)); if np.sum(idx) > 0: raise NotImplementedError() guard is JAX-incompatible (Python if on a traced value). Replace with if xp is np: around the guard — under JAX, NaNs propagate through downstream nansum/nanmean and we'd rather see them than crash inside a JIT trace.

  2. Same treatment for EllipseMultipole.points_perturbed_from and EllipseMultipoleScaled.points_perturbed_from in @PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse_multipole.py. Add xp=np, swap np.* for xp.*. The multipole_comps_from and multipole_k_m_and_phi_m_from helpers from @PyAutoGalaxy/autogalaxy/convert.py are called outside the math loop on Python tuples — leave those as-is unless they trip JIT (verify by tracing).

  3. Replace the while loops in EllipseMultipole.get_shape_angle (@PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse_multipole.py:66-69) with arithmetic that JAX can trace. The intent is "wrap angle into the open interval (-180/m, 180/m]". A direct replacement using xp.mod works:

    period = 360.0 / self.m
    angle = xp.mod(angle + period / 2.0, period) - period / 2.0

    This produces values in [-period/2, period/2) rather than (-period/2, period/2], which is a tiny boundary-case difference. Verify against the existing tests in @PyAutoGalaxy/test_autogalaxy/ellipse/ and add a test pinning the new behaviour at the boundary (angle = period/2.0) so future changes don't drift unnoticed.

  4. The existing call sites in FitEllipse and elsewhere don't pass xp — they get the numpy default and behaviour is unchanged. Don't thread xp through the call sites in this prompt; that happens in prompt 6 and 7 where it actually matters.

  5. Add unit tests in @PyAutoGalaxy/test_autogalaxy/ellipse/test_ellipse.py that for one fixed Ellipse and one fixed EllipseMultipole, the xp=np and xp=jnp paths produce numerically identical points to rtol=1e-6. Gate the JAX side with pytest.importorskip("jax").

(NOTE: step 5 above conflicts with PyAutoGalaxy/CLAUDE.md's "Never use JAX in unit tests" rule — the issuing session caught this and dropped the JAX-parity tests during issue creation. Only the boundary test for get_shape_angle is added in test_autogalaxy/.)

  1. Test bar:
    • python -m pytest test_autogalaxy/ellipse/ -v passes.
    • The reference numbers from prompt 2's workspace_test scripts are unchanged on the numpy path.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions