In [None]:
"""
Schrödinger–Poisson (JAX) + Globular-Cluster Orbits
---------------------------------------------------
• Frames 0–49  : warm-up with random tracer stars
• Frame  50    : swap-in 6 compact globular clusters
• Frames 51–299: follow their COM orbits, leaving trails

Everything else (solver, density rendering) is identical
to the edge-on-disk version you approved.
"""

# ──────────────────────────────────────────────────────────────
# 0. Imports & constants
# ──────────────────────────────────────────────────────────────
import jax, jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from mpl_toolkits.axes_grid1 import make_axes_locatable
from IPython.display import HTML

# grid & box
nx, ny, nz = 128, 64, 16
Lx, Ly, Lz = 128.0, 64.0, 16.0
dx, dy, dz = Lx/nx, Ly/ny, Lz/nz
x = jnp.linspace(0.5*dx, Lx-0.5*dx, nx)
y = jnp.linspace(0.5*dy, Ly-0.5*dy, ny)
z = jnp.linspace(0.5*dz, Lz-0.5*dz, nz)
X, Y, Z = jnp.meshgrid(x, y, z, indexing="ij")

# physical params
rho_bar = 1e4
t_end   = 30.0
m_22    = 0.8
G       = 4.30241002e-6
hbar    = 1.71818134e-87
m       = m_22*1e-22*8.96215334e-67
m_h     = m / hbar

# stars / clusters
n_s = 400
M_s = 0.05 * rho_bar * Lx*Ly*Lz   # smaller than disk run
m_s = M_s / n_s
N_gc = 6                          # number of globular clusters

# Fourier grid
kx = jnp.fft.ifftshift(2*jnp.pi/Lx * jnp.arange(-nx/2, nx/2)[:,None,None])
ky = jnp.fft.ifftshift(2*jnp.pi/Ly * jnp.arange(-ny/2, ny/2)[None,:,None])
kz = jnp.fft.ifftshift(2*jnp.pi/Lz * jnp.arange(-nz/2, nz/2)[None,None,:])
k_sq = kx**2 + ky**2 + kz**2

# time step
dt_fac = 1.0
dt_kin = dt_fac*m_h/6*(dx*dy*dz)**(2/3)
nt     = int(jnp.ceil(jnp.ceil(t_end/dt_kin)/200)*200)
nt_sub = nt // 200
dt     = t_end / nt

# ──────────────────────────────────────────────────────────────
# 1. FFT helpers
# ──────────────────────────────────────────────────────────────
def pot(rho):
    Vh = -jnp.fft.fftn(4*jnp.pi*G*(rho-rho_bar)) / (k_sq + (k_sq==0))
    return jnp.real(jnp.fft.ifftn(Vh))

def cic_weights(pos):
    d=jnp.array([dx,dy,dz]); frac=(pos-0.5*d)/d
    i=jnp.floor(frac).astype(int); w1=frac-i; w0=1-w1
    i=jnp.mod(i, jnp.array([nx,ny,nz])); ip1=jnp.mod(i+1, jnp.array([nx,ny,nz]))
    return i, ip1, w0, w1

def deposit(pos):
    rho=jnp.zeros((nx,ny,nz)); i,ip1,w0,w1=cic_weights(pos)
    def loop(s,r):
        fac=m_s/(dx*dy*dz)
        for ix,wx in ((i,w0),(ip1,w1)):
            for iy,wy in ((i,w0),(ip1,w1)):
                for iz,wz in ((i,w0),(ip1,w1)):
                    r=r.at[ix[s,0],iy[s,1],iz[s,2]].add(wx[s,0]*wy[s,1]*wz[s,2]*fac)
        return r
    return jax.lax.fori_loop(0,n_s,loop,rho)

def accel(pos,rho):
    i,ip1,w0,w1=cic_weights(pos)
    Vh=-jnp.fft.fftn(4*jnp.pi*G*(rho-rho_bar))/(k_sq+(k_sq==0))
    ax=-jnp.real(jnp.fft.ifftn(1j*kx*Vh))
    ay=-jnp.real(jnp.fft.ifftn(1j*ky*Vh))
    az=-jnp.real(jnp.fft.ifftn(1j*kz*Vh))
    grid=jnp.stack((ax,ay,az),-1)
    a=jnp.zeros((n_s,3))
    for ix,wx in ((i,w0),(ip1,w1)):
        for iy,wy in ((i,w0),(ip1,w1)):
            for iz,wz in ((i,w0),(ip1,w1)):
                a+=((wx[:,0]*wy[:,1]*wz[:,2])[:,None])*grid[ix[:,0],iy[:,1],iz[:,2]]
    return a

@jax.jit
def substep(psi,pos,vel):
    rho=jnp.abs(psi)**2 + deposit(pos)
    V  = pot(rho)
    psi=jnp.exp(-1j*m_h*dt/2*V)*psi
    vel=vel+accel(pos,rho)*dt/2
    psi=jnp.fft.ifftn(jnp.fft.fftn(psi)*jnp.exp(-1j*dt*k_sq/(2*m_h)))
    pos=jnp.mod(pos+vel*dt, jnp.array([Lx,Ly,Lz]))
    rho=jnp.abs(psi)**2 + deposit(pos)
    V  = pot(rho)
    psi=jnp.exp(-1j*m_h*dt/2*V)*psi
    vel=vel+accel(pos,rho)*dt/2
    return psi,pos,vel

# ──────────────────────────────────────────────────────────────
# 2. Wave-function initial condition
# ──────────────────────────────────────────────────────────────
amp,sigma=100.0,4.0
rho=10.0
for cx,cy in [(0.5,0.4),(0.6,0.5),(0.4,0.6),(0.6,0.4),(0.6,0.6)]:
    rho+=amp*jnp.exp(-((X-cx*Lx)**2+(Y-cy*Ly)**2)/(2*sigma**2))
rho*=rho_bar/jnp.mean(rho)
psi0=jnp.sqrt(rho)+0j

# ──────────────────────────────────────────────────────────────
# 3. Star / cluster samplers
# ──────────────────────────────────────────────────────────────
def sample_random():
    rng=np.random.default_rng(0)
    pos=rng.uniform(0,1,(n_s,3))*np.array([Lx,Ly,Lz])
    vel=jnp.zeros_like(pos)
    cid=-jnp.ones(n_s,dtype=jnp.int32)
    return jnp.array(pos),vel,cid

def sample_globulars(N_gc=6,r_core=0.8):
    rng=np.random.default_rng(123)
    box_mid=np.array([Lx/2,Ly/2,Lz/2])
    theta=np.linspace(0,2*np.pi,N_gc,endpoint=False)
    centres=np.column_stack([box_mid[0]+20*np.cos(theta),
                             box_mid[1]+20*np.sin(theta),
                             np.full(N_gc,box_mid[2])])
    base=n_s//N_gc; extra=n_s-base*N_gc
    pos,vel,cid=[],[],[]
    for k,c in enumerate(centres):
        n_here=base+(1 if k<extra else 0)
        for _ in range(n_here):
            u=rng.random(); r=r_core/np.sqrt(u**(-2/3)-1)
            phi=rng.uniform(0,2*np.pi); cost=rng.uniform(-1,1); sint=np.sqrt(1-cost**2)
            xyz=r*np.array([sint*np.cos(phi),sint*np.sin(phi),cost])
            pos.append(c+xyz); vel.append((0,0,0)); cid.append(k)
    return jnp.array(pos),jnp.array(vel),jnp.array(cid,dtype=jnp.int32)

pos_rand,vel_rand,cid_rand = sample_random()
pos_gc,  vel_gc,  cid_gc   = sample_globulars(N_gc)

# history buffers for COM trails
history_x=[[] for _ in range(N_gc)]
history_y=[[] for _ in range(N_gc)]

# ──────────────────────────────────────────────────────────────
# 4. Matplotlib figure   (cluster-coloured stars)
# ──────────────────────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(9, 7), dpi=150)
div = make_axes_locatable(ax); cax = div.append_axes('right', '5%', pad=0.05)

rho_proj = jnp.log10(jnp.mean(jnp.abs(psi0)**2, axis=2)).T
im = ax.imshow(rho_proj, cmap='viridis', origin='lower',
               extent=(0, Lx, 0, Ly))

colors   = plt.cm.tab10(range(N_gc))          # distinct colours (0…5)
scatters = [ax.scatter([], [], s=22, c=[colors[k]],
                       edgecolors='white', zorder=3)
            for k in range(N_gc)]             # one scatter artist / cluster
lines    = [ax.plot([], [], lw=1.3, color=colors[k], zorder=2)[0]
            for k in range(N_gc)]             # COM-trail for each cluster

cb = fig.colorbar(im, cax=cax)
cb.set_label('log10 ρ', rotation=270, labelpad=15)
ax.set_xlabel('x [kpc]'); ax.set_ylabel('y [kpc]')
ax.set_aspect('equal')


# ──────────────────────────────────────────────────────────────
# 5. Animation driver
# ──────────────────────────────────────────────────────────────
state = (psi0, pos_rand, vel_rand, cid_rand)
history_x = [[] for _ in range(N_gc)]
history_y = [[] for _ in range(N_gc)]

def animate(frame):
    global state
    psi, pos, vel, cid = state

    # swap to globular clusters at frame 50
    if frame == 50:
        state = (psi, pos_gc, vel_gc, cid_gc)
        psi, pos, vel, cid = state

    # integrate physics
    for _ in range(nt_sub):
        psi, pos, vel = substep(psi, pos, vel)
    state = (psi, pos, vel, cid)

    # update density image
    rho_p = jnp.log10(jnp.mean(jnp.abs(psi)**2, axis=2)).T
    im.set_data(rho_p)
    im.set_clim(float(rho_p.min()), float(rho_p.max()))

    # update each cluster’s particles and trail
    for k in range(N_gc):
        mask = (cid == k)
        # scatter points
        scatters[k].set_offsets(
            np.column_stack((pos[mask, 0]/Lx*nx,
                             pos[mask, 1]/Ly*ny))
        )
        # COM trail (only after swap)
        if frame >= 50 and mask.any():
            cx = float(jnp.mean(pos[mask, 0]) % Lx)
            cy = float(jnp.mean(pos[mask, 1]) % Ly)
            history_x[k].append(cx); history_y[k].append(cy)
            lines[k].set_data(history_x[k], history_y[k])

    return (im, *scatters, *lines)


anim=animation.FuncAnimation(fig,animate,frames=300,interval=100,blit=True)
HTML(anim.to_jshtml())