In [1]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
from functools import partial

import jax
import jax.numpy as jnp

from potential_cylspline import *
from constants import *

from cubic_spline import *

In [2]:
from constants import *
print(EPSILON)
print(G)
print(X_THRESHOLD0)

1e-12
4.3e-06
[0.72  0.72  0.8   0.8   0.83  0.86  0.85  0.88  0.88  0.88  0.885 0.9
 0.91 ]


In [3]:
import time as tt

start = tt.time()
HYPERGEOM_0[0]
end = tt.time()
print("Numpy time:", 1000*(end - start))

start = tt.time()
HYPERGEOM_1[0]
end = tt.time()
print("Numpy time:", 1000*(end - start))

Numpy time: 26.974201202392578
Numpy time: 20.968914031982422


In [4]:
@jax.jit
def rho_xyz(x,y,z):
    r = jnp.sqrt(x**2 + y**2)
    return jnp.exp(-r/5.0) * jnp.exp(-jnp.abs(z))

@jax.jit
def cylindrical_to_cartesian(R, phi, z):
    x = R * jnp.cos(phi)
    y = R * jnp.sin(phi)
    return x, y, z

@jax.jit
def rho_Rzphi(R, z, phi):
    z = jnp.abs(z)  # even symmetry
    x, y, zz = cylindrical_to_cartesian(R, phi, z)
    return rho_xyz(x, y, zz)

@jax.jit
def rho_last(R, z, m, phi, Nphi=200):
    dphi = (2*jnp.pi) / Nphi
    vals = rho_Rzphi(R, z, phi)
    exp_ph = jnp.exp(-1j * m * phi)
    rho_m_stack = vals * exp_ph * dphi / (2.0 * jnp.pi)
    return rho_m_stack

@jax.jit
def rho_phiZRm(R, z, m, phi):
    return jnp.sum(jax.vmap(rho_last, in_axes=(None, None, None, 0))(R, z, m, phi), axis=0)

@jax.jit
def compute_rho_m(R, z, m, phi):
    return jax.vmap(rho_phiZRm, in_axes=(None, None, 0, None))(R, z, m, phi)

In [44]:
NR, NZ, Rmin, Rmax, Zmin, Zmax, Mmax = 30, 30, 1e-3, 20., 1e-3, 10., 8. 
Nphi = 200

M = jnp.arange(0, Mmax + 1)

R = jnp.geomspace(jnp.maximum(Rmin, 1e-3), Rmax, NR)
R0_eff = R[NR // 2]

Zpos = jnp.geomspace(jnp.maximum(Zmin, 1e-3), Zmax, NZ)
Z_nonneg = jnp.concatenate([jnp.array([0.0]), Zpos])

Rg, Zg = jnp.meshgrid(R, Z_nonneg, indexing="ij")
phi = jnp.linspace(0.0, 2*jnp.pi, Nphi, endpoint=False)

dphi = (2*jnp.pi) / Nphi

rho_m = jax.vmap(compute_rho_m, in_axes=(0, 0, None, None))(Rg, Zg, M, phi).transpose(1,0,2) 

In [62]:
# import scipy.interpolate as interpax

test_points = np.array([
    np.random.uniform(1, 4, 100),
    np.random.uniform(1, 3, 100)
]).T
    
# results = cubic_spline_evaluate(jnp.array(test_points), (R, Z_nonneg), rho_m[0], M_x, M_y, fill_value=0.0)

_rho_m_real_interp = {-1: R, -2: Z_nonneg}
_rho_m_imag_interp = {-1: R, -2: Z_nonneg}

i = -1
for m in M.astype(int):
    i+=1
    M_x, M_y = jax_precompute_splines((R, Z_nonneg), rho_m[m].real)
    _rho_m_real_interp[i] = (rho_m[m].real, M_x, M_y)

    M_x, M_y = jax_precompute_splines((R, Z_nonneg), rho_m[m].imag)
    _rho_m_imag_interp[i] = (rho_m[m].imag, M_x, M_y)

# @partial(jax.jit, static_argnums=(0,))
def jax_rho_m_eval(m, R, z, _rho_m_real_interp, _rho_m_imag_interp):

    Rgrid = _rho_m_real_interp[-1]
    Zgrid = _rho_m_real_interp[-2]
    real_values = _rho_m_real_interp[m][0]
    imag_values = _rho_m_imag_interp[m][0]
    M_x_real = _rho_m_real_interp[m][1]
    M_y_real = _rho_m_real_interp[m][2]
    M_x_imag = _rho_m_imag_interp[m][1]
    M_y_imag = _rho_m_imag_interp[m][2]

    shape = R.shape
    pts = jnp.column_stack((R.ravel(), jnp.abs(z).ravel()))

    real_part = cubic_spline_evaluate(pts, (Rgrid, Zgrid), real_values, M_x_real, M_y_real, fill_value=0.0).reshape(shape)
    imag_part = cubic_spline_evaluate(pts, (Rgrid, Zgrid), imag_values, M_x_imag, M_y_imag, fill_value=0.0).reshape(shape)

    return real_part + 1j * imag_part


In [63]:
@jax.jit
def jax_hypergeom_m(m, x):
    """
    m: int,
    x: array-like
    """


    y = 1.0 - x
    y2 = y*y
    z = jnp.log(jnp.where(y > 1e-12, y, 1e-12))

    HYPERGEOM_0_m = HYPERGEOM_0[m]
    HYPERGEOM_I_m = HYPERGEOM_I[m]
    HYPERGEOM_1_m = HYPERGEOM_1[m]

    xA8_1 = x + HYPERGEOM_0_m[8]
    xA6_1 = x + HYPERGEOM_0_m[6] + HYPERGEOM_0_m[7] / xA8_1
    xA4_1 = x + HYPERGEOM_0_m[4] + HYPERGEOM_0_m[5] / xA6_1
    xA2_1 = x + HYPERGEOM_0_m[2] + HYPERGEOM_0_m[3] / xA4_1
    val_1 = HYPERGEOM_0_m[0] + HYPERGEOM_0_m[1] / xA2_1

    xA8_2 = x + HYPERGEOM_I_m[8]
    xA6_2 = x + HYPERGEOM_I_m[6] + HYPERGEOM_I_m[7] / xA8_2
    xA4_2 = x + HYPERGEOM_I_m[4] + HYPERGEOM_I_m[5] / xA6_2
    xA2_2 = x + HYPERGEOM_I_m[2] + HYPERGEOM_I_m[3] / xA4_2
    val_2 = HYPERGEOM_I_m[0] + HYPERGEOM_I_m[1] / xA2_2

    val3 = (HYPERGEOM_1_m[0] + HYPERGEOM_1_m[1]*z +
             (HYPERGEOM_1_m[2] + HYPERGEOM_1_m[3]*z) * y +
             (HYPERGEOM_1_m[4] + HYPERGEOM_1_m[5]*z + 
             (HYPERGEOM_1_m[6] + HYPERGEOM_1_m[7]*z) * y + 
             (HYPERGEOM_1_m[8] + HYPERGEOM_1_m[9]*z) * y2) * y2)

    F = jnp.where(x < X_THRESHOLD1[m],
                 jnp.where(x < X_THRESHOLD0[m], val_1, val_2),
                 val3)

    return F

@jax.jit
def jax_legendreQ(n, x):
    """
    n: float,
    x: array-like
    """

    x = jnp.where(x < 1.0, 1.0, x)
    out = jnp.empty_like(x)
    m = jnp.round(n + 0.5).astype(jnp.int32)

    pref = Q_PREFACTOR[m] / jnp.sqrt(x) / (x**m)
    F = jax_hypergeom_m(m, 1.0/(x*x))
    out = pref * F

    return out

@jax.jit
def jax_kernel_Xi_m(m, R, z, Rp, zp):

    """
    m: int,
    R: float,
    z: float,
    Rp: array-like,
    zp: array-like
    """
    zeros = jnp.zeros_like(Rp, dtype=float)

    val1 = zeros
    val2 = 1.0 / jnp.sqrt(R*R + Rp*Rp + (z - zp)**2)

    val_zero = jax.lax.cond(m>0, lambda: val1, lambda: val2)
    
    Rp_reg = Rp
    dz = (z - zp)
    chi = (R*R + Rp_reg*Rp_reg + dz*dz) / (2.0 * R * Rp_reg)
    chi = jnp.maximum(chi, 1.0)
    Q = jax_legendreQ(m - 0.5, chi)
    val_nonzero = (1.0 / (jnp.pi * jnp.sqrt(R * Rp_reg))) * Q


    val_out = jnp.where(Rp<1e-3, val_zero, val_nonzero)

    out = jax.lax.cond(R < 1e-3, lambda: val_zero, lambda: val_out)

    return out

In [64]:
### Now for Compute the Potential components

N_int = 10_000
base = jnp.maximum(9, jnp.sqrt(jnp.maximum(16, N_int)).astype(int))
base += jnp.abs(base % 2 - 1)  # make it odd

n_xi = base
n_eta = base

@partial(jax.jit, static_argnames=['n'])
def simpson_weights(n):
    w = jnp.ones(n)
    w = w.at[1:-1:2].set(4.0)
    w = w.at[2:-1:2].set(2.0)
    w *= (1.0 / (n - 1)) / 3.0   # h = 1/(n-1), scale by h/3
    return w

wxi  = simpson_weights(int(n_xi))
weta = simpson_weights(int(n_eta))

xi  = jnp.linspace(0.0, 1.0, n_xi)
eta = jnp.linspace(0.0, 1.0, n_eta)
XI, ETA = jnp.meshgrid(xi, eta, indexing="ij")

@jax.jit
def _xieta_to_Rz_jacobian(xi, eta):
    Rmin_map = R[1]
    Rmax_map = Rmax
    zmin_map = Z_nonneg[1]
    zmax_map = Zmax

    # Precompute logs
    LR = jnp.log(1.0 + Rmax_map / Rmin_map)
    LZ = jnp.log(1.0 + zmax_map / zmin_map)

    # Map to physical coordinates
    pR = jnp.power(1.0 + Rmax_map / Rmin_map, xi)
    pZ = jnp.power(1.0 + zmax_map / zmin_map, eta)
    Rp = Rmin_map * (pR - 1.0)
    zp = zmin_map * (pZ - 1.0)

    # Jacobian part from the coordinate transform (no 2πR' here)
    dR_dxi  = LR * (Rmin_map + Rp)
    dz_deta = LZ * (zmin_map + zp)
    J = dR_dxi * dz_deta
    return Rp, zp, J

Rp, zp, Jmap = _xieta_to_Rz_jacobian(XI, ETA) 
W2D = jnp.einsum('i,j->ij', wxi, weta)

m_list = jnp.array(M)

In [67]:
# @jax.jit
def m_wrapper(m, R0, z0, Rp, zp):
    rho_grid = jax_rho_m_eval(m.astype(int), Rp, zp, _rho_m_real_interp, _rho_m_imag_interp)

    return jax.vmap(R_wrapper, in_axes=(None, 0, None, None, None, None))(m, R0, z0, Rp, zp, rho_grid)

@jax.jit
def R_wrapper(m, R0, z0, Rp, zp, rho_grid):
    return jax.vmap(Z_wrapper, in_axes=(None, None, 0, None, None, None))(m, R0, z0, Rp, zp, rho_grid)

@jax.jit
def Z_wrapper(m, R0, z0, Rp, zp, rho_grid):
    Xi_plus  = jax_kernel_Xi_m(m, R0, z0, Rp, zp)
    Xi_minus = jax_kernel_Xi_m(m, R0, z0, Rp, -zp)
    Xi_sum   = Xi_plus + Xi_minus

    F = rho_grid * Xi_sum * (2.0 * np.pi) * Rp * Jmap

    I = np.sum(W2D * F)

    return -G * I

phi_m = jax.vmap(m_wrapper, in_axes=(0, None, None, None, None))(M.astype(int), R, Z_nonneg, Rp, zp)

TypeError: cannot use 'jax._src.interpreters.batching.BatchTracer' as a dict key (unhashable type: 'BatchTracer')

In [40]:
_phi_m_real_interp = {-1: R, -2: Z_nonneg}
_phi_m_imag_interp = {-1: R, -2: Z_nonneg}

i = -1
for m in M.astype(int):
    i+=1
    M_x, M_y = jax_precompute_splines((R, Z_nonneg), phi_m[m].real)
    _phi_m_real_interp[i] = (phi_m[m].real, M_x, M_y)

    M_x, M_y = jax_precompute_splines((R, Z_nonneg), phi_m[0].imag)
    _phi_m_imag_interp[i] = (phi_m[m].imag, M_x, M_y)

In [None]:
phi_m = jax.vmap(m_wrapper, in_axes=(0, None, None, None, None))(M.astype(int), R, Z_nonneg, Rp, zp)

In [None]:
def potential(self, R: ArrayLike, z: ArrayLike, phi: ArrayLike) -> Array:
    Rb = np.asarray(R, float); zb = np.asarray(z, float); ph = np.asarray(phi, float)
    Rb, zb, ph = np.broadcast_arrays(Rb, zb, ph)
    # m=0 term
    out = self.phi_m_eval(0, Rb, zb).real
    # m>=1 terms: Φ += 2 * (Re Φ_m * cos mφ  - Im Φ_m * sin mφ)
    for m in range(1, self.mmax+1):
        Phi_m = self.phi_m_eval(m, Rb, zb)
        out += 2.0 * (Phi_m.real * np.cos(m*ph) - Phi_m.imag * np.sin(m*ph))
    return out
