In [None]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
from functools import partial

import jax
import jax.numpy as jnp

from potential_cylspline import *
from constants import *

from cubic_spline import *

In [2]:
from constants import *
print(EPSILON)
print(X_THRESHOLD0)

1e-12
[0.72  0.72  0.8   0.8   0.83  0.86  0.85  0.88  0.88  0.88  0.885 0.9
 0.91 ]


In [90]:
import time as tt

start = tt.time()
HYPERGEOM_0[0]
end = tt.time()
print("Numpy time:", 1000*(end - start))

start = tt.time()
HYPERGEOM_1[0]
end = tt.time()
print("Numpy time:", 1000*(end - start))

Numpy time: 0.5922317504882812
Numpy time: 0.07796287536621094


In [4]:
@jax.jit
def rho_xyz(x,y,z):
    r = jnp.sqrt(x**2 + y**2)
    return jnp.exp(-r/5.0) * jnp.exp(-jnp.abs(z))

@jax.jit
def cylindrical_to_cartesian(R, phi, z):
    x = R * jnp.cos(phi)
    y = R * jnp.sin(phi)
    return x, y, z

@jax.jit
def rho_Rzphi(R, z, phi):
    z = jnp.abs(z)  # even symmetry
    x, y, zz = cylindrical_to_cartesian(R, phi, z)
    return rho_xyz(x, y, zz)

@jax.jit
def rho_last(R, z, m, phi, Nphi=200):
    dphi = (2*jnp.pi) / Nphi
    vals = rho_Rzphi(R, z, phi)
    exp_ph = jnp.exp(-1j * m * phi)
    rho_m_stack = vals * exp_ph * dphi / (2.0 * jnp.pi)
    return rho_m_stack

@jax.jit
def rho_phiZRm(R, z, m, phi):
    return jnp.sum(jax.vmap(rho_last, in_axes=(None, None, None, 0))(R, z, m, phi), axis=0)

@jax.jit
def compute_rho_m(R, z, m, phi):
    return jax.vmap(rho_phiZRm, in_axes=(None, None, 0, None))(R, z, m, phi)

In [6]:
NR, NZ, Rmin, Rmax, Zmin, Zmax, Mmax = 30, 30, 1e-3, 20., 1e-3, 10., 1. 
Nphi = 200

M = jnp.arange(0, Mmax + 1)

R = jnp.geomspace(jnp.maximum(Rmin, 1e-3), Rmax, NR)
R0_eff = R[NR // 2]

Zpos = jnp.geomspace(jnp.maximum(Zmin, 1e-3), Zmax, NZ)
Z_nonneg = jnp.concatenate([jnp.array([0.0]), Zpos])

Rg, Zg = jnp.meshgrid(R, Z_nonneg, indexing="ij")
phi = jnp.linspace(0.0, 2*jnp.pi, Nphi, endpoint=False)

dphi = (2*jnp.pi) / Nphi

rho_m = jax.vmap(compute_rho_m, in_axes=(0, 0, None, None))(Rg, Zg, M, phi).transpose(1,0,2) 
rho_m.shape

(2, 30, 31)

In [49]:
# import scipy.interpolate as interpax

test_points = np.array([
    np.random.uniform(1, 4, 100),
    np.random.uniform(1, 3, 100)
]).T
    
# results = cubic_spline_evaluate(jnp.array(test_points), (R, Z_nonneg), rho_m[0], M_x, M_y, fill_value=0.0)

_rho_m_real_interp = {-1: R, -2: Z_nonneg}
_rho_m_imag_interp = {-1: R, -2: Z_nonneg}

M_x, M_y = jax_precompute_splines((R, Z_nonneg), rho_m[0].real)
_rho_m_real_interp[0] = (rho_m[0].real, M_x, M_y)
M_x, M_y = jax_precompute_splines((R, Z_nonneg), rho_m[0].imag)
_rho_m_imag_interp[0] = (rho_m[0].imag, M_x, M_y)

M_x, M_y = jax_precompute_splines((R, Z_nonneg), rho_m[1].real)
_rho_m_real_interp[1] = (rho_m[1].real, M_x, M_y)
M_x, M_y = jax_precompute_splines((R, Z_nonneg), rho_m[1].imag)
_rho_m_imag_interp[1] = (rho_m[1].imag, M_x, M_y)


@partial(jax.jit, static_argnums=(0,))
def jax_rho_m_eval(m, R, z, _rho_m_real_interp, _rho_m_imag_interp):

    Rgrid = _rho_m_real_interp[-1]
    Zgrid = _rho_m_real_interp[-2]
    real_values = _rho_m_real_interp[m][0]
    imag_values = _rho_m_imag_interp[m][0]
    M_x_real = _rho_m_real_interp[m][1]
    M_y_real = _rho_m_real_interp[m][2]
    M_x_imag = _rho_m_imag_interp[m][1]
    M_y_imag = _rho_m_imag_interp[m][2]

    shape = R.shape
    pts = jnp.column_stack((R.ravel(), jnp.abs(z).ravel()))

    real_part = cubic_spline_evaluate(pts, (Rgrid, Zgrid), real_values, M_x_real, M_y_real, fill_value=0.0).reshape(shape)
    imag_part = cubic_spline_evaluate(pts, (Rgrid, Zgrid), imag_values, M_x_imag, M_y_imag, fill_value=0.0).reshape(shape)

    return real_part + 1j * imag_part

m = 0
a = jax_rho_m_eval(m , jnp.array([2.0, 3.0]), jnp.array([0.5, 1.0]), _rho_m_real_interp, _rho_m_imag_interp)




In [52]:
@jax.jit
def jax_hypergeom_m(m, x):
    """
    m: int,
    x: array-like
    """


    y = 1.0 - x
    y2 = y*y
    z = jnp.log(jnp.where(y > 1e-12, y, 1e-12))

    HYPERGEOM_0_m = HYPERGEOM_0[m]
    HYPERGEOM_I_m = HYPERGEOM_I[m]
    HYPERGEOM_1_m = HYPERGEOM_1[m]

    xA8_1 = x + HYPERGEOM_0_m[8]
    xA6_1 = x + HYPERGEOM_0_m[6] + HYPERGEOM_0_m[7] / xA8_1
    xA4_1 = x + HYPERGEOM_0_m[4] + HYPERGEOM_0_m[5] / xA6_1
    xA2_1 = x + HYPERGEOM_0_m[2] + HYPERGEOM_0_m[3] / xA4_1
    val_1 = HYPERGEOM_0_m[0] + HYPERGEOM_0_m[1] / xA2_1

    xA8_2 = x + HYPERGEOM_I_m[8]
    xA6_2 = x + HYPERGEOM_I_m[6] + HYPERGEOM_I_m[7] / xA8_2
    xA4_2 = x + HYPERGEOM_I_m[4] + HYPERGEOM_I_m[5] / xA6_2
    xA2_2 = x + HYPERGEOM_I_m[2] + HYPERGEOM_I_m[3] / xA4_2
    val_2 = HYPERGEOM_I_m[0] + HYPERGEOM_I_m[1] / xA2_2

    val3 = (HYPERGEOM_1_m[0] + HYPERGEOM_1_m[1]*z +
             (HYPERGEOM_1_m[2] + HYPERGEOM_1_m[3]*z) * y +
             (HYPERGEOM_1_m[4] + HYPERGEOM_1_m[5]*z + 
             (HYPERGEOM_1_m[6] + HYPERGEOM_1_m[7]*z) * y + 
             (HYPERGEOM_1_m[8] + HYPERGEOM_1_m[9]*z) * y2) * y2)

    F = jnp.where(x < X_THRESHOLD1[m],
                 jnp.where(x < X_THRESHOLD0[m], val_1, val_2),
                 val3)

    return F

@jax.jit
def jax_legendreQ(n, x):
    """
    n: float,
    x: array-like
    """

    x = jnp.where(x < 1.0, 1.0, x)
    out = jnp.empty_like(x)
    m = jnp.round(n + 0.5).astype(jnp.int32)

    pref = Q_PREFACTOR[m] / jnp.sqrt(x) / (x**m)
    F = jax_hypergeom_m(m, 1.0/(x*x))
    out = pref * F

    return out

@jax.jit
def jax_kernel_Xi_m(m, R, z, Rp, zp):

    """
    m: int,
    R: float,
    z: float,
    Rp: array-like,
    zp: array-like
    """
    zeros = jnp.zeros_like(Rp, dtype=float)

    val1 = zeros
    val2 = 1.0 / jnp.sqrt(R*R + Rp*Rp + (z - zp)**2)

    val_zero = jax.lax.cond(m>0, lambda: val1, lambda: val2)
    
    Rp_reg = Rp
    dz = (z - zp)
    chi = (R*R + Rp_reg*Rp_reg + dz*dz) / (2.0 * R * Rp_reg)
    chi = jnp.maximum(chi, 1.0)
    Q = jax_legendreQ(m - 0.5, chi)
    val_nonzero = (1.0 / (jnp.pi * jnp.sqrt(R * Rp_reg))) * Q


    val_out = jnp.where(Rp<1e-3, val_zero, val_nonzero)

    out = jax.lax.cond(R < 1e-3, lambda: val_zero, lambda: val_out)

    return out

jax_kernel_Xi_m(1, 2.0, 0.5, jnp.array([1.0, 2.0, 4.]), jnp.array([0.5, 1.0, 3.]))

Array([0.13896658, 0.23909934, 0.03668395], dtype=float32)

In [None]:
### Now for Compute the Potential components

N_int = 10_000
base = jnp.maximum(9, jnp.sqrt(jnp.maximum(16, N_int)).astype(int))
base += jnp.abs(base % 2 - 1)  # make it odd

n_xi = base
n_eta = base

@partial(jax.jit, static_argnames=['n'])
def simpson_weights(n):
    w = jnp.ones(n)
    w = w.at[1:-1:2].set(4.0)
    w = w.at[2:-1:2].set(2.0)
    w *= (1.0 / (n - 1)) / 3.0   # h = 1/(n-1), scale by h/3
    return w

wxi  = simpson_weights(int(n_xi))
weta = simpson_weights(int(n_eta))

xi  = jnp.linspace(0.0, 1.0, n_xi)
eta = jnp.linspace(0.0, 1.0, n_eta)
XI, ETA = jnp.meshgrid(xi, eta, indexing="ij")

@jax.jit
def _xieta_to_Rz_jacobian(xi, eta):
    Rmin_map = R[1]
    Rmax_map = Rmax
    zmin_map = Z_nonneg[1]
    zmax_map = Zmax

    # Precompute logs
    LR = jnp.log(1.0 + Rmax_map / Rmin_map)
    LZ = jnp.log(1.0 + zmax_map / zmin_map)

    # Map to physical coordinates
    pR = jnp.power(1.0 + Rmax_map / Rmin_map, xi)
    pZ = jnp.power(1.0 + zmax_map / zmin_map, eta)
    Rp = Rmin_map * (pR - 1.0)
    zp = zmin_map * (pZ - 1.0)

    # Jacobian part from the coordinate transform (no 2πR' here)
    dR_dxi  = LR * (Rmin_map + Rp)
    dz_deta = LZ * (zmin_map + zp)
    J = dR_dxi * dz_deta
    return Rp, zp, J

Rp, zp, Jmap = _xieta_to_Rz_jacobian(XI, ETA) 
W2D = jnp.einsum('i,j->ij', wxi, weta)

m_list = jnp.array(M)
# print(m_list)

def 

[0. 1.]


In [None]:
def compute_phi_m_grid_fixed_mapped(
    self,
    *,
    N_int: int = 10_000,
    n_xi: Optional[int] = None,
    n_eta: Optional[int] = None,
    m_list: Optional[List[int]] = None,
    progress: bool = False,
):
    """
    Compute Φ_m(R,z) on the (R, z≥0) grid using a tensor Simpson rule over (xi,eta)∈[0,1]^2,
    with AGAMA-style log mapping and Jacobian.

    IMPORTANT: we integrate z'≥0 only, but *sum* kernel contributions from +z' and −z'
    (Xi_plus + Xi_minus) inside the integrand. This replaces the old “×2 at the end”
    and is correct for z≠0 as well as z=0.

        Φ_m(R0,z0) = -G * ∬_{[0,1]^2}  ρ_m(R',z') * [Ξ_m(R0,z0|R',+z') + Ξ_m(R0,z0|R',-z')]
                                * (2π R') * J(xi,eta)  dxi deta
    """
    if not self._rho_m_real_interp:
        raise RuntimeError("Run compute_rho_m() first.")

    # --- choose Simpson node counts (odd ≥3) from N_int if not provided ---
    if (n_xi is None) or (n_eta is None):
        base = max(9, int(np.sqrt(max(16, N_int))))
        if base % 2 == 0:
            base += 1
        n_xi = base if n_xi is None else int(n_xi)
        n_eta = base if n_eta is None else int(n_eta)
    if n_xi < 3 or n_xi % 2 == 0:
        raise ValueError("Simpson along xi needs odd n_xi ≥ 3.")
    if n_eta < 3 or n_eta % 2 == 0:
        raise ValueError("Simpson along eta needs odd n_eta ≥ 3.")

    # --- Simpson weights on [0,1] ---
    def simpson_weights(n: int) -> np.ndarray:
        w = np.ones(n)
        w[1:-1:2] = 4.0
        w[2:-1:2] = 2.0
        w *= (1.0 / (n - 1)) / 3.0   # h = 1/(n-1), scale by h/3
        return w

    wxi  = simpson_weights(n_xi)
    weta = simpson_weights(n_eta)

    # --- tensor nodes in (xi,eta) and mapped (R', z') with Jacobian ---
    xi  = np.linspace(0.0, 1.0, n_xi)
    eta = np.linspace(0.0, 1.0, n_eta)
    XI, ETA = np.meshgrid(xi, eta, indexing="ij")                    # (n_xi, n_eta)

    Rp, zp, Jmap = self._xieta_to_Rz_jacobian(XI, ETA)               # (n_xi, n_eta) each
    W2D = (wxi[:, None]) * (weta[None, :])                           # Simpson product weights

    # which m to compute
    m_list = list(range(self.mmax + 1)) if m_list is None else list(m_list)

    self._Phi_m_grid = {}
    self._Phi_m_interp = {}

    for m in tqdm(m_list, total = len(m_list)):
        Phi = np.zeros((self.NR, self.Z_nonneg.size), dtype=complex)

        # Precompute density ρ_m at all (R',z') nodes ONCE per m (even in z)
        rho_grid = self.rho_m_eval(m, Rp, zp)                         # (n_xi, n_eta), complex

        for i, R0 in (enumerate(self.R)):
            for j, z0 in (enumerate(self.Z_nonneg)):
                # kernel from +z' and −z' (sum, not average)
                Xi_plus  = self.kernel_Xi_m(m, R0, z0, Rp,  zp)       # real
                Xi_minus = self.kernel_Xi_m(m, R0, z0, Rp, -zp)       # real
                Xi_sum   = Xi_plus + Xi_minus

                # integrand: ρ_m * (Ξ+ + Ξ−) * (2π R') * J
                F = rho_grid * Xi_sum * (2.0 * np.pi) * Rp * Jmap     # complex

                # Simpson tensor product on [0,1]^2
                I = np.sum(W2D * F)

                # no extra ×2: Xi_sum already accounts for both halves in z′
                Phi[i, j] = -self.G * I

            if progress:
                print(f"[mapped simpson] m={m}  R[{i+1}/{self.NR}]")

        # store grid + interpolator
        self._Phi_m_grid[m] = Phi
        # self._Phi_m_interp[m] = RegularGridInterpolator(
        #     (self.R, self.Z_nonneg), Phi, method="linear",
        #     bounds_error=False, fill_value=None
        # )
        self._Phi_m_real_interp[m] = CubicSpline2D(
            (self.R, self.Z_nonneg), Phi.real,
            bounds_error=False, fill_value=0.
        )
        self._Phi_m_imag_interp[m] = CubicSpline2D(
            (self.R, self.Z_nonneg), Phi.imag,
            bounds_error=False, fill_value=0.
        )

    return {
        "n_xi": n_xi,
        "n_eta": n_eta,
        "total_nodes": n_xi * n_eta,
        "rule": "simpson([0,1]^2) with log-mapping; Xi(+z')+Xi(−z')",
    }

(Array(101, dtype=int32), 101)