In [1]:
import jax
jax.config.update("jax_enable_x64", True)

import galsim
import numpy as np

import jax.numpy as jnp

In [2]:
from jax_galsim.spergel import (
    fz_nup1, _gammap1, _spergel_hlr_pade,
    fluxfractionFunc, fz_nu, calculateFluxRadius,
)

@jax.jit
def _calculateFluxRadius_newtons_kernel(i, args):
    """Newton's method kernel for calculateFluxRadius

    Returns

        lnz - fluxfractionFunc(z, nu, alpha) / dfluxfractionFunc(z, nu, alpha)_dz / z

    which is Newton's kernel but in log space.
    """
    lnz, alpha, nu = args
    z = jnp.exp(lnz)
    dn = (jnp.power(2.0, nu) * _gammap1(nu))
    fz = 1.0 - fz_nup1(z, nu) / dn - alpha
    dfzdz = z * fz_nu(z, nu) / dn

    # we clip the result to avoid numerical issues near bounds
    lnz = jnp.clip(
        lnz - fz / dfzdz / z,
        min=-100,
        max=100,
    )

    return lnz, alpha, nu


@jax.jit
def calculateFluxRadiusNewton(alpha, nu):
    """Return radius R enclosing flux fraction alpha in unit of the scale radius r0

    Method: Solve  F(R/r0=z)/Flux - alpha = 0 using Netwon's method

    We can integrate the profile to get

        F(R)/F =  int( 1/(2^nu Gamma(nu+1)) (r/r0)^(nu+1) K_nu(r/r0) dr/r0; r=0..R) = alpha

    So if we define z = R/r0 and f(z) = F(z * r0)/F - alpha, then Newton's method is

        z -> z - f(z) / f'(z)

    We actually run the method for ln(z) which is

        ln(z) -> ln(z) - f(z) / f'(z) / z

    Typical use cases include:

      - alpha = 1/2 => R = Half-Light-Radius,
      - alpha = 1 - folding-thresold => R used for stepk computation
    """
    # seed the iteration with the Pade approximation to the HLR
    # scaled by the fraction of flux to some power
    zalpha = _spergel_hlr_pade(nu) * jnp.sqrt(alpha / 0.5)
    return jnp.exp(jax.lax.fori_loop(
        0, 100,
        _calculateFluxRadius_newtons_kernel,
        (jnp.log(zalpha), alpha, nu),
    )[0])


In [3]:
for eps in [1e-12, 0.1]:
    for alpha in [eps, 1.0 - eps]:
        for nu in [-0.84, 3.999]:

            print("\neps, nu, log10(alpha):", eps, nu, np.log10(alpha))
            zfp = calculateFluxRadiusNewton(alpha, nu)
            zbs = calculateFluxRadius(alpha, nu)
            print(
                zfp,
                fluxfractionFunc(zfp, nu, alpha),
                zbs,
                fluxfractionFunc(zbs, nu, alpha),
                galsim.Spergel(nu, scale_radius=1.0).calculateIntegratedFlux(zfp),
            )


eps, nu, log10(alpha): 1e-12 -0.84 -12.0
3.5138887102897e-38 -2.2121720121483927e-17 1.0587911840678754e-21 1.8761616702453412e-07 1.000534100015216e-12

eps, nu, log10(alpha): 1e-12 3.999 -12.0
3.9966817649384216e-06 -1.576433954596703e-15 3.999832106175669e-06 3.1094518726606304e-16 9.984622740022494e-13

eps, nu, log10(alpha): 1e-12 -0.84 -4.3428487456249e-13
25.572845945758726 0.0 25.572509765625 -3.3306690738754696e-16 0.999999999999

eps, nu, log10(alpha): 1e-12 3.999 -4.3428487456249e-13
38.6677767503012 0.0 38.6676025390625 -1.1102230246251565e-16 0.999999999999

eps, nu, log10(alpha): 0.1 -0.84 -1.0
0.0008333666650951336 6.38378239159465e-16 0.0008333666650951221 3.0531133177191805e-16 0.10000000000000096

eps, nu, log10(alpha): 0.1 3.999 -1.0
1.3092245672406861 6.38378239159465e-16 1.3092245672406833 8.326672684688674e-17 0.10000000000000037

eps, nu, log10(alpha): 0.1 -0.84 -0.045757490560675115
1.2147258941802845 0.0 1.214725894180284 -1.1102230246251565e-16 0.900000000000