Skip to content

fix: subhalo redshift as free parameter raises TracerBoolConversionError under JAX #498

@Jammy2211

Description

@Jammy2211

Overview

Reported on Slack by @qiuhan96 (working with an undergraduate at the University of Groningen) on the public pip release of autolens. Setting a subhalo's redshift as a free parameter (af.UniformPrior) raises jax.errors.TracerBoolConversionError during fitting. The only workaround today is use_jax=False, which is much slower. The cause is that autolens/lens/tracer_util.py performs Python-level <, <=, ==, float() on the redshift values, which are traced arrays under jax.jit when the subhalo redshift is free.

Plan

  • Add a unit test that fits a model with a free-parameter subhalo redshift under JAX and asserts no TracerBoolConversionError.
  • Reformulate grid_2d_at_redshift_from so the redshift comparisons used to choose where to insert the subhalo plane are JAX-friendly (e.g. compute candidate traced grids and select with jnp.where / jax.lax.switch, rather than branching on Python booleans).
  • Stop calling float() and pairwise < on potentially-traced redshifts in plane_redshifts_from / planes_from — sort by the concrete redshifts of the non-subhalo galaxies (which are Python floats) and carry the subhalo redshift through as a traced scalar.
  • Stop forcing the traced subhalo centre back through tuple(...) in AnalysisLens.tracer_via_instance_from (autolens/analysis/analysis/lens.py:116).
  • Cover the edge cases in tests: subhalo before lens plane, between lens and source, equal to lens redshift, equal to source redshift.
  • Verify the full Nautilus/Dynesty fit from the reporter's script runs end-to-end with use_jax=True.
Detailed implementation plan

Affected Repositories

  • Jammy2211/PyAutoLens (primary)

Branch Survey

Repository Current Branch Dirty?
./PyAutoLens main clean

Suggested branch: feature/subhalo-redshift-jax-fix

Implementation Steps

  1. Reproduce. Add a regression test in test_autolens/lens/test_tracer_util.py (or a new test_autolens/analysis/test_analysis_lens_jax.py) that builds a 3-galaxy model (lens at z=0.5, subhalo with af.UniformPrior(0.2, 0.9), source at z=1.0) and calls jax.jit(analysis.fit_from)(instance), asserting it returns without raising.

  2. Refactor autolens/lens/tracer_util.py:plane_redshifts_from. The list comprehension [float(galaxy.redshift) for galaxy in galaxies_ascending_redshift] (line 49) and the sorted(galaxies, key=lambda g: g.redshift) call (line 46) both fail on traced redshifts. Either:

    • Split inputs into "concrete-redshift galaxies" (sorted with Python sorted) and "traced-redshift galaxies" (the subhalo), and recombine without calling float(); or
    • Accept that this helper only ever sees Python-float redshifts and move the subhalo handling out before the call.
  3. Refactor autolens/lens/tracer_util.py:grid_2d_at_redshift_from (line 199). The three branches that fail under JAX are:

    • Line 249: if redshift <= plane_redshifts[0]: return grid.copy()
    • Line 257: [... if galaxies[0].redshift == redshift]
    • Line 267-268: for plane_index, plane_redshift in enumerate(plane_redshifts): if redshift > plane_redshift: plane_index_insert = plane_index + 1

    Replace with a structured selection: compute the traced grid for every candidate insertion position (before plane 0, between each pair, after the last plane), then pick the right one via jax.lax.switch (or jnp.where over a stacked result) using a comparison vector built with jnp.less / jnp.less_equal. The numpy path keeps the existing implementation behind if xp is np:.

  4. Update autolens/analysis/analysis/lens.py:99-116:AnalysisLens.tracer_via_instance_from. The line instance.galaxies.subhalo.mass.centre = tuple(subhalo_centre.in_list[0]) forces a Python tuple(...) of traced scalars. Either keep the centre as a traced 2-vector or skip the round-trip and pass the traced centre directly into the downstream Tracer build.

  5. Tests. Cover four scenarios in test_autolens/lens/test_tracer_util.py:

    • subhalo redshift < lens.redshift
    • subhalo redshift == lens.redshift
    • lens.redshift < subhalo redshift < source.redshift (the typical case)
    • subhalo redshift == source.redshift

    Run each under both xp=np and xp=jnp (the latter inside jax.jit).

  6. End-to-end check. Run the reporter's reproduction script (the model defined above plus a Nautilus/Dynesty search with a tiny nlive) with use_jax=True and confirm the fit completes.

Key Files

  • autolens/lens/tracer_util.pyplane_redshifts_from, planes_from, grid_2d_at_redshift_from.
  • autolens/analysis/analysis/lens.pyAnalysisLens.tracer_via_instance_from (subhalo branch at line 99-116).
  • test_autolens/lens/test_tracer_util.py — new JAX regression tests.

Out of scope

  • Changes to autogalaxy or autoarray.
  • Anything related to the instance.perturb shortcut at lens.py:99 — that path is unaffected by the bug (subhalo redshift is taken from the model, not perturbed).

Original Prompt

Click to expand starting prompt

Free-parameter subhalo redshift breaks under JAX (TracerBoolConversionError)

Reporter

Reported on Slack by @qiuhan96 (with an undergraduate at the University of Groningen). Running the public pip release of autolens on Python 3.13.

Symptom

Setting the subhalo's redshift as a free parameter (a af.UniformPrior) raises a jax.errors.TracerBoolConversionError during model fitting. The fit only runs if use_jax=False is set, which is much slower.

Reproduction

import autofit as af
import autolens as al

bulge = al.model_util.mge_model_from(
    mask_radius=mask_radius,
    total_gaussians=20,
    gaussian_per_basis=2,
    centre_prior_is_uniform=True,
)

mass = af.Model(al.mp.Isothermal)
shear = af.Model(al.mp.ExternalShear)

lens = af.Model(al.Galaxy, redshift=0.5, bulge=bulge, mass=mass, shear=shear)

# Subhalo
subhalo_mass = af.Model(al.mp.IsothermalSph)
subhalo_mass.centre_0 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1)
subhalo_mass.centre_1 = af.UniformPrior(lower_limit=1.2, upper_limit=1.8)
subhalo_mass.einstein_radius = af.UniformPrior(lower_limit=0.01, upper_limit=0.4)

# Trigger: free-parameter redshift
redshift_subhalo = af.UniformPrior(lower_limit=0.2, upper_limit=0.9)
# redshift_subhalo = 0.6   # <-- works fine

subhalo_galaxy = af.Model(al.Galaxy, redshift=redshift_subhalo, mass=subhalo_mass)

# Source
bulge = al.model_util.mge_model_from(
    mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=False
)
source = af.Model(al.Galaxy, redshift=1.0, bulge=bulge)

model = af.Collection(galaxies=af.Collection(lens=lens, subhalo=subhalo_galaxy, source=source))

When redshift_subhalo is a UniformPrior, the fit raises:

File autolens/analysis/analysis/lens.py:99, in AnalysisLens.tracer_via_instance_from
    subhalo_centre = tracer_util.grid_2d_at_redshift_from(
        galaxies=instance.galaxies,
        redshift=instance.galaxies.subhalo.redshift,
        ...
    )
File autolens/lens/tracer_util.py:247, in grid_2d_at_redshift_from
    plane_redshifts = plane_redshifts_from(galaxies=galaxies)
File autolens/lens/tracer_util.py:46, in plane_redshifts_from
    galaxies_ascending_redshift = sorted(galaxies, key=lambda galaxy: galaxy.redshift)
...
jax._src.core.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].

Workaround: setting use_jax=False lets the fit run, but is much slower.

Root cause

autolens/lens/tracer_util.py performs several Python-level operations on the redshift values that fail when one of them is a JAX traced scalar:

  • Line 46sorted(galaxies, key=lambda g: g.redshift) does pairwise < comparisons on Python objects holding traced redshifts.
  • Line 49[float(g.redshift) for g in ...] calls float() on a traced scalar.
  • Line 249if redshift <= plane_redshifts[0]: is a Python branch on a traced boolean.
  • Line 257[plane_index for ... if galaxies[0].redshift == redshift] filters on a traced boolean.
  • Line 267-268for ...: if redshift > plane_redshift: plane_index_insert = plane_index + 1 again branches on a traced boolean and uses a Python integer to index the inserted plane.

grid_2d_at_redshift_from is called from autolens/analysis/analysis/lens.py:99 whenever instance.galaxies.subhalo exists, with redshift=instance.galaxies.subhalo.redshift. When that redshift is free, the value passed in is a traced array and every comparison above is illegal under jax.jit.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions