Overview
Reported on Slack by @qiuhan96 (working with an undergraduate at the University of Groningen) on the public pip release of autolens. Setting a subhalo's redshift as a free parameter (af.UniformPrior) raises jax.errors.TracerBoolConversionError during fitting. The only workaround today is use_jax=False, which is much slower. The cause is that autolens/lens/tracer_util.py performs Python-level <, <=, ==, float() on the redshift values, which are traced arrays under jax.jit when the subhalo redshift is free.
Plan
- Add a unit test that fits a model with a free-parameter subhalo redshift under JAX and asserts no
TracerBoolConversionError.
- Reformulate
grid_2d_at_redshift_from so the redshift comparisons used to choose where to insert the subhalo plane are JAX-friendly (e.g. compute candidate traced grids and select with jnp.where / jax.lax.switch, rather than branching on Python booleans).
- Stop calling
float() and pairwise < on potentially-traced redshifts in plane_redshifts_from / planes_from — sort by the concrete redshifts of the non-subhalo galaxies (which are Python floats) and carry the subhalo redshift through as a traced scalar.
- Stop forcing the traced subhalo centre back through
tuple(...) in AnalysisLens.tracer_via_instance_from (autolens/analysis/analysis/lens.py:116).
- Cover the edge cases in tests: subhalo before lens plane, between lens and source, equal to lens redshift, equal to source redshift.
- Verify the full Nautilus/Dynesty fit from the reporter's script runs end-to-end with
use_jax=True.
Detailed implementation plan
Affected Repositories
Jammy2211/PyAutoLens (primary)
Branch Survey
| Repository |
Current Branch |
Dirty? |
| ./PyAutoLens |
main |
clean |
Suggested branch: feature/subhalo-redshift-jax-fix
Implementation Steps
-
Reproduce. Add a regression test in test_autolens/lens/test_tracer_util.py (or a new test_autolens/analysis/test_analysis_lens_jax.py) that builds a 3-galaxy model (lens at z=0.5, subhalo with af.UniformPrior(0.2, 0.9), source at z=1.0) and calls jax.jit(analysis.fit_from)(instance), asserting it returns without raising.
-
Refactor autolens/lens/tracer_util.py:plane_redshifts_from. The list comprehension [float(galaxy.redshift) for galaxy in galaxies_ascending_redshift] (line 49) and the sorted(galaxies, key=lambda g: g.redshift) call (line 46) both fail on traced redshifts. Either:
- Split inputs into "concrete-redshift galaxies" (sorted with Python
sorted) and "traced-redshift galaxies" (the subhalo), and recombine without calling float(); or
- Accept that this helper only ever sees Python-float redshifts and move the subhalo handling out before the call.
-
Refactor autolens/lens/tracer_util.py:grid_2d_at_redshift_from (line 199). The three branches that fail under JAX are:
- Line 249:
if redshift <= plane_redshifts[0]: return grid.copy()
- Line 257:
[... if galaxies[0].redshift == redshift]
- Line 267-268:
for plane_index, plane_redshift in enumerate(plane_redshifts): if redshift > plane_redshift: plane_index_insert = plane_index + 1
Replace with a structured selection: compute the traced grid for every candidate insertion position (before plane 0, between each pair, after the last plane), then pick the right one via jax.lax.switch (or jnp.where over a stacked result) using a comparison vector built with jnp.less / jnp.less_equal. The numpy path keeps the existing implementation behind if xp is np:.
-
Update autolens/analysis/analysis/lens.py:99-116:AnalysisLens.tracer_via_instance_from. The line instance.galaxies.subhalo.mass.centre = tuple(subhalo_centre.in_list[0]) forces a Python tuple(...) of traced scalars. Either keep the centre as a traced 2-vector or skip the round-trip and pass the traced centre directly into the downstream Tracer build.
-
Tests. Cover four scenarios in test_autolens/lens/test_tracer_util.py:
- subhalo redshift
< lens.redshift
- subhalo redshift
== lens.redshift
- lens.redshift
< subhalo redshift < source.redshift (the typical case)
- subhalo redshift
== source.redshift
Run each under both xp=np and xp=jnp (the latter inside jax.jit).
-
End-to-end check. Run the reporter's reproduction script (the model defined above plus a Nautilus/Dynesty search with a tiny nlive) with use_jax=True and confirm the fit completes.
Key Files
autolens/lens/tracer_util.py — plane_redshifts_from, planes_from, grid_2d_at_redshift_from.
autolens/analysis/analysis/lens.py — AnalysisLens.tracer_via_instance_from (subhalo branch at line 99-116).
test_autolens/lens/test_tracer_util.py — new JAX regression tests.
Out of scope
- Changes to
autogalaxy or autoarray.
- Anything related to the
instance.perturb shortcut at lens.py:99 — that path is unaffected by the bug (subhalo redshift is taken from the model, not perturbed).
Original Prompt
Click to expand starting prompt
Free-parameter subhalo redshift breaks under JAX (TracerBoolConversionError)
Reporter
Reported on Slack by @qiuhan96 (with an undergraduate at the University of Groningen). Running the public pip release of autolens on Python 3.13.
Symptom
Setting the subhalo's redshift as a free parameter (a af.UniformPrior) raises a jax.errors.TracerBoolConversionError during model fitting. The fit only runs if use_jax=False is set, which is much slower.
Reproduction
import autofit as af
import autolens as al
bulge = al.model_util.mge_model_from(
mask_radius=mask_radius,
total_gaussians=20,
gaussian_per_basis=2,
centre_prior_is_uniform=True,
)
mass = af.Model(al.mp.Isothermal)
shear = af.Model(al.mp.ExternalShear)
lens = af.Model(al.Galaxy, redshift=0.5, bulge=bulge, mass=mass, shear=shear)
# Subhalo
subhalo_mass = af.Model(al.mp.IsothermalSph)
subhalo_mass.centre_0 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1)
subhalo_mass.centre_1 = af.UniformPrior(lower_limit=1.2, upper_limit=1.8)
subhalo_mass.einstein_radius = af.UniformPrior(lower_limit=0.01, upper_limit=0.4)
# Trigger: free-parameter redshift
redshift_subhalo = af.UniformPrior(lower_limit=0.2, upper_limit=0.9)
# redshift_subhalo = 0.6 # <-- works fine
subhalo_galaxy = af.Model(al.Galaxy, redshift=redshift_subhalo, mass=subhalo_mass)
# Source
bulge = al.model_util.mge_model_from(
mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=False
)
source = af.Model(al.Galaxy, redshift=1.0, bulge=bulge)
model = af.Collection(galaxies=af.Collection(lens=lens, subhalo=subhalo_galaxy, source=source))
When redshift_subhalo is a UniformPrior, the fit raises:
File autolens/analysis/analysis/lens.py:99, in AnalysisLens.tracer_via_instance_from
subhalo_centre = tracer_util.grid_2d_at_redshift_from(
galaxies=instance.galaxies,
redshift=instance.galaxies.subhalo.redshift,
...
)
File autolens/lens/tracer_util.py:247, in grid_2d_at_redshift_from
plane_redshifts = plane_redshifts_from(galaxies=galaxies)
File autolens/lens/tracer_util.py:46, in plane_redshifts_from
galaxies_ascending_redshift = sorted(galaxies, key=lambda galaxy: galaxy.redshift)
...
jax._src.core.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
Workaround: setting use_jax=False lets the fit run, but is much slower.
Root cause
autolens/lens/tracer_util.py performs several Python-level operations on the redshift values that fail when one of them is a JAX traced scalar:
- Line 46 —
sorted(galaxies, key=lambda g: g.redshift) does pairwise < comparisons on Python objects holding traced redshifts.
- Line 49 —
[float(g.redshift) for g in ...] calls float() on a traced scalar.
- Line 249 —
if redshift <= plane_redshifts[0]: is a Python branch on a traced boolean.
- Line 257 —
[plane_index for ... if galaxies[0].redshift == redshift] filters on a traced boolean.
- Line 267-268 —
for ...: if redshift > plane_redshift: plane_index_insert = plane_index + 1 again branches on a traced boolean and uses a Python integer to index the inserted plane.
grid_2d_at_redshift_from is called from autolens/analysis/analysis/lens.py:99 whenever instance.galaxies.subhalo exists, with redshift=instance.galaxies.subhalo.redshift. When that redshift is free, the value passed in is a traced array and every comparison above is illegal under jax.jit.
Overview
Reported on Slack by @qiuhan96 (working with an undergraduate at the University of Groningen) on the public pip release of
autolens. Setting a subhalo's redshift as a free parameter (af.UniformPrior) raisesjax.errors.TracerBoolConversionErrorduring fitting. The only workaround today isuse_jax=False, which is much slower. The cause is thatautolens/lens/tracer_util.pyperforms Python-level<,<=,==,float()on the redshift values, which are traced arrays underjax.jitwhen the subhalo redshift is free.Plan
TracerBoolConversionError.grid_2d_at_redshift_fromso the redshift comparisons used to choose where to insert the subhalo plane are JAX-friendly (e.g. compute candidate traced grids and select withjnp.where/jax.lax.switch, rather than branching on Python booleans).float()and pairwise<on potentially-traced redshifts inplane_redshifts_from/planes_from— sort by the concrete redshifts of the non-subhalo galaxies (which are Python floats) and carry the subhalo redshift through as a traced scalar.tuple(...)inAnalysisLens.tracer_via_instance_from(autolens/analysis/analysis/lens.py:116).use_jax=True.Detailed implementation plan
Affected Repositories
Jammy2211/PyAutoLens(primary)Branch Survey
Suggested branch:
feature/subhalo-redshift-jax-fixImplementation Steps
Reproduce. Add a regression test in
test_autolens/lens/test_tracer_util.py(or a newtest_autolens/analysis/test_analysis_lens_jax.py) that builds a 3-galaxy model (lens at z=0.5, subhalo withaf.UniformPrior(0.2, 0.9), source at z=1.0) and callsjax.jit(analysis.fit_from)(instance), asserting it returns without raising.Refactor
autolens/lens/tracer_util.py:plane_redshifts_from. The list comprehension[float(galaxy.redshift) for galaxy in galaxies_ascending_redshift](line 49) and thesorted(galaxies, key=lambda g: g.redshift)call (line 46) both fail on traced redshifts. Either:sorted) and "traced-redshift galaxies" (the subhalo), and recombine without callingfloat(); orRefactor
autolens/lens/tracer_util.py:grid_2d_at_redshift_from(line 199). The three branches that fail under JAX are:if redshift <= plane_redshifts[0]: return grid.copy()[... if galaxies[0].redshift == redshift]for plane_index, plane_redshift in enumerate(plane_redshifts): if redshift > plane_redshift: plane_index_insert = plane_index + 1Replace with a structured selection: compute the traced grid for every candidate insertion position (before plane 0, between each pair, after the last plane), then pick the right one via
jax.lax.switch(orjnp.whereover a stacked result) using a comparison vector built withjnp.less/jnp.less_equal. The numpy path keeps the existing implementation behindif xp is np:.Update
autolens/analysis/analysis/lens.py:99-116:AnalysisLens.tracer_via_instance_from. The lineinstance.galaxies.subhalo.mass.centre = tuple(subhalo_centre.in_list[0])forces a Pythontuple(...)of traced scalars. Either keep the centre as a traced 2-vector or skip the round-trip and pass the traced centre directly into the downstreamTracerbuild.Tests. Cover four scenarios in
test_autolens/lens/test_tracer_util.py:< lens.redshift== lens.redshift<subhalo redshift<source.redshift (the typical case)== source.redshiftRun each under both
xp=npandxp=jnp(the latter insidejax.jit).End-to-end check. Run the reporter's reproduction script (the model defined above plus a Nautilus/Dynesty search with a tiny
nlive) withuse_jax=Trueand confirm the fit completes.Key Files
autolens/lens/tracer_util.py—plane_redshifts_from,planes_from,grid_2d_at_redshift_from.autolens/analysis/analysis/lens.py—AnalysisLens.tracer_via_instance_from(subhalo branch at line 99-116).test_autolens/lens/test_tracer_util.py— new JAX regression tests.Out of scope
autogalaxyorautoarray.instance.perturbshortcut atlens.py:99— that path is unaffected by the bug (subhalo redshift is taken from the model, not perturbed).Original Prompt
Click to expand starting prompt
Free-parameter subhalo redshift breaks under JAX (
TracerBoolConversionError)Reporter
Reported on Slack by @qiuhan96 (with an undergraduate at the University of Groningen). Running the public pip release of
autolenson Python 3.13.Symptom
Setting the subhalo's redshift as a free parameter (a
af.UniformPrior) raises ajax.errors.TracerBoolConversionErrorduring model fitting. The fit only runs ifuse_jax=Falseis set, which is much slower.Reproduction
When
redshift_subhalois aUniformPrior, the fit raises:Workaround: setting
use_jax=Falselets the fit run, but is much slower.Root cause
autolens/lens/tracer_util.pyperforms several Python-level operations on the redshift values that fail when one of them is a JAX traced scalar:sorted(galaxies, key=lambda g: g.redshift)does pairwise<comparisons on Python objects holding traced redshifts.[float(g.redshift) for g in ...]callsfloat()on a traced scalar.if redshift <= plane_redshifts[0]:is a Python branch on a traced boolean.[plane_index for ... if galaxies[0].redshift == redshift]filters on a traced boolean.for ...: if redshift > plane_redshift: plane_index_insert = plane_index + 1again branches on a traced boolean and uses a Python integer to index the inserted plane.grid_2d_at_redshift_fromis called fromautolens/analysis/analysis/lens.py:99wheneverinstance.galaxies.subhaloexists, withredshift=instance.galaxies.subhalo.redshift. When that redshift is free, the value passed in is a traced array and every comparison above is illegal underjax.jit.