Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors traced_grid_2d_list_from in autolens/lens/tracer_util.py to adjust how grids/deflections are converted to arrays during multi-plane ray-tracing, and simplifies the per-plane deflection summation.
Changes:
- Switches
scaled_gridinitialization fromgrid.copy()togrid.arrayand scales/subtracts deflections using.array. - Wraps
scaled_gridintoaa.Grid2DIrregularafter applying all prior-plane deflection subtractions. - Simplifies deflection calculation from
map(...)to a generator expression and removes the non-finite sanitization step.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -156,10 +158,14 @@ def traced_grid_2d_list_from( | |||
| ) | |||
|
|
|||
| scaled_deflections = ( | |||
| scaling_factor * traced_deflection_list[previous_plane_index] | |||
| scaling_factor * traced_deflection_list[previous_plane_index].array | |||
| ) | |||
|
|
|||
| scaled_grid -= scaled_deflections | |||
| scaled_grid = scaled_grid - scaled_deflections | |||
|
|
|||
| scaled_grid = aa.Grid2DIrregular( | |||
| values=scaled_grid, | |||
| ) | |||
There was a problem hiding this comment.
scaled_grid is now always wrapped as aa.Grid2DIrregular, even when the input grid is an aa.Grid2D (or other Grid2DLike). This changes the function’s return types and also makes it inconsistent with grid_2d_at_redshift_from, which returns grid.copy() (preserving the original grid type) for redshift <= plane_redshifts[0]. Recommend preserving the input grid type (e.g., keep scaled_grid as grid.copy() for plane 0, or conditionally re-wrap based on the original grid class) to avoid downstream type assumptions and potential test regressions.
| scaled_grid = scaled_grid - scaled_deflections | ||
|
|
||
| scaled_grid = aa.Grid2DIrregular( | ||
| values=scaled_grid, |
There was a problem hiding this comment.
aa.Grid2DIrregular is constructed without xp=xp, even though this utility function is explicitly parameterized by xp for NumPy vs JAX backends. In other JAX-sensitive code paths (e.g. constructing grids in AnalysisLens.tracer_via_instance_from), xp is passed to keep arrays on the correct backend. Consider passing xp=xp here as well to avoid accidental coercion back to NumPy when xp is JAX.
| values=scaled_grid, | |
| values=scaled_grid, | |
| xp=xp, |
| deflections_yx_2d = sum( | ||
| map(lambda g: g.deflections_yx_2d_from(grid=scaled_grid, xp=xp), galaxies) | ||
| ) | ||
|
|
||
| # Remove NaN deflection values to sanitize the ray-tracing calculation for JAX. | ||
| deflections_yx_2d = xp.where( | ||
| xp.isfinite(deflections_yx_2d.array), deflections_yx_2d.array, 0.0 | ||
| (g.deflections_yx_2d_from(grid=scaled_grid, xp=xp) for g in galaxies) | ||
| ) |
There was a problem hiding this comment.
Removing the xp.where(xp.isfinite(...), ..., 0.0) sanitization means any NaN/Inf produced by a galaxy deflection (e.g. at profile centres) will now propagate into traced_deflection_list and subsequent scaled_grid calculations. This can break ray-tracing results and defeats the previous JAX-focused safeguard. Consider reinstating a finite-value sanitization step (or ensuring deflection implementations never return non-finite values) and add a regression test that exercises the non-finite case under the JAX backend.
This pull request refactors the
traced_grid_2d_list_fromfunction inautolens/lens/tracer_util.pyto improve the handling of grid and deflection arrays, ensuring compatibility and correctness in ray-tracing calculations. The main changes focus on converting grids and deflections to arrays at appropriate points, and simplifying the deflection computation loop.Grid and deflection handling improvements:
scaled_gridto use the.arrayattribute instead of.copy(), ensuring the grid is consistently treated as a NumPy array for calculations..arrayfor each deflection, and replaced in-place subtraction with explicit array subtraction for clarity and correctness.scaled_gridin anaa.Grid2DIrregularobject after all scaling and subtraction, standardizing the output type.Deflection computation simplification:
map, and removed the explicit sanitization of NaN values, streamlining the loop and improving readability.