In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import gamma
from imripy import merger_system as ms
from imripy import halo
from imripy import inspiral

In [None]:
# Basic spike and merger system properties 
g_spike = 7./3.
m1 = 1e3*ms.solar_mass_to_pc
m2 = 1*ms.solar_mass_to_pc
rho_spike = 226.*ms.solar_mass_to_pc
r_spike = ( (3 - g_spike) * m1 / (2 * np.pi * rho_spike) * 0.2**(3.-g_spike) )**(1./3.)
D = 1e6

Eps_grid = np.geomspace(1e-13, 1e1, 1000)
f_grid = rho_spike * g_spike*(g_spike-1.)/(2.*np.pi)**(3./2.) * (r_spike/m1)**g_spike * gamma(g_spike-1.)/gamma(g_spike-1./2.) * Eps_grid**(g_spike-3./2.)
potential = lambda r: m1/r

dh = halo.DynamicSS(Eps_grid, f_grid, potential)
sp = ms.SystemProp(m1, m2, dh, D)

In [None]:
# Inspiral properties and refined grid
R0 = 50.* sp.r_isco()
R_fin = 45. * sp.r_isco()
r_grid = np.geomspace(sp.r_isco(), 50*R0, 1000)

Eps_grid = np.geomspace(1e-13, 1e1, 1000)
Eps_grid = np.sort(np.append(Eps_grid, np.geomspace(1e-1 * (sp.m1/R0 - (sp.omega_s(R0)*R0)**2 / 2.), 1.2 * sp.m1/R0, 2000)))

sp.halo.Eps_grid = Eps_grid; sp.halo.update_Eps()
sp.halo.f_grid =  rho_spike * g_spike*(g_spike-1.)/(2.*np.pi)**(3./2.) * (r_spike/m1)**g_spike * gamma(g_spike-1.)/gamma(g_spike-1./2.) * Eps_grid**(g_spike-3./2.)
haloModel = inspiral.HaloFeedback(sp)

In [None]:
# Evolve the phase space distribution with the first method
ev = haloModel.Evolve_HFK( R0, R_fin = R_fin, adjust_stepsize=True)
print(len(ev.t), ev.t, ev.R)

In [None]:
# Reset phase space distribution and evolve with the second method
f_grid = rho_spike * g_spike*(g_spike-1.)/(2.*np.pi)**(3./2.) * (r_spike/m1)**g_spike * gamma(g_spike-1.)/gamma(g_spike-1./2.) * Eps_grid**(g_spike-3./2.)
dh.f_grid = f_grid
ev_2 = haloModel.Evolve( R0, R_fin = R_fin)
print(len(ev_2.t), ev_2.t)

In [None]:
from matplotlib.animation import FuncAnimation

n_frame = max(len(ev.t), len(ev_2.t)) - 1
#n_frame = 10

print(n_frame)
fig, (ax_rho, ax_f) = plt.subplots(2, 1, figsize=(20,20))

index1 = 0; index2 = 0

v_0 = sp.omega_s(R0)*R0
Tini_orb =  2.*np.pi / sp.omega_s(R0)
lr0 = ax_rho.axvline(R0/sp.r_isco(), linestyle='-.', label='$r_0$', color='black')

dh.f_grid = ev.f[0,:]
lrho, = ax_rho.loglog(r_grid/sp.r_isco(), dh.density(r_grid), label=r'$\rho$')
lrho_v0, = ax_rho.loglog(r_grid/sp.r_isco(), dh.density(r_grid, v_max=[sp.omega_s(r)*r for r in r_grid]), 
                                  color=lrho.get_c(), linestyle='--', label=r'$\rho_{v<v_{orb}}$')

lr = ax_rho.axvline(ev.R[0]/sp.r_isco(), linestyle='-.', color=lrho.get_c(), label='$r$')
lf, = ax_f.loglog(dh.Eps_grid, dh.f_grid, label="$f$", color=lrho.get_c())

ldf, = ax_f.loglog(dh.Eps_grid, np.abs(haloModel.dfHalo_dt(ev.R[0], v_cut=v_0)*Tini_orb), 
                                  linestyle='--', color=lrho.get_c(), label="$|\Delta f|$")
lmr = ax_rho.axvline(sp.m1/ev.R[0], linestyle='-.', color=lrho.get_c(), label='$m1/r_1$')
l1 = [lrho, lrho_v0, lf, ldf, lr, lmr]

dh.f_grid = ev_2.f[0,:]
lrho2, = ax_rho.loglog(r_grid/sp.r_isco(), dh.density(r_grid), label=r'$\rho_2$')
lrho2_v0, = ax_rho.loglog(r_grid/sp.r_isco(), dh.density(r_grid, v_max=[sp.omega_s(r)*r for r in r_grid]), 
                                   linestyle='--', color=lrho2.get_c(), label=r'$\rho_{2,v<v_{orb}}$')

lr2 = ax_rho.axvline(ev_2.R[0]/sp.r_isco(), linestyle='-.', color=lrho2.get_c(), label='$r_2$')

lf2, = ax_f.loglog(dh.Eps_grid, dh.f_grid, color=lrho2.get_c(), label="$f_2$")
ldf2, = ax_f.loglog(dh.Eps_grid, np.abs(haloModel.dfHalo_dt(ev_2.R[0], v_cut=v_0)*Tini_orb), 
                                   linestyle='--', color=lrho2.get_c(), label="$|\Delta f_2|$")
lmr2 = ax_rho.axvline(sp.m1/ev_2.R[0], linestyle='-.', color=lrho2.get_c(), label='$m1/r_2$')
l2 = [lrho2, lrho2_v0, lf2, ldf2, lr2, lmr2]


ax_rho.set_ylabel(r'$\rho$ / $pc^{-2}$', fontsize=20); ax_rho.set_xlabel(r'$r$ / $r_{isco}$', fontsize=20); ax_rho.grid()
ax_f.set_ylabel(r'$f$ / $pc^{-2}$',fontsize=20); ax_f.set_xlabel(r'$\epsilon$', fontsize=20); ax_f.grid()
ax_rho.set_xlim((r_grid[0]/sp.r_isco(), r_grid[-1]/sp.r_isco()))
fig.legend(fontsize=20, loc='center right')

def init_plot():
    return l1+l2

def update_plot(frame):
    print(frame)
    global index1, index2
    updt1 = False; updt2 = False
    if len(ev.t) > len(ev_2.t):
        index1 += 1; updt1 = True
        if ev.t[index1] > ev_2.t[index2] and index2 < len(ev_2.t) - 1:
            index2 += 1; updt2 = True
    else:
        index2 += 1; updt2 = True
        if ev_2.t[index2] > ev.t[index1] and index1 < len(ev.t) - 1:
            index1 += 1; updt1 = True

    #print(frame, index1, t[index1], index2, t2[index2])
    ax_rho.set_title(f"t_1={ev.t[index1]/ms.year_to_pc : .4f} yrs, t_2={ev_2.t[index2]/ms.year_to_pc : .4f}")
   
    if updt1:
        lrho, lrho_v0, lf, ldf, lr, lmr = l1
        dh.f_grid = ev.f[index1,:]; 
        lrho.set_data(r_grid/sp.r_isco(), dh.density(r_grid))
        lrho_v0.set_data(r_grid/sp.r_isco(), dh.density(r_grid, v_max=[sp.omega_s(r)*r for r in r_grid]))
        lr.set_data( ev.R[index1]/sp.r_isco(), lr.get_ydata())
        lf.set_data(dh.Eps_grid, dh.f_grid)
        
        if index1 < len(ev.t)-1:
            delta_t = ev.t[index1+1] - ev.t[index1]
            v_0 = sp.omega_s(ev.R[index1])*ev.R[index1]
            ldf.set_data(dh.Eps_grid, np.abs(haloModel.dfHalo_dt(ev.R[index1], v_cut=v_0, t_scale=delta_t)*delta_t))

        lmr.set_data(sp.m1/ev.R[index1], lmr.get_ydata())
    
    if updt2:
        lrho2, lrho2_v0, lf2, ldf2, lr2, lmr2 = l2
        dh.f_grid = ev_2.f[index2, :]; 
        lrho2.set_data(r_grid/sp.r_isco(), dh.density(r_grid))
        lrho2_v0.set_data(r_grid/sp.r_isco(), dh.density(r_grid, v_max=[sp.omega_s(r)*r for r in r_grid]))
        lr2.set_data( ev_2.R[index2]/sp.r_isco(), lr2.get_ydata())
        lf2.set_data(dh.Eps_grid, dh.f_grid)

        if index2 < len(ev_2.t)-1:
            delta_t = ev_2.t[index2+1] - ev_2.t[index2]
            v_0 = sp.omega_s(ev_2.R[index2])*ev_2.R[index2]
            ldf2.set_data(dh.Eps_grid, np.abs(haloModel.dfHalo_dt(ev_2.R[index2], v_cut=v_0)*delta_t))
        lmr2.set_data(sp.m1/ev_2.R[index2], lmr2.get_ydata())
    
    if not updt1:
        return l2
    if not updt2: 
        return l1
    return l1 + l2


ani = FuncAnimation(fig, update_plot, frames=n_frame, blit=True, init_func=init_plot , interval=100, repeat=True)
ani.save("HaloFeedbackMethodComparison.mp4")
plt.show()