In [130]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

from potential_cylspline import *
from constants import *

In [2]:
from constants import *
print(EPSILON)
print(X_THRESHOLD0)

1e-12
[0.72  0.72  0.8   0.8   0.83  0.86  0.85  0.88  0.88  0.88  0.885 0.9
 0.91 ]


In [90]:
import time as tt

start = tt.time()
HYPERGEOM_0[0]
end = tt.time()
print("Numpy time:", 1000*(end - start))

start = tt.time()
HYPERGEOM_1[0]
end = tt.time()
print("Numpy time:", 1000*(end - start))

Numpy time: 0.5922317504882812
Numpy time: 0.07796287536621094


In [None]:
@jax.jit
def rho_xyz(x,y,z):
    r = jnp.sqrt(x**2 + y**2)
    return jnp.exp(-r/5.0) * jnp.exp(-jnp.abs(z))

@jax.jit
def cylindrical_to_cartesian(R, phi, z):
    x = R * jnp.cos(phi)
    y = R * jnp.sin(phi)
    return x, y, z

@jax.jit
def rho_Rzphi(R, z, phi):
    z = jnp.abs(z)  # even symmetry
    x, y, zz = cylindrical_to_cartesian(R, phi, z)
    return rho_xyz(x, y, zz)

@jax.jit
def rho_last(R, z, m, phi, Nphi=200):
    dphi = (2*jnp.pi) / Nphi
    vals = rho_Rzphi(R, z, phi)
    exp_ph = jnp.exp(-1j * m * phi)
    rho_m_stack = vals * exp_ph * dphi / (2.0 * jnp.pi)
    return rho_m_stack

@jax.jit
def rho_phiZRm(R, z, m, phi):
    return jnp.sum(jax.vmap(rho_last, in_axes=(None, None, None, 0))(R, z, m, phi), axis=0)

@jax.jit
def compute_rho_m(R, z, m, phi):
    return jax.vmap(rho_phiZRm, in_axes=(None, None, 0, None))(R, z, m, phi)

In [177]:
NR, NZ, Rmin, Rmax, Zmin, Zmax, Mmax = 30, 30, 1e-3, 20., 1e-3, 10., 1. 
Nphi = 200

M = jnp.arange(0, Mmax + 1)

R = jnp.geomspace(jnp.maximum(Rmin, 1e-3), Rmax, NR)
R0_eff = R[NR // 2]

Zpos = jnp.geomspace(jnp.maximum(Zmin, 1e-3), Zmax, NZ)
Z_nonneg = jnp.concatenate([jnp.array([0.0]), Zpos])

Rg, Zg = jnp.meshgrid(R, Z_nonneg, indexing="ij")
phi = jnp.linspace(0.0, 2*jnp.pi, Nphi, endpoint=False)

dphi = (2*jnp.pi) / Nphi

rho_m = jax.vmap(compute_rho_m, in_axes=(0, 0, None, None))(Rg, Zg, M, phi).transpose(1,0,2) 

In [215]:
import jax.scipy.interpolate as interpax
_rho_m_interp = interpax.RegularGridInterpolator((R, Z_nonneg), rho_m[1], method='linear')

In [240]:
### Now for Compute the Potential components

N_int = 10_000
base = jnp.maximum(9, jnp.sqrt(jnp.maximum(16, N_int)).astype(int))
base += jnp.abs(base % 2 - 1)  # make it odd

n_xi = base
n_eta = base

@partial(jax.jit, static_argnames=['n'])
def simpson_weights(n):
    w = jnp.ones(n)
    w = w.at[1:-1:2].set(4.0)
    w = w.at[2:-1:2].set(2.0)
    w *= (1.0 / (n - 1)) / 3.0   # h = 1/(n-1), scale by h/3
    return w

wxi  = simpson_weights(int(n_xi))
weta = simpson_weights(int(n_eta))

xi  = jnp.linspace(0.0, 1.0, n_xi)
eta = jnp.linspace(0.0, 1.0, n_eta)
XI, ETA = jnp.meshgrid(xi, eta, indexing="ij")

@jax.jit
def _xieta_to_Rz_jacobian(xi, eta):
    Rmin_map = R[1]
    Rmax_map = Rmax
    zmin_map = Z_nonneg[1]
    zmax_map = Zmax

    # Precompute logs
    LR = jnp.log(1.0 + Rmax_map / Rmin_map)
    LZ = jnp.log(1.0 + zmax_map / zmin_map)

    # Map to physical coordinates
    pR = jnp.power(1.0 + Rmax_map / Rmin_map, xi)
    pZ = jnp.power(1.0 + zmax_map / zmin_map, eta)
    Rp = Rmin_map * (pR - 1.0)
    zp = zmin_map * (pZ - 1.0)

    # Jacobian part from the coordinate transform (no 2Ï€R' here)
    dR_dxi  = LR * (Rmin_map + Rp)
    dz_deta = LZ * (zmin_map + zp)
    J = dR_dxi * dz_deta
    return Rp, zp, J

Rp, zp, Jmap = _xieta_to_Rz_jacobian(XI, ETA) 
W2D = jnp.einsum('i,j->ij', wxi, weta)


In [232]:
n_xi, int(n_xi)

(Array(101, dtype=int32), 101)