Overview
Add an xp=np keyword argument to every numeric method on Ellipse, EllipseMultipole, and EllipseMultipoleScaled, replacing bare np.* calls with xp.* so the geometry computations trace under jax.jit. Also replaces the two Python while loops in EllipseMultipole.get_shape_angle with a single xp.mod-based wrap so that method is JIT-safe too. Step 5 of 7 in the ellipse_fitting_jax feature (PyAutoPrompt/z_features/ellipse_fitting_jax.md). NumPy-path behaviour is preserved — xp=np is the default at every entry point and call sites in FitEllipse continue to pass nothing.
Plan
- Thread
xp=np through five Ellipse.*_from_major_axis_from methods that return arrays; replace np.* calls with xp.* in their bodies.
- Thread
xp=np through EllipseMultipole.points_perturbed_from and EllipseMultipoleScaled.points_perturbed_from.
- Replace the two Python
while loops in EllipseMultipole.get_shape_angle with a single arithmetic wrap using xp.mod. Document the (-180/m, 180/m] vs [-180/m, 180/m) boundary semantics change in a docstring note and pin it with a numpy-only boundary test.
- Guard the
if np.sum(idx) > 0: raise NotImplementedError() NaN check in points_from_major_axis_from behind if xp is np: — under JAX, NaNs propagate through nansum/nanmean downstream so a JIT-time raise would be incorrect.
Ellipse.total_points_from stays numpy (returns a Python int used as a static shape outside the JIT trace).
- No JAX in unit tests per
PyAutoGalaxy/CLAUDE.md rule. JAX parity is verified at the workspace_test level — implicitly through prompt 2's autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/ scripts once prompt 7 wires the JIT path. Until then, the xp=jnp path is exercised only manually during development.
- Numpy-path numerics must not drift: prompt 3's
rtol=1e-12 reference arrays and prompt 2's workspace_test reference numbers continue to pass.
Detailed implementation plan
Affected Repositories
Work Classification
Library
Branch Survey
| Repository |
Current Branch |
Dirty? |
| ./PyAutoGalaxy |
main |
clean |
Suggested branch: feature/ellipse-xp
Worktree root: ~/Code/PyAutoLabs-wt/ellipse-xp/ (created later by /start_library)
Implementation Steps
-
autogalaxy/ellipse/ellipse/ellipse.py: add xp=np to:
Ellipse.angles_from_x0_from(self, pixel_scale, n_i=0, xp=np) — replace np.linspace with xp.linspace.
Ellipse.ellipse_radii_from_major_axis_from(self, pixel_scale, n_i=0, xp=np) — replace np.divide, np.sqrt, np.add, np.sin, np.cos. Thread xp into the inner angles_from_x0_from call.
Ellipse.x_from_major_axis_from(self, pixel_scale, n_i=0, xp=np) — replace np.cos. Thread xp into inner calls.
Ellipse.y_from_major_axis_from(self, pixel_scale, n_i=0, xp=np) — replace np.sin. Thread xp into inner calls.
Ellipse.points_from_major_axis_from(self, pixel_scale, n_i=0, xp=np) — replace np.stack. Wrap the NaN-check (idx = np.logical_or(np.isnan(x), np.isnan(y)); if np.sum(idx) > 0: raise NotImplementedError()) in if xp is np:. Thread xp into x_from_major_axis_from / y_from_major_axis_from calls.
-
Ellipse.total_points_from — unchanged. Stays numpy because it returns a Python int used outside JIT traces.
-
autogalaxy/ellipse/ellipse/ellipse_multipole.py:
EllipseMultipole.get_shape_angle(self, ellipse, xp=np) — replace the two while angle < -180/self.m: angle += 360/self.m / while angle > 180/self.m: angle -= 360/self.m loops with:
period = 360.0 / self.m
angle = xp.mod(angle + period / 2.0, period) - period / 2.0
The original returns angles in (-period/2, period/2]; the rewrite returns [-period/2, period/2). Document the boundary difference in a docstring note and pin the boundary case (angle = period/2.0 exactly) in a new unit test.
EllipseMultipole.points_perturbed_from(self, pixel_scale, points, ellipse, n_i=0, xp=np) — replace np.arctan2, np.cos, np.sin, np.stack. multipole_comps_from / multipole_k_m_and_phi_m_from from convert.py operate on Python tuples and don't trip JIT; leave untouched.
EllipseMultipoleScaled.points_perturbed_from(self, pixel_scale, points, ellipse, n_i=0, xp=np) — same treatment as EllipseMultipole.points_perturbed_from.
-
Call sites in FitEllipse (autogalaxy/ellipse/fit_ellipse.py:69-136) and elsewhere must NOT be updated in this PR — they pass nothing, get the numpy default, behaviour unchanged. Threading xp into those call sites happens in prompts 6 and 7.
-
test_autogalaxy/ellipse/test_ellipse.py:
- Add
test__multipole__get_shape_angle__boundary — numpy-only. Input ellipse with angle() such that the un-normalised offset equals period/2.0 exactly; assert the rewrite returns -period/2.0 (the new convention). Documents the boundary-case change vs the original while-loop behaviour.
- Do NOT add JAX-parity tests here.
PyAutoGalaxy/CLAUDE.md: "Never use JAX in unit tests." JAX parity is checked at the workspace_test level via prompt 2's autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/ scripts once prompt 7 flips them to the JIT path.
-
Run python -m pytest test_autogalaxy/ellipse/ -v from the worktree. Must report all tests pass — including prompt-3's rtol=1e-12 pins on _points_from_major_axis. If those drift, the numpy path semantics have changed and there's a bug.
-
Run python scripts/jax_likelihood_functions/ellipse/{fit,multipoles}.py from autogalaxy_workspace_test/. Confirm all 8 reference numbers match the prompt-2 baseline.
Key Files
autogalaxy/ellipse/ellipse/ellipse.py — thread xp through five methods.
autogalaxy/ellipse/ellipse/ellipse_multipole.py — thread xp through three methods, rewrite get_shape_angle.
test_autogalaxy/ellipse/test_ellipse.py — one new test (boundary case for get_shape_angle).
Testing Approach
- pytest: must remain green, including prompt-3 reference-array pins at
rtol=1e-12.
- Workspace parity: prompt-2 reference numbers byte-stable.
- JAX-path verification: deferred to workspace_test level (prompt 7).
Original Prompt
Click to expand starting prompt
Step 5 of the ellipse-JAX series. With the 2D interpolator in place from prompt 4, the next blocker is the geometry math in @PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse.py and @PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse_multipole.py. Every routine on Ellipse uses bare np.*, and EllipseMultipole.get_shape_angle uses Python while loops to wrap an angle into [-180/m, 180/m] — both incompatible with jax.jit tracing. Convert these to the xp=np pattern documented in @PyAutoGalaxy/CLAUDE.md "JAX Support" section.
Please:
-
Add xp=np as a keyword argument to every method in @PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse.py that returns a numerical array:
Ellipse.angles_from_x0_from
Ellipse.ellipse_radii_from_major_axis_from
Ellipse.x_from_major_axis_from
Ellipse.y_from_major_axis_from
Ellipse.points_from_major_axis_from
Replace bare np.* with xp.* inside the function bodies (xp.linspace, xp.sin, xp.cos, xp.divide, xp.add, xp.sqrt, xp.stack). The total_points_from method stays numpy — its return type is a Python int and it's used to set static shapes outside the JIT trace.
Special case in points_from_major_axis_from: the idx = np.logical_or(np.isnan(x), np.isnan(y)); if np.sum(idx) > 0: raise NotImplementedError() guard is JAX-incompatible (Python if on a traced value). Replace with if xp is np: around the guard — under JAX, NaNs propagate through downstream nansum/nanmean and we'd rather see them than crash inside a JIT trace.
-
Same treatment for EllipseMultipole.points_perturbed_from and EllipseMultipoleScaled.points_perturbed_from in @PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse_multipole.py. Add xp=np, swap np.* for xp.*. The multipole_comps_from and multipole_k_m_and_phi_m_from helpers from @PyAutoGalaxy/autogalaxy/convert.py are called outside the math loop on Python tuples — leave those as-is unless they trip JIT (verify by tracing).
-
Replace the while loops in EllipseMultipole.get_shape_angle (@PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse_multipole.py:66-69) with arithmetic that JAX can trace. The intent is "wrap angle into the open interval (-180/m, 180/m]". A direct replacement using xp.mod works:
period = 360.0 / self.m
angle = xp.mod(angle + period / 2.0, period) - period / 2.0
This produces values in [-period/2, period/2) rather than (-period/2, period/2], which is a tiny boundary-case difference. Verify against the existing tests in @PyAutoGalaxy/test_autogalaxy/ellipse/ and add a test pinning the new behaviour at the boundary (angle = period/2.0) so future changes don't drift unnoticed.
-
The existing call sites in FitEllipse and elsewhere don't pass xp — they get the numpy default and behaviour is unchanged. Don't thread xp through the call sites in this prompt; that happens in prompt 6 and 7 where it actually matters.
-
Add unit tests in @PyAutoGalaxy/test_autogalaxy/ellipse/test_ellipse.py that for one fixed Ellipse and one fixed EllipseMultipole, the xp=np and xp=jnp paths produce numerically identical points to rtol=1e-6. Gate the JAX side with pytest.importorskip("jax").
(NOTE: step 5 above conflicts with PyAutoGalaxy/CLAUDE.md's "Never use JAX in unit tests" rule — the issuing session caught this and dropped the JAX-parity tests during issue creation. Only the boundary test for get_shape_angle is added in test_autogalaxy/.)
- Test bar:
python -m pytest test_autogalaxy/ellipse/ -v passes.
- The reference numbers from prompt 2's workspace_test scripts are unchanged on the numpy path.
Overview
Add an
xp=npkeyword argument to every numeric method onEllipse,EllipseMultipole, andEllipseMultipoleScaled, replacing barenp.*calls withxp.*so the geometry computations trace underjax.jit. Also replaces the two Pythonwhileloops inEllipseMultipole.get_shape_anglewith a singlexp.mod-based wrap so that method is JIT-safe too. Step 5 of 7 in theellipse_fitting_jaxfeature (PyAutoPrompt/z_features/ellipse_fitting_jax.md). NumPy-path behaviour is preserved —xp=npis the default at every entry point and call sites inFitEllipsecontinue to pass nothing.Plan
xp=npthrough fiveEllipse.*_from_major_axis_frommethods that return arrays; replacenp.*calls withxp.*in their bodies.xp=npthroughEllipseMultipole.points_perturbed_fromandEllipseMultipoleScaled.points_perturbed_from.whileloops inEllipseMultipole.get_shape_anglewith a single arithmetic wrap usingxp.mod. Document the (-180/m, 180/m] vs [-180/m, 180/m) boundary semantics change in a docstring note and pin it with a numpy-only boundary test.if np.sum(idx) > 0: raise NotImplementedError()NaN check inpoints_from_major_axis_frombehindif xp is np:— under JAX, NaNs propagate throughnansum/nanmeandownstream so a JIT-time raise would be incorrect.Ellipse.total_points_fromstays numpy (returns a Pythonintused as a static shape outside the JIT trace).PyAutoGalaxy/CLAUDE.mdrule. JAX parity is verified at the workspace_test level — implicitly through prompt 2'sautogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/scripts once prompt 7 wires the JIT path. Until then, thexp=jnppath is exercised only manually during development.rtol=1e-12reference arrays and prompt 2's workspace_test reference numbers continue to pass.Detailed implementation plan
Affected Repositories
Work Classification
Library
Branch Survey
Suggested branch:
feature/ellipse-xpWorktree root:
~/Code/PyAutoLabs-wt/ellipse-xp/(created later by/start_library)Implementation Steps
autogalaxy/ellipse/ellipse/ellipse.py: addxp=npto:Ellipse.angles_from_x0_from(self, pixel_scale, n_i=0, xp=np)— replacenp.linspacewithxp.linspace.Ellipse.ellipse_radii_from_major_axis_from(self, pixel_scale, n_i=0, xp=np)— replacenp.divide,np.sqrt,np.add,np.sin,np.cos. Threadxpinto the innerangles_from_x0_fromcall.Ellipse.x_from_major_axis_from(self, pixel_scale, n_i=0, xp=np)— replacenp.cos. Threadxpinto inner calls.Ellipse.y_from_major_axis_from(self, pixel_scale, n_i=0, xp=np)— replacenp.sin. Threadxpinto inner calls.Ellipse.points_from_major_axis_from(self, pixel_scale, n_i=0, xp=np)— replacenp.stack. Wrap the NaN-check (idx = np.logical_or(np.isnan(x), np.isnan(y)); if np.sum(idx) > 0: raise NotImplementedError()) inif xp is np:. Threadxpintox_from_major_axis_from/y_from_major_axis_fromcalls.Ellipse.total_points_from— unchanged. Stays numpy because it returns a Pythonintused outside JIT traces.autogalaxy/ellipse/ellipse/ellipse_multipole.py:EllipseMultipole.get_shape_angle(self, ellipse, xp=np)— replace the twowhile angle < -180/self.m: angle += 360/self.m/while angle > 180/self.m: angle -= 360/self.mloops with:(-period/2, period/2]; the rewrite returns[-period/2, period/2). Document the boundary difference in a docstring note and pin the boundary case (angle = period/2.0exactly) in a new unit test.EllipseMultipole.points_perturbed_from(self, pixel_scale, points, ellipse, n_i=0, xp=np)— replacenp.arctan2,np.cos,np.sin,np.stack.multipole_comps_from/multipole_k_m_and_phi_m_fromfromconvert.pyoperate on Python tuples and don't trip JIT; leave untouched.EllipseMultipoleScaled.points_perturbed_from(self, pixel_scale, points, ellipse, n_i=0, xp=np)— same treatment asEllipseMultipole.points_perturbed_from.Call sites in
FitEllipse(autogalaxy/ellipse/fit_ellipse.py:69-136) and elsewhere must NOT be updated in this PR — they pass nothing, get the numpy default, behaviour unchanged. Threadingxpinto those call sites happens in prompts 6 and 7.test_autogalaxy/ellipse/test_ellipse.py:test__multipole__get_shape_angle__boundary— numpy-only. Input ellipse withangle()such that the un-normalised offset equalsperiod/2.0exactly; assert the rewrite returns-period/2.0(the new convention). Documents the boundary-case change vs the originalwhile-loop behaviour.PyAutoGalaxy/CLAUDE.md: "Never use JAX in unit tests." JAX parity is checked at the workspace_test level via prompt 2'sautogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/scripts once prompt 7 flips them to the JIT path.Run
python -m pytest test_autogalaxy/ellipse/ -vfrom the worktree. Must report all tests pass — including prompt-3'srtol=1e-12pins on_points_from_major_axis. If those drift, the numpy path semantics have changed and there's a bug.Run
python scripts/jax_likelihood_functions/ellipse/{fit,multipoles}.pyfromautogalaxy_workspace_test/. Confirm all 8 reference numbers match the prompt-2 baseline.Key Files
autogalaxy/ellipse/ellipse/ellipse.py— threadxpthrough five methods.autogalaxy/ellipse/ellipse/ellipse_multipole.py— threadxpthrough three methods, rewriteget_shape_angle.test_autogalaxy/ellipse/test_ellipse.py— one new test (boundary case forget_shape_angle).Testing Approach
rtol=1e-12.Original Prompt
Click to expand starting prompt
Step 5 of the ellipse-JAX series. With the 2D interpolator in place from prompt 4, the next blocker is the geometry math in
@PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse.pyand@PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse_multipole.py. Every routine onEllipseuses barenp.*, andEllipseMultipole.get_shape_angleuses Pythonwhileloops to wrap an angle into[-180/m, 180/m]— both incompatible withjax.jittracing. Convert these to thexp=nppattern documented in@PyAutoGalaxy/CLAUDE.md"JAX Support" section.Please:
Add
xp=npas a keyword argument to every method in@PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse.pythat returns a numerical array:Ellipse.angles_from_x0_fromEllipse.ellipse_radii_from_major_axis_fromEllipse.x_from_major_axis_fromEllipse.y_from_major_axis_fromEllipse.points_from_major_axis_fromReplace bare
np.*withxp.*inside the function bodies (xp.linspace,xp.sin,xp.cos,xp.divide,xp.add,xp.sqrt,xp.stack). Thetotal_points_frommethod stays numpy — its return type is a Pythonintand it's used to set static shapes outside the JIT trace.Special case in
points_from_major_axis_from: theidx = np.logical_or(np.isnan(x), np.isnan(y)); if np.sum(idx) > 0: raise NotImplementedError()guard is JAX-incompatible (Pythonifon a traced value). Replace withif xp is np:around the guard — under JAX, NaNs propagate through downstreamnansum/nanmeanand we'd rather see them than crash inside a JIT trace.Same treatment for
EllipseMultipole.points_perturbed_fromandEllipseMultipoleScaled.points_perturbed_fromin@PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse_multipole.py. Addxp=np, swapnp.*forxp.*. Themultipole_comps_fromandmultipole_k_m_and_phi_m_fromhelpers from@PyAutoGalaxy/autogalaxy/convert.pyare called outside the math loop on Python tuples — leave those as-is unless they trip JIT (verify by tracing).Replace the
whileloops inEllipseMultipole.get_shape_angle(@PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse_multipole.py:66-69) with arithmetic that JAX can trace. The intent is "wrapangleinto the open interval(-180/m, 180/m]". A direct replacement usingxp.modworks:This produces values in
[-period/2, period/2)rather than(-period/2, period/2], which is a tiny boundary-case difference. Verify against the existing tests in@PyAutoGalaxy/test_autogalaxy/ellipse/and add a test pinning the new behaviour at the boundary (angle = period/2.0) so future changes don't drift unnoticed.The existing call sites in
FitEllipseand elsewhere don't passxp— they get the numpy default and behaviour is unchanged. Don't threadxpthrough the call sites in this prompt; that happens in prompt 6 and 7 where it actually matters.Add unit tests in
@PyAutoGalaxy/test_autogalaxy/ellipse/test_ellipse.pythat for one fixedEllipseand one fixedEllipseMultipole, thexp=npandxp=jnppaths produce numerically identical points tortol=1e-6. Gate the JAX side withpytest.importorskip("jax").(NOTE: step 5 above conflicts with
PyAutoGalaxy/CLAUDE.md's "Never use JAX in unit tests" rule — the issuing session caught this and dropped the JAX-parity tests during issue creation. Only the boundary test forget_shape_angleis added in test_autogalaxy/.)python -m pytest test_autogalaxy/ellipse/ -vpasses.