In [None]:
"""
Differentiable Schrödinger-Poisson + Collisionless Stars (JAX)
-------------------------------------------------------------
• Frames 0–99  : evolve with random stars (warm-up)
• Frame  100   : swap to edge-on stellar disk
• Frames 101–299: evolve with the disk

Same physics & visualization as your working script, rewritten in the
documented, sectioned style of the Mocz-style demo.
"""

# ======================================================================
# 0. Imports & housekeeping
# ======================================================================
import jax, jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from mpl_toolkits.axes_grid1 import make_axes_locatable
from IPython.display import HTML

# ----------------------------------------------------------------------
# 1. Unit system & global parameters (input)
# ----------------------------------------------------------------------
# spatial resolution
nx, ny, nz = 128, 64, 16
# box size [kpc]
Lx, Ly, Lz = 128.0, 64.0, 16.0
# mean density [Msun/kpc³]
rho_bar = 1.0e4
# stop time [kpc/(km/s) ≈ Gyr]
t_end   = 30.0
# axion mass [10⁻²² eV]
m_22    = 0.8
# stars
n_s = 400                                   # number of star particles
M_s = 0.1 * rho_bar * Lx * Ly * Lz          # total stellar mass [Msun]

# ----------------------------------------------------------------------
# 2. Physical constants
# ----------------------------------------------------------------------
G            = 4.30241002e-6                 # kpc (km/s)² / Msun
hbar         = 1.71818134e-87               # [V][L][M]
ev_to_msun   = 8.96215334e-67
m            = m_22 * 1e-22 * ev_to_msun    # axion mass [Msun]
m_per_hbar   = m / hbar                     # convenient factor
m_s          = M_s / n_s                    # mass per star

# ----------------------------------------------------------------------
# 3. Mesh & Fourier space
# ----------------------------------------------------------------------
# real grid
dx, dy, dz = Lx/nx, Ly/ny, Lz/nz
x = jnp.linspace(0.5*dx, Lx-0.5*dx, nx)
y = jnp.linspace(0.5*dy, Ly-0.5*dy, ny)
z = jnp.linspace(0.5*dz, Lz-0.5*dz, nz)
X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij")
# k-grid
kx = jnp.fft.ifftshift(2*jnp.pi/Lx * jnp.arange(-nx/2, nx/2)[:,None,None])
ky = jnp.fft.ifftshift(2*jnp.pi/Ly * jnp.arange(-ny/2, ny/2)[None,:,None])
kz = jnp.fft.ifftshift(2*jnp.pi/Lz * jnp.arange(-nz/2, nz/2)[None,None,:])
k_sq = kx**2 + ky**2 + kz**2

# ----------------------------------------------------------------------
# 4. Time step setup
# ----------------------------------------------------------------------
dt_fac = 1.0
dt_kin = dt_fac * m_per_hbar / 6 * (dx*dy*dz)**(2/3)
nt     = int(jnp.ceil(jnp.ceil(t_end/dt_kin)/200)*200)  # divisible by 200
nt_sub = nt // 200                                     # sub-steps per frame
dt     = t_end / nt

# pack some constants for easy passing (optional)
params = dict(nx=nx, ny=ny, nz=nz, Lx=Lx, Ly=Ly, Lz=Lz,
              dx=dx, dy=dy, dz=dz, m_per_hbar=m_per_hbar,
              G=G, rho_bar=rho_bar, k_sq=k_sq, kx=kx, ky=ky, kz=kz)

# ======================================================================
# 5. Helper kernels
# ======================================================================

def potential(rho):
    """Poisson solver via FFT."""
    V_hat = -jnp.fft.fftn(4*jnp.pi*G*(rho-rho_bar)) / (k_sq + (k_sq==0))
    return jnp.real(jnp.fft.ifftn(V_hat))


def cic_indices_weights(pos):
    d = jnp.array([dx,dy,dz])
    frac = (pos - 0.5*d) / d
    i = jnp.floor(frac).astype(int)
    w1 = frac - i; w0 = 1 - w1
    i = jnp.mod(i,   jnp.array([nx,ny,nz]))
    ip1 = jnp.mod(i+1,jnp.array([nx,ny,nz]))
    return i, ip1, w0, w1


def deposit_stars(pos):
    rho = jnp.zeros((nx,ny,nz))
    i,ip1,w0,w1 = cic_indices_weights(pos)

    def loop(s,r):
        fac = m_s/(dx*dy*dz)
        for ix,wx in ((i,w0),(ip1,w1)):
            for iy,wy in ((i,w0),(ip1,w1)):
                for iz,wz in ((i,w0),(ip1,w1)):
                    r = r.at[ix[s,0],iy[s,1],iz[s,2]].add(wx[s,0]*wy[s,1]*wz[s,2]*fac)
        return r
    return jax.lax.fori_loop(0, n_s, loop, rho)


def star_acceleration(pos, rho):
    i,ip1,w0,w1 = cic_indices_weights(pos)
    V_hat = -jnp.fft.fftn(4*jnp.pi*G*(rho-rho_bar)) / (k_sq + (k_sq==0))
    ax = -jnp.real(jnp.fft.ifftn(1j*kx*V_hat))
    ay = -jnp.real(jnp.fft.ifftn(1j*ky*V_hat))
    az = -jnp.real(jnp.fft.ifftn(1j*kz*V_hat))
    grid = jnp.stack((ax,ay,az), -1)
    a = jnp.zeros((n_s,3))
    for ix,wx in ((i,w0),(ip1,w1)):
        for iy,wy in ((i,w0),(ip1,w1)):
            for iz,wz in ((i,w0),(ip1,w1)):
                w = (wx[:,0]*wy[:,1]*wz[:,2])[:,None]
                a += w * grid[ix[:,0], iy[:,1], iz[:,2]]
    return a

# ----------------------------------------------------------------------
# 6. Symplectic sub-step (kick-drift-kick)
# ----------------------------------------------------------------------
@jax.jit
def substep(psi,pos,vel):
    rho = jnp.abs(psi)**2 + deposit_stars(pos)
    V   = potential(rho)
    psi = jnp.exp(-1j*m_per_hbar*dt/2*V)*psi
    vel = vel + star_acceleration(pos, rho)*dt/2

    psi = jnp.fft.ifftn(jnp.fft.fftn(psi) * jnp.exp(-1j*dt*k_sq/(2*m_per_hbar)))
    pos = jnp.mod(pos + vel*dt, jnp.array([Lx,Ly,Lz]))

    rho = jnp.abs(psi)**2 + deposit_stars(pos)
    V   = potential(rho)
    psi = jnp.exp(-1j*m_per_hbar*dt/2*V)*psi
    vel = vel + star_acceleration(pos, rho)*dt/2
    return psi,pos,vel

# ======================================================================
# 7. Initial conditions
# ======================================================================
# wavefunction with Gaussian bumps (same pattern)
amp,sigma = 100.0, 4.0
rho = 10.0
for cx,cy in [(0.5,0.4),(0.6,0.5),(0.4,0.6),(0.6,0.4),(0.6,0.6)]:
    rho += amp * jnp.exp(-((X-cx*Lx)**2 + (Y-cy*Ly)**2)/(2*sigma**2))
rho *= rho_bar/jnp.mean(rho)
psi0 = jnp.sqrt(rho) + 0j

# star samplers ---------------------------------------------------------

def sample_random():
    rng = np.random.default_rng(0)
    pos = rng.uniform(0,1,(n_s,3))*np.array([Lx,Ly,Lz])
    vel = np.zeros_like(pos)
    return jnp.array(pos), jnp.array(vel)


def sample_edge_disk():
    rng = np.random.default_rng(42)
    c = np.array([Lx/2,Ly/2,Lz/2])
    nd = int(n_s*0.8)
    pos,vel=[],[]
    for _ in range(nd):
        R=rng.exponential(5.0); phi=rng.uniform(0,2*np.pi)
        x=R*np.cos(phi)+c[0]; z=R*np.sin(phi)+c[2]; y=c[1]+rng.normal(scale=1.0)
        v_c=0 if R==0 else np.sqrt(G*M_s*R**2/(R**2+5**2)**1.5)
        pos.append((x,y,z)); vel.append((-v_c*np.sin(phi),0,v_c*np.cos(phi)))
    for _ in range(n_s-nd):
        u=rng.random(); r=3/np.sqrt(u**(-2/3)-1)
        th=rng.uniform(0,np.pi); ph=rng.uniform(0,2*np.pi)
        x=r*np.sin(th)*np.cos(ph)+c[0]; y=r*np.sin(th)*np.sin(ph)+c[1]; z=r*np.cos(th)+c[2]
        pos.append((x,y,z)); vel.append((0,0,0))
    return jnp.array(pos), jnp.array(vel)

pos_rand, vel_rand = sample_random()
pos_disk, vel_disk = sample_edge_disk()

# ======================================================================
# 8. Matplotlib setup
# ======================================================================
fig, ax = plt.subplots(figsize=(10,8), dpi=150)
div = make_axes_locatable(ax); cax = div.append_axes('right','5%',pad=0.05)

# initial field
rho_proj = jnp.log10(jnp.mean(jnp.abs(psi0)**2, axis=2)).T
im = ax.imshow(rho_proj, cmap='viridis', origin='lower', extent=(0,Lx,0,Ly))
sc = ax.scatter(pos_rand[:,0]/Lx*nx, pos_rand[:,1]/Ly*ny, s=4, c='cyan', edgecolors='white')
cb = fig.colorbar(im, cax=cax); cb.set_label('log10 ρ', rotation=270, labelpad=15)
ax.set_xlabel('x [kpc]'); ax.set_ylabel('y [kpc]'); ax.set_aspect('equal')

# ======================================================================
# 9. Animation driver (frames 0-299)
# ======================================================================
state = (psi0, pos_rand, vel_rand)


def animate(frame):
    global state
    psi,pos,vel = state
    if frame == 100:
        state = (psi, pos_disk, vel_disk)
        psi,pos,vel = state
    for _ in range(nt_sub):
        psi,pos,vel = substep(psi,pos,vel)
    state = (psi,pos,vel)

    # update image & stars
    rho_p = jnp.log10(jnp.mean(jnp.abs(psi)**2, axis=2)).T
    im.set_data(rho_p)
    im.set_clim(float(rho_p.min()), float(rho_p.max()))
    sc.set_offsets(jnp.vstack([pos[:,0]/Lx*nx, pos[:,1]/Ly*ny]).T)
    return im,sc

anim = animation.FuncAnimation(fig, animate, frames=300,
                               interval=100, blit=True)
HTML(anim.to_jshtml())