In [2]:

"""
    Say something about what the code is supposed to do 
"""



#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Packages Used ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

#------------------------To solve the ODE----------------------------
using DifferentialEquations
#using LSODA

#------------------------For Plotting---------------------------------
using Plots; 
#pyplot()
# using GRUtils
# using GR

#------------------------For Calling Python---------------------------
using PyCall
np = pyimport("numpy")
using SciPy

#------------------------For rendering LaTeX in Plots-----------------
using LaTeXStrings

# #------------------------Writing Data to CSV file---------------------
# using CSV
# using Tables

#-----------------------For Benchmarking the ODE solver--------------
# using BenchmarkTools
# using Logging: global_logger
# using TerminalLoggers: TerminalLogger
# global_logger(TerminalLogger()) 



#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Defining the System ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


#---------------------------function to calculate r_isco------------------------

function r_isco(m)    
        """
        Radius of the Innermost Stable Circular Orbit (ISCO) of a Schwarzschild black hole with mass m
        """
    return 6.0*float(m)
end


#---------------------------Initial Conditions-----------------------------------

function initial_conditions(mass, a0, e0, phi0 = 0.)

        """
        Calculate the initial conditions for a Keplerian orbit with parameters a, e
        """
    r0 = a0 * (1. - e0^2) / (1. + e0 * cos(phi0))
    dphi0 = sqrt(mass[3] * a0 * (1. - e0^2)) / r0^2
    dr0 = a0* (1. - e0^2) / (1 + e0*cos(phi0))^2 * e0 *sin(phi0)*dphi0
    return [r0, phi0, dr0, dphi0]
end

#-------------------------Halo Models--------------------------------------------

function Spike_Halo(r, rho_spike, r_spike, alpha, r_min=0.)   # Spike Model

        """
        Defining the Halo Model from Eda paper [https://arxiv.org/abs/1408.3534]

        The density is given by:
            rho (r) = rho_spike*(r_spike/r)^(alpha)

        Parameters:
            rho_spike : float
                The density parameter of the spike profile
            r_spike : float
                The scale radius of the spike profile
            alpha : float
                The power-law index of the spike profile, with condition 0 < alpha < 3
        """

    return  ifelse.(r .> r_min, rho_spike.*(r_spike./r).^alpha, 0.)
end


function Const_Halo(r, rho_0, r_min=0.)    # Constant Density Model
        """
        Constant Halo with density rho_0
        """
    return  ifelse.(r .> r_min, rho_0, 0.)
end



#takes mbh and r in pc and outputs profile in 1/pc^2
function rho_effective(r, Mbh)
    # scale params

    a = 23.1  #[pc] from eda paper
    rho0 = 3.8*1e-22 #[g/cm^3] from eda paper
    rho0 *= 5.60958*1e+23 #[g/cm^3 to GeV/cm^3]


    # profile params
    alpha = 0.331
    beta = -1.66
    gamma = 0.32
    delta = -0.000282
    
    # fit params 
    eta = 1
    A = 6.42
    w = 1.82
    q = 1.91

    # convert units so rho_bar has units 1/pc^2 and x_tilde is unitless
    #r *= 1e3 #[kpc to pc]
    Mbh1 = Mbh / (4.8e-14) #[pc to Msol]
    A *= (10^-43) * (4.367e26) #[Msol^-2 to pc^-2: 1/2.29e-27]
    x_tilde = r./Mbh
    
    rho_bar = A.*((1-( 4. * eta ./ x_tilde)).^w).*(((4.17 * 10^11)./x_tilde).^q)
    
    rho = rho_bar.*(10^delta).*((rho0/0.3)^alpha).*((Mbh1/1e6)^beta).*((a/20)^gamma)
    
    # convert back to GeV/cm^3 from 1/pc^2
#     rho *= 1.264e15
    return rho
end


#---------------------------System Parameters-----------------------------------
m1 = 4.8e-11
m2 = 4.8e-14
D  = 1e5 
a0 = 100.0*r_isco(m1)                
e0 = 0.

#---------------------------Mass Array------------------------------------------
m_total = m1+m2
mu = (m1*m2)/(m1+m2)
mass=[m1,m2,m_total,mu]

#--------------------Halo Parameters from Eda paper-----------------------------
rho_spike = 1.0848e-11               # [226.*solar_mass_to_pc]
alpha_spike = 2.544
r_spike = 0.54


u0 = initial_conditions(mass, a0, e0)



#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Evolving the System ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

#-------------------------Evolution Parameters--------------------------------------------

F0 = sqrt(mass[3]/a0^3)/2/pi
nOrbits = 1000                     # Number of Orbits
tp_per_orb = 100                   # Time-points per orbit
dp = nOrbits*tp_per_orb            # total number of data points the code runs for
t_end = nOrbits/F0


#-------------------------Evolution Function-----------------------------------------------

function Evolve(mass, u0, t_end, t_start=0., gwEmission=true, dynamicalFriction=true, PostNewtonian=true, coulombLog=3., nSteps=1000)
        
        """
        Evolve the system of differential equations from t_start to t_end with initial conditions u0
        """
    t_eval = LinRange(t_start, t_end, nSteps)
    eta    = mass[4]/mass[3]
    m      = mass
    tspan  = (t_start,t_end)

    
    #---------------------------Derivative Function---------------------------------------
    
    function orbit!(du, u, m, t)
        
        # r=u[1],  phi=u[2],  dr=u[3],  dphi=u[4]
        # m = [m1,m2,m_total,mu]
        
        v = sqrt(u[3]^2 + (u[1]^2)*(u[4])^2) 
        
        P_df     = dynamicalFriction ? 4*pi*(m[2]^2 / m[4]) * (rho_effective(u[1], m[1])) * (coulombLog/v^3) : 0.
        
        P_pn_r   = PostNewtonian ? (m[3]/u[1]^2)*( (4 + 2*eta)*(m[3]/u[1]) - (3*eta +1)*(u[1]^2)*(u[4]^2) + (3 - 7*eta/2)*u[3]^2) : 0.
        
        P_pn_phi = PostNewtonian ? (m[3]/u[1]^2)*(4 - 2*eta)*u[3]*u[4] : 0.
        
        P_gw_r   = gwEmission ? ( (8/5)*m[4]*(m[3]/u[1]^3)*(2*v^2 + (8/3)*(m[3]/u[1]) )) : 0.
        
        P_gw_phi = gwEmission ? ( (8/5)*m[4]*(m[3]/u[1]^3)*(v^2 + 3*(m[3]/u[1]) )) : 0.
        
        
        du[1] = u[3]  # dr
        du[2] = u[4]  # dphi
        du[3] = -(-P_gw_r + P_df)*(u[3]) + P_pn_r - (m[3]/u[1]^2) + u[1]*(u[4]^2)  # ddr
        du[4] = -(P_gw_phi + P_df)*(u[4]) + P_pn_phi - (2*u[3]*u[4]/u[1])          # ddphi
    end
    
    #---------------------Call-back functions---------------------------------------------
    
    function terminate_condition(u,t,integrator)       # condition at which the integration terminates
        u[1]< 7*r_isco(m1)
    end
    
    function terminate_affect!(integrator)
        terminate!(integrator)
    end
    
    terminate_cb =DiscreteCallback(terminate_condition,terminate_affect!)
 
    
    #-------------------------Calling the solver--------------------------------------------
    
    # call-backs do not work with lsoda() but they do work with other algorithms
    
    prob = ODEProblem(orbit!, u0, tspan, m, callback=terminate_cb)  # exclude the callback part when using lsoda()
    alg = VCABM() #lsoda() #AN5() #VCABM5() #Tsit5()  #DP5() #AutoVern7(Rodas5())  #Vern9(lazy=false) #Feagin12() #Vern7() 
    @time sol = solve(prob, alg, abstol=1e-14, reltol=1e-12, saveat=t_eval, dense=false, maxiters=Int(1e9))    # For predefined time-steps 
    #@time sol = solve(prob, alg, abstol=1e-14, reltol=1e-12, dense=true, maxiters=Int(1e9))                   # For adaptive time-steps
    return sol
end


solution_all = Evolve(mass,u0,t_end,0.,true,true,true,3., dp)    # calling the evolution function with all effects
sol_trans_all = solution_all'
r1 = sol_trans_all[:,1]
phi1 = sol_trans_all[:,2]
dr1 = sol_trans_all[:,3]
dphi1 = sol_trans_all[:,4]
t1 = solution_all.t

solution_woDM = Evolve(mass,u0,t_end,0.,true,false,true,3., dp)    # calling the evolution function without dark matter
sol_trans_woDM = solution_woDM'
r2 = sol_trans_woDM[:,1]
phi2 = sol_trans_woDM[:,2]
dr2 = sol_trans_woDM[:,3]
dphi2 = sol_trans_woDM[:,4]
t2 = solution_woDM.t

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Calculating the GW strain~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


#--------------------------------Strain Function------------------------------------------------

function strain(mass, D, r, phi, t)
    """
    This function takes in the orbit solution arrays (r and phi) and returns the gravitational wave pattern 
    """
    
    #---------Observer parameters--------------
    mu = mass[4]
    theta_o = 0    # inclination_angle
    phi_o = 0      # pericenter_angle
    
    #---------Scaling--------------------------
    r0 = u0[1]
    r /= r0
    T = t[2]-t[1]
    
    #---------Rotating body parameters---------
    x = r.*cos.(phi)
    y = r.*sin.(phi)
    z = zeros(length(x))
    
    #--------Quadrupole Moment Tensor----------
    Q  = [ [mu*x.*x, mu*x.*y, mu*x.*z],
           [mu*y.*x, mu*y.*y, mu*y.*z],
           [mu*z.*x, mu*z.*y, mu*z.*z] ]
    
    #--------Derivative Functions---------------
    
    function Mdt(Q)     
        """
        Returns the first derivative of Quadrupole Moment Tensor with respect to time
        """
        
        dQdt = [ [np.gradient(Q[1][1], T), np.gradient(Q[1][2], T), np.gradient(Q[1][3], T)], 
                 [np.gradient(Q[2][1], T), np.gradient(Q[2][2], T), np.gradient(Q[2][3], T)],
                 [np.gradient(Q[3][1], T), np.gradient(Q[3][2], T), np.gradient(Q[3][3], T)] ]
        
        return dQdt
    end
    
    
    function Mdt2(Q) 
        """
        Returns the second derivative of Quadrupole Moment Tensor with respect to time
        """
        
        dQ2dt2 = Mdt(Mdt(Q))
        
        return dQ2dt2*(r0^2)
    end
    
    
    d2Qd2t = Mdt2(Q)
    
    
    h_plus =  (1.0/D) * (   d2Qd2t[1][1].*(cos.(phi_o).^2 - sin.(phi_o).^2 .*cos.(theta_o).^2) 
                          + d2Qd2t[2][2].*(sin.(phi_o).^2 - cos.(phi_o).^2 .*cos.(theta_o).^2) 
                          - d2Qd2t[3][3].*(sin.(theta_o).^2) 
                          - d2Qd2t[1][2].*(sin.(2*phi_o).*(1.0 .+ cos.(theta_o).^2))
                          + d2Qd2t[1][3].*(sin.(phi_o).*sin.(2*theta_o)) 
                          + d2Qd2t[2][3].*(cos.(phi_o).*sin.(2*theta_o))     ) 
    
    h_cross = (1.0/D) * (   (d2Qd2t[1][1]-d2Qd2t[2][2]).*sin.(2*phi_o).*cos.(theta_o)
                             + 2*d2Qd2t[1][2].*cos.(2*phi_o).*cos.(theta_o) 
                             - 2*d2Qd2t[1][3].*cos.(phi_o).*sin.(theta_o) 
                             + 2*d2Qd2t[2][3].*sin.(theta_o).*sin.(phi_o)    )
    
    return [h_plus, h_cross]
    
end


function strain2d(mass, D, r, phi, t)
    
    """
    This function takes in the orbit solution arrays (r and phi) and returns the gravitational wave pattern 
    """
    
    #-------------Observer parameters----------
    mu = mass[4]
    theta_o = 0    # inclination_angle
    phi_o = 0      # pericenter_angle
    
    #----------------Scaling-------------------
    r0 = u0[1]
    r /= r0
    T = t[2]-t[1]
    
    #---------Rotating body parameters---------
    x = r.*cos.(phi)
    y = r.*sin.(phi)
    
    #--------Quadrupole Moment Tensor----------
    Q  = [ [mu*x.*x, mu*x.*y],
           [mu*y.*x, mu*y.*y] ]
    
    
    #--------Derivative Functions---------------

    function Mdt(Q)     
        """
        Returns the first derivative of Quadrupole Moment Tensor with respect to time
        """
        
        dQdt = [ [np.gradient(Q[1][1], T), np.gradient(Q[1][2], T)], 
                 [np.gradient(Q[2][1], T), np.gradient(Q[2][2], T)], ]
        
        return dQdt
    end
    
    
    function Mdt2(Q) 
        """
        Returns the second derivative of Quadrupole Moment Tensor with respect to time
        """
        
        dQ2dt2 = Mdt(Mdt(Q))
        
        return dQ2dt2*(r0^2)
    end
    
    
    d2Qd2t = Mdt2(Q)
    
    
    h_plus =  (1.0/D) * (   d2Qd2t[1][1].*(cos.(phi_o).^2) 
                          + d2Qd2t[2][2].*(-cos.(phi_o).^2 .*cos.(theta_o).^2)  )
    
    return h_plus
    
end


#--------------------------------FFT Function------------------------------------------------

function strainFFT(t, strain, f_bin)
    N = length(t)
    T = t[2] - t[1]
    
    h_plus_fft = SciPy.fft.fft(strain[1,:])/(2*pi*N)
    h_cross_fft = SciPy.fft.fft(strain[2,:])/(2*pi*N)
    xf = SciPy.fft.fftfreq(N, T)  #[1:Int(N ÷ 2)] 
    
    h_plus_fft = h_plus_fft[(xf .> f_bin[1]) .&& (xf .< f_bin[2])]
    h_cross_fft = h_cross_fft[(xf .> f_bin[1]) .&& (xf .< f_bin[2])]
    xf = xf[(xf .> f_bin[1]) .&& (xf .< f_bin[2])]

    return [xf, h_plus_fft, h_cross_fft]
end


function strainFFT2d(t, strain, f_bin)
    N = length(t)
    T = t[2] - t[1]
    
    h_plus_fft = SciPy.fft.fft(strain[:])/(2*pi*N)
    xf = SciPy.fft.fftfreq(N, T) #[1:Int(N ÷ 2)]
    
    h_plus_fft = h_plus_fft[(xf .> f_bin[1]) .&& (xf .< f_bin[2])]
    xf = xf[(xf .> f_bin[1]) .&& (xf .< f_bin[2])]
    
    return [xf , h_plus_fft]
end


#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Calculating the Strain~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
hz_to_invpc = 1.029e8
strain_spike_all     = strain2d(mass, D, r1, phi1, t1)
strain_FFT_spike_all = strainFFT2d(t1, strain_spike_all, [1e-4*hz_to_invpc, 1e-1*hz_to_invpc])
strain_spike_woDM     = strain2d(mass, D, r2, phi2, t2)
strain_FFT_spike_woDM = strainFFT2d(t2, strain_spike_woDM, [1e-4*hz_to_invpc, 1e-1*hz_to_invpc])
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~LISA Curve~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
hz_to_invpc = 1.029e8
s_to_pc = 9.716e-9
m_to_pc = 3.241e-17
solar_mass_to_pc = 4.8e-14
g_cm3_to_invpc2 = 7.072e8
year_to_pc = 0.3064
lisa_bandwidth = [1e-4*hz_to_invpc, 1*hz_to_invpc]

f_gw = np.geomspace(lisa_bandwidth[1], lisa_bandwidth[2], 100)

function NoiseSpectralDensity(f)
    P_oms = ((1.5e-11*m_to_pc)^2) .* (1 .+ (2e-3.*hz_to_invpc./f).^4 )./hz_to_invpc
    P_acc = (3e-15*m_to_pc/s_to_pc^2)^2 .* (1 .+ (0.4e-3.*hz_to_invpc./f).^2) .* (1 .+ (f./8e-3./hz_to_invpc).^4)./hz_to_invpc
    f_s = 19.09e-3*hz_to_invpc
    L = 2.5e9*m_to_pc
    return 10/3/L^2 .* (P_oms .+ 2*(1 .+ cos.(f./f_s).^2 ) .* P_acc./(2*pi.*f).^4) .* (1 .+ 6/10 .* (f./f_s).^2)
end

NoiseStrain = sqrt.(f_gw.*NoiseSpectralDensity(f_gw));


#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~Plotting Functions~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

function Semi_Major_Axis_Plot(r1,dr1,phi1,dphi1,t1,r2,dr2,phi2,dphi2,t2)
    v1 = sqrt.(dr1.^2 + (r1.^2 .* dphi1.^2))
    sma1 = m_total*(1 ./abs.(v1.^2 - 2*m_total ./r1))
    v2 = sqrt.(dr2.^2 + (r2.^2 .* dphi2.^2))
    sma2 = m_total*(1 ./abs.(v2.^2 - 2*m_total ./r2))
    gr()
    step_sma = 1   
    @time Plots.plot(t1[1:step_sma:end]/year_to_pc, sma1[1:step_sma:end]/r_isco(m1), dpi=300, xlabel = L"$t$ $[years]$", ylabel = L"$a(t)/r_{isco}$", label="with DM halo",legend=:bottomleft);
    Plots.lens!([0.0065, 0.007], [99.60, 99.65]; inset = (1, bbox(0.8, 0.4, 0.2, 0.2)), subplot=2, xticks=[],yticks=[], framestyle=:box)# lw=2, ls=:dot, lc=:orange)
    plot1 = Plots.plot!(t2[1:step_sma:end]/year_to_pc, sma2[1:step_sma:end]/r_isco(m1), label="without DM halo");
    Plots.savefig(plot1,"Semi_Major_Axis_Plot_NEW.png");
    return nothing
end

function GW_Strain_plot()
    yr_to_s = 31556952
    year_to_pc = 0.3064
    tstart = 20000
    tend   = 20500
    y_ticks= [-3.2,0,3.2]*10^-21
    @time Plots.plot(t1[tstart:tend]/year_to_pc, strain_spike_all[tstart:tend], ylim=(-4e-21, 4.5e-21), dpi=300, lw=2, yticks=y_ticks, ylabel=L"$h_+$", xlabel=L"$t$ $[years]$",label="with DM halo");
    plot2 = Plots.plot!(t2[tstart:tend]/year_to_pc, strain_spike_woDM[tstart:tend], dpi=300,lw=2, ls=:solid, label="without DM halo");
    Plots.savefig(plot2,"GW_Strain_Plot_NEW.png");
    return nothing
end

function Char_Strain_plot()
    f_start = 1
    f_end   = length(strain_FFT_spike_all[1])
    frequency = abs.(strain_FFT_spike_all[1][f_start:f_end])/hz_to_invpc
    step_hchar = 1
    h_char_all = 2*abs.(strain_FFT_spike_all[1][f_start:f_end]) .* abs.(strain_FFT_spike_all[2][f_start:f_end])
    h_char_woDM = 2*abs.(strain_FFT_spike_woDM[1][f_start:f_end]) .* abs.(strain_FFT_spike_woDM[2][f_start:f_end])
    @time Plots.plot(frequency[1:step_hchar:end], h_char_all[1:step_hchar:end], ylim=(1e-22, 1e-16), xaxis=:log, yaxis=:log, dpi=300,ylabel=L"$Characteristic$ $Strain$ $(h_c)$", xlabel=L"$f$ $[Hz]$", label="with DM halo");
    Plots.plot!(frequency[1:step_hchar:end], h_char_woDM[1:step_hchar:end], ylim=(1e-22, 1e-16), xaxis=:log, yaxis=:log, dpi=300, label="without DM halo");
    plot3 = Plots.plot!(f_gw/hz_to_invpc, NoiseStrain, label="LISA noise curve");
    Plots.savefig(plot3,"Char_Strain_Plot_NEW.png");
    return nothing
end


Semi_Major_Axis_Plot(r1,dr1,phi1,dphi1,t1,r2,dr2,phi2,dphi2,t2)
GW_Strain_plot()
Char_Strain_plot()



  3.311930 seconds (2.93 M allocations: 165.225 MiB, 1.88% gc time, 95.54% compilation time)
  0.130048 seconds (235.13 k allocations: 15.343 MiB)
  0.011967 seconds (1.32 k allocations: 4.649 MiB)
  0.001169 seconds (1.35 k allocations: 96.211 KiB)
  0.007123 seconds (1.40 k allocations: 1.466 MiB)
