Skip to content

feature/jaxify_gnfw_conc#286

Merged
Jammy2211 merged 2 commits intomainfrom
feature/jaxify_gnfw_conc
Mar 2, 2026
Merged

feature/jaxify_gnfw_conc#286
Jammy2211 merged 2 commits intomainfrom
feature/jaxify_gnfw_conc

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

This pull request significantly refactors the kappa_s_and_scale_radius function in gnfw_virial_mass_conc.py to improve backend compatibility, performance, and analytic correctness. The update introduces JAX support (including JIT and vmap compatibility), replaces numerical quadrature with an analytic normalization using the hypergeometric function, and ensures all array operations are backend-agnostic. The function now also handles overdensity calculation in a fully differentiable way and is documented in detail.

Backend compatibility and JAX support:

  • Refactored the function to work seamlessly with both NumPy and JAX backends, including dynamic dispatch of array operations and special functions, and added a fallback series implementation for the hypergeometric function if JAX's native version is unavailable.
  • Added the is_jax utility to detect JAX arrays and ensure correct backend selection.

Analytic improvements and performance:

  • Replaced the numerical quadrature for the gNFW normalization with a closed-form analytic expression using the hypergeometric function, improving both speed and differentiability.
  • Overdensity is now calculated using backend array operations, supporting differentiability and JIT compilation.

Documentation and maintainability:

  • Added comprehensive docstrings explaining the function’s parameters, return values,

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors kappa_s_and_scale_radius for the gNFW virial-mass–concentration parameterization to improve backend-agnostic execution (NumPy vs JAX) and replace numerical quadrature with an analytic hypergeometric normalization.

Changes:

  • Adds JAX detection and backend dispatch (xp) to support traced/JIT execution.
  • Replaces scipy.integrate.quad normalization with a closed-form hyp2f1 expression.
  • Makes overdensity selection differentiable via backend array ops (xp.where) and expands documentation.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 19 to 59
Returns a callable hyp2f1(a,b,c,z) compatible with the backend xp.

- NumPy: scipy.special.hyp2f1
- JAX (if available): jax.scipy.special.hyp2f1
- JAX (fallback): series approximation for 2F1 (sufficient for this gNFW use-case)
"""
import jax
import jax.numpy as jnp

# Fallback: truncated series for 2F1(a,a;a+1;z) and general 2F1(a,b;c;z)
# We implement general 2F1 series:
# 2F1(a,b;c;z) = sum_{n=0}^{∞} (a)_n (b)_n / (c)_n * z^n / n!
#
# Recurrence for terms:
# t_0 = 1
# t_{n+1} = t_n * (a+n)(b+n)/((c+n)(n+1)) * z
#
# This is JIT-safe with static max_terms.
def hyp2f1_series(a, b, c, z):
a = jnp.asarray(a)
b = jnp.asarray(b)
c = jnp.asarray(c)
z = jnp.asarray(z)

def body_fun(n, carry):
t, s = carry
n_f = jnp.asarray(n, dtype=t.dtype)
t = t * (a + n_f) * (b + n_f) / ((c + n_f) * (n_f + 1.0)) * z
s = s + t
return (t, s)

# Start: t0 = 1, s0 = 1
t0 = jnp.ones_like(z, dtype=jnp.result_type(a, b, c, z))
s0 = t0

# fori_loop has static iteration count => good under jit/vmap
tN, sN = jax.lax.fori_loop(0, max_terms - 1, body_fun, (t0, s0))
return sN

return hyp2f1_series

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The JAX fallback hyp2f1_series uses the Maclaurin series in z, which only converges for |z| < 1 (and conditionally at |z| = 1). In this file the call site uses z = -c where c is the halo concentration (typically ~5–20), so the fallback will diverge / return incorrect results for the default parameters. Consider removing the fallback (raise a clear error requiring jax.scipy.special.hyp2f1) or replacing it with a numerically stable method for z < -1 (e.g., a quadrature with fixed nodes that is JIT-safe, or an analytic continuation tailored to 2F1(a,a;a+1;−c)).

Suggested change
Returns a callable hyp2f1(a,b,c,z) compatible with the backend xp.
- NumPy: scipy.special.hyp2f1
- JAX (if available): jax.scipy.special.hyp2f1
- JAX (fallback): series approximation for 2F1 (sufficient for this gNFW use-case)
"""
import jax
import jax.numpy as jnp
# Fallback: truncated series for 2F1(a,a;a+1;z) and general 2F1(a,b;c;z)
# We implement general 2F1 series:
# 2F1(a,b;c;z) = sum_{n=0}^{∞} (a)_n (b)_n / (c)_n * z^n / n!
#
# Recurrence for terms:
# t_0 = 1
# t_{n+1} = t_n * (a+n)(b+n)/((c+n)(n+1)) * z
#
# This is JIT-safe with static max_terms.
def hyp2f1_series(a, b, c, z):
a = jnp.asarray(a)
b = jnp.asarray(b)
c = jnp.asarray(c)
z = jnp.asarray(z)
def body_fun(n, carry):
t, s = carry
n_f = jnp.asarray(n, dtype=t.dtype)
t = t * (a + n_f) * (b + n_f) / ((c + n_f) * (n_f + 1.0)) * z
s = s + t
return (t, s)
# Start: t0 = 1, s0 = 1
t0 = jnp.ones_like(z, dtype=jnp.result_type(a, b, c, z))
s0 = t0
# fori_loop has static iteration count => good under jit/vmap
tN, sN = jax.lax.fori_loop(0, max_terms - 1, body_fun, (t0, s0))
return sN
return hyp2f1_series
Returns a callable hyp2f1(a,b,c,z) compatible with the JAX backend.
This helper requires `jax.scipy.special.hyp2f1` to be available. No
series-based fallback is used, because the gNFW use-case calls 2F1
with |z| > 1 (e.g. z = -c with c ~ 520), where the Maclaurin series
diverges and would yield incorrect results.
"""
try:
import jax # noqa: F401 # kept for potential JAX-type detection elsewhere
import jax.numpy as jnp # noqa: F401 # retained to avoid changing imports broadly
from jax.scipy.special import hyp2f1 as jax_hyp2f1
except Exception as exc:
raise RuntimeError(
"The gNFW virial-mass–concentration profile requires "
"`jax.scipy.special.hyp2f1` when using the JAX backend. "
"Please install a version of JAX/jaxlib that provides "
"`jax.scipy.special.hyp2f1`, or disable the JAX path."
) from exc
# Ignore `xp` and `max_terms` here; they are kept for API compatibility.
return jax_hyp2f1

Copilot uses AI. Check for mistakes.
Comment on lines +93 to +116
Parameters
----------
cosmology
Cosmology object providing critical density, angular diameter distance
conversions, and surface mass density calculations. Must support an `xp`
argument for NumPy/JAX interoperability.
virial_mass
Virial mass of the halo in units of solar masses.
c_2
Concentration-like parameter, converted internally to the gNFW
concentration via `(2 - inner_slope) * c_2`.
overdens
Overdensity with respect to the critical density. If zero, the
Bryan & Norman (1998) redshift-dependent overdensity is used.
redshift_object
Redshift of the lens (halo).
redshift_source
Redshift of the background source.
inner_slope
Inner logarithmic density slope γ of the gNFW profile.
xp
Array backend module (`numpy` or `jax.numpy`). All array operations
are dispatched through this module to ensure compatibility with
both standard NumPy execution and JAX tracing / JIT compilation.
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring documents an xp parameter and states JIT-compatibility when xp=jax.numpy, but kappa_s_and_scale_radius does not accept xp and instead infers the backend from virial_mass. Update the docstring to match the actual API, or add an explicit xp argument (and propagate it through) so callers can control the backend deterministically.

Copilot uses AI. Check for mistakes.
Comment on lines +137 to +144
is_jax_bool = is_jax(virial_mass)

if not is_jax_bool:
xp = np
else:
from jax import numpy as jnp
xp = jnp

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backend selection is based only on is_jax(virial_mass). If a caller passes virial_mass as a Python/NumPy scalar but supplies other inputs as JAX arrays/tracers (e.g. c_2, overdens, or redshifts), this will select NumPy and then attempt NumPy ops on traced values, breaking JIT/grad. Consider either adding an explicit xp parameter, or inferring the backend from all numeric inputs (or at least from c_2 / overdens as well).

Copilot uses AI. Check for mistakes.
# Bryan & Norman (1998) overdensity if overdens == 0
x = cosmology.Om(redshift_object, xp=xp) - 1.0
overdens_bn98 = 18.0 * xp.pi**2 + 82.0 * x - 39.0 * x**2
overdens = xp.where(overdens == 0, overdens_bn98, overdens)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using xp.where here changes the NumPy-path behavior: when overdens is a scalar, np.where(...) returns a 0-d ndarray, which then propagates so the function returns arrays instead of plain scalars under NumPy. If downstream code expects Python / NumPy scalars (e.g. for JSON serialization), consider keeping the original Python if overdens == 0: branch in the NumPy path and reserving xp.where for the JAX path (or explicitly converting 0-d arrays back to scalars in the NumPy backend).

Suggested change
overdens = xp.where(overdens == 0, overdens_bn98, overdens)
# Use a Python scalar branch for NumPy scalars to avoid returning 0-d arrays;
# keep xp.where for JAX and array inputs.
if is_jax(overdens) or not np.isscalar(overdens):
overdens = xp.where(overdens == 0, overdens_bn98, overdens)
else:
if overdens == 0:
overdens = overdens_bn98

Copilot uses AI. Check for mistakes.
from jax import Array
from jax.core import Tracer
return isinstance(x, (Array, Tracer))
except Exception:
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_jax catches Exception broadly, which can mask real import/runtime errors inside JAX and silently fall back to NumPy. It would be safer to catch only ImportError (and potentially AttributeError for older JAX types) so unexpected failures don’t change numerical backends without warning.

Suggested change
except Exception:
except (ImportError, ModuleNotFoundError, AttributeError):

Copilot uses AI. Check for mistakes.
@Jammy2211 Jammy2211 merged commit 5850af2 into main Mar 2, 2026
8 checks passed
@Jammy2211 Jammy2211 deleted the feature/jaxify_gnfw_conc branch April 2, 2026 11:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants