Conversation
There was a problem hiding this comment.
Pull request overview
Refactors kappa_s_and_scale_radius for the gNFW virial-mass–concentration parameterization to improve backend-agnostic execution (NumPy vs JAX) and replace numerical quadrature with an analytic hypergeometric normalization.
Changes:
- Adds JAX detection and backend dispatch (
xp) to support traced/JIT execution. - Replaces
scipy.integrate.quadnormalization with a closed-formhyp2f1expression. - Makes overdensity selection differentiable via backend array ops (
xp.where) and expands documentation.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Returns a callable hyp2f1(a,b,c,z) compatible with the backend xp. | ||
|
|
||
| - NumPy: scipy.special.hyp2f1 | ||
| - JAX (if available): jax.scipy.special.hyp2f1 | ||
| - JAX (fallback): series approximation for 2F1 (sufficient for this gNFW use-case) | ||
| """ | ||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
| # Fallback: truncated series for 2F1(a,a;a+1;z) and general 2F1(a,b;c;z) | ||
| # We implement general 2F1 series: | ||
| # 2F1(a,b;c;z) = sum_{n=0}^{∞} (a)_n (b)_n / (c)_n * z^n / n! | ||
| # | ||
| # Recurrence for terms: | ||
| # t_0 = 1 | ||
| # t_{n+1} = t_n * (a+n)(b+n)/((c+n)(n+1)) * z | ||
| # | ||
| # This is JIT-safe with static max_terms. | ||
| def hyp2f1_series(a, b, c, z): | ||
| a = jnp.asarray(a) | ||
| b = jnp.asarray(b) | ||
| c = jnp.asarray(c) | ||
| z = jnp.asarray(z) | ||
|
|
||
| def body_fun(n, carry): | ||
| t, s = carry | ||
| n_f = jnp.asarray(n, dtype=t.dtype) | ||
| t = t * (a + n_f) * (b + n_f) / ((c + n_f) * (n_f + 1.0)) * z | ||
| s = s + t | ||
| return (t, s) | ||
|
|
||
| # Start: t0 = 1, s0 = 1 | ||
| t0 = jnp.ones_like(z, dtype=jnp.result_type(a, b, c, z)) | ||
| s0 = t0 | ||
|
|
||
| # fori_loop has static iteration count => good under jit/vmap | ||
| tN, sN = jax.lax.fori_loop(0, max_terms - 1, body_fun, (t0, s0)) | ||
| return sN | ||
|
|
||
| return hyp2f1_series | ||
|
|
There was a problem hiding this comment.
The JAX fallback hyp2f1_series uses the Maclaurin series in z, which only converges for |z| < 1 (and conditionally at |z| = 1). In this file the call site uses z = -c where c is the halo concentration (typically ~5–20), so the fallback will diverge / return incorrect results for the default parameters. Consider removing the fallback (raise a clear error requiring jax.scipy.special.hyp2f1) or replacing it with a numerically stable method for z < -1 (e.g., a quadrature with fixed nodes that is JIT-safe, or an analytic continuation tailored to 2F1(a,a;a+1;−c)).
| Returns a callable hyp2f1(a,b,c,z) compatible with the backend xp. | |
| - NumPy: scipy.special.hyp2f1 | |
| - JAX (if available): jax.scipy.special.hyp2f1 | |
| - JAX (fallback): series approximation for 2F1 (sufficient for this gNFW use-case) | |
| """ | |
| import jax | |
| import jax.numpy as jnp | |
| # Fallback: truncated series for 2F1(a,a;a+1;z) and general 2F1(a,b;c;z) | |
| # We implement general 2F1 series: | |
| # 2F1(a,b;c;z) = sum_{n=0}^{∞} (a)_n (b)_n / (c)_n * z^n / n! | |
| # | |
| # Recurrence for terms: | |
| # t_0 = 1 | |
| # t_{n+1} = t_n * (a+n)(b+n)/((c+n)(n+1)) * z | |
| # | |
| # This is JIT-safe with static max_terms. | |
| def hyp2f1_series(a, b, c, z): | |
| a = jnp.asarray(a) | |
| b = jnp.asarray(b) | |
| c = jnp.asarray(c) | |
| z = jnp.asarray(z) | |
| def body_fun(n, carry): | |
| t, s = carry | |
| n_f = jnp.asarray(n, dtype=t.dtype) | |
| t = t * (a + n_f) * (b + n_f) / ((c + n_f) * (n_f + 1.0)) * z | |
| s = s + t | |
| return (t, s) | |
| # Start: t0 = 1, s0 = 1 | |
| t0 = jnp.ones_like(z, dtype=jnp.result_type(a, b, c, z)) | |
| s0 = t0 | |
| # fori_loop has static iteration count => good under jit/vmap | |
| tN, sN = jax.lax.fori_loop(0, max_terms - 1, body_fun, (t0, s0)) | |
| return sN | |
| return hyp2f1_series | |
| Returns a callable hyp2f1(a,b,c,z) compatible with the JAX backend. | |
| This helper requires `jax.scipy.special.hyp2f1` to be available. No | |
| series-based fallback is used, because the gNFW use-case calls 2F1 | |
| with |z| > 1 (e.g. z = -c with c ~ 5–20), where the Maclaurin series | |
| diverges and would yield incorrect results. | |
| """ | |
| try: | |
| import jax # noqa: F401 # kept for potential JAX-type detection elsewhere | |
| import jax.numpy as jnp # noqa: F401 # retained to avoid changing imports broadly | |
| from jax.scipy.special import hyp2f1 as jax_hyp2f1 | |
| except Exception as exc: | |
| raise RuntimeError( | |
| "The gNFW virial-mass–concentration profile requires " | |
| "`jax.scipy.special.hyp2f1` when using the JAX backend. " | |
| "Please install a version of JAX/jaxlib that provides " | |
| "`jax.scipy.special.hyp2f1`, or disable the JAX path." | |
| ) from exc | |
| # Ignore `xp` and `max_terms` here; they are kept for API compatibility. | |
| return jax_hyp2f1 | |
| Parameters | ||
| ---------- | ||
| cosmology | ||
| Cosmology object providing critical density, angular diameter distance | ||
| conversions, and surface mass density calculations. Must support an `xp` | ||
| argument for NumPy/JAX interoperability. | ||
| virial_mass | ||
| Virial mass of the halo in units of solar masses. | ||
| c_2 | ||
| Concentration-like parameter, converted internally to the gNFW | ||
| concentration via `(2 - inner_slope) * c_2`. | ||
| overdens | ||
| Overdensity with respect to the critical density. If zero, the | ||
| Bryan & Norman (1998) redshift-dependent overdensity is used. | ||
| redshift_object | ||
| Redshift of the lens (halo). | ||
| redshift_source | ||
| Redshift of the background source. | ||
| inner_slope | ||
| Inner logarithmic density slope γ of the gNFW profile. | ||
| xp | ||
| Array backend module (`numpy` or `jax.numpy`). All array operations | ||
| are dispatched through this module to ensure compatibility with | ||
| both standard NumPy execution and JAX tracing / JIT compilation. |
There was a problem hiding this comment.
The docstring documents an xp parameter and states JIT-compatibility when xp=jax.numpy, but kappa_s_and_scale_radius does not accept xp and instead infers the backend from virial_mass. Update the docstring to match the actual API, or add an explicit xp argument (and propagate it through) so callers can control the backend deterministically.
| is_jax_bool = is_jax(virial_mass) | ||
|
|
||
| if not is_jax_bool: | ||
| xp = np | ||
| else: | ||
| from jax import numpy as jnp | ||
| xp = jnp | ||
|
|
There was a problem hiding this comment.
Backend selection is based only on is_jax(virial_mass). If a caller passes virial_mass as a Python/NumPy scalar but supplies other inputs as JAX arrays/tracers (e.g. c_2, overdens, or redshifts), this will select NumPy and then attempt NumPy ops on traced values, breaking JIT/grad. Consider either adding an explicit xp parameter, or inferring the backend from all numeric inputs (or at least from c_2 / overdens as well).
| # Bryan & Norman (1998) overdensity if overdens == 0 | ||
| x = cosmology.Om(redshift_object, xp=xp) - 1.0 | ||
| overdens_bn98 = 18.0 * xp.pi**2 + 82.0 * x - 39.0 * x**2 | ||
| overdens = xp.where(overdens == 0, overdens_bn98, overdens) |
There was a problem hiding this comment.
Using xp.where here changes the NumPy-path behavior: when overdens is a scalar, np.where(...) returns a 0-d ndarray, which then propagates so the function returns arrays instead of plain scalars under NumPy. If downstream code expects Python / NumPy scalars (e.g. for JSON serialization), consider keeping the original Python if overdens == 0: branch in the NumPy path and reserving xp.where for the JAX path (or explicitly converting 0-d arrays back to scalars in the NumPy backend).
| overdens = xp.where(overdens == 0, overdens_bn98, overdens) | |
| # Use a Python scalar branch for NumPy scalars to avoid returning 0-d arrays; | |
| # keep xp.where for JAX and array inputs. | |
| if is_jax(overdens) or not np.isscalar(overdens): | |
| overdens = xp.where(overdens == 0, overdens_bn98, overdens) | |
| else: | |
| if overdens == 0: | |
| overdens = overdens_bn98 |
| from jax import Array | ||
| from jax.core import Tracer | ||
| return isinstance(x, (Array, Tracer)) | ||
| except Exception: |
There was a problem hiding this comment.
is_jax catches Exception broadly, which can mask real import/runtime errors inside JAX and silently fall back to NumPy. It would be safer to catch only ImportError (and potentially AttributeError for older JAX types) so unexpected failures don’t change numerical backends without warning.
| except Exception: | |
| except (ImportError, ModuleNotFoundError, AttributeError): |
This pull request significantly refactors the
kappa_s_and_scale_radiusfunction ingnfw_virial_mass_conc.pyto improve backend compatibility, performance, and analytic correctness. The update introduces JAX support (including JIT and vmap compatibility), replaces numerical quadrature with an analytic normalization using the hypergeometric function, and ensures all array operations are backend-agnostic. The function now also handles overdensity calculation in a fully differentiable way and is documented in detail.Backend compatibility and JAX support:
is_jaxutility to detect JAX arrays and ensure correct backend selection.Analytic improvements and performance:
Documentation and maintainability: