In [None]:
import warnings
import math

import astropy.units as u
import astropy.coordinates as coord
from astropy.constants import G as _G
import emcee
import numpy as np
import gala.dynamics as gd
import gala.potential as gp
from gala.units import galactic
import matplotlib.pyplot as plt
plt.style.use('apw-notebook')
%matplotlib inline

from scipy.misc import derivative
from scipy.integrate import quad
from scipy.interpolate import interp1d
from scipy.optimize import root

In [None]:
G = _G.decompose(galactic).value

In [None]:
# NFW background potential
a = 20. # kpc
vc_nfw = 0.2 # kpc/Myr
_nfw = gp.SphericalNFWPotential(v_c=vc_nfw, r_s=a, units=galactic)

# Hernquist density profile
m_h = 1E11 # Msun
c = 15. # kpc
_hernquist = gp.HernquistPotential(m=m_h, c=c, units=galactic)

In [None]:
def nfw_potential(r):
    return _nfw._value(np.array([r,0.,0.]))

def nfw_r(phi):
    res = root(lambda r: nfw_potential(r[0]) - phi, 10.)
    if res.success:
        return res.x[0]
    else:
        return np.nan
nfw_r(-0.12)

In [None]:
def hernquist_density(phi):
    r = nfw_r(phi)
    rho0 = m_h / (2*np.pi*c**3)
    return rho0 / ((r/c) * (1+r/c)**3)
hernquist_density(-0.12)

In [None]:
phi = -0.012
derivative(hernquist_density, phi, dx=1E-3*phi)

In [None]:
def integrand(phi, H):
    dp_dphi = derivative(hernquist_density, phi, dx=1E-3*phi)
    return dp_dphi / np.sqrt(phi - H)

In [None]:
some_E = -0.01
derivative(lambda H: quad(integrand, H, 0, args=(H,))[0], some_E, dx=np.abs(1E-4*some_E))

In [None]:
n_df = 256
curlyE = np.linspace(1E-2,5.1,n_df)
energy_grid = -curlyE * vc_nfw**2

df = np.zeros(n_df)
for i,energy in enumerate(energy_grid):
    df[i] = derivative(lambda H: quad(integrand, H, 0, args=(H,))[0], energy, dx=np.abs(1E-4*energy))
    
log_df = np.log(df)

In [None]:
plt.semilogy(curlyE, df)

In [None]:
log_df_interp = interp1d(energy_grid[np.isfinite(log_df)], log_df[np.isfinite(log_df)], 
                         fill_value="extrapolate")

In [None]:
log_df_interp(-2.*G*m_h/a)

## First, sample radii in a Hernquist profile

In [None]:
def sample_radii(size=1, r_min=0.*u.kpc, r_max=np.inf*u.kpc):
    r_min = r_min.to(u.kpc).value
    r_max = r_max.to(u.kpc).value
    
    Menc = lambda r: _hernquist.c_instance.mass_enclosed(np.array([[r,0,0]]), G=_hernquist.G)[0]
    Mtot = m_h
    
    def root_func(r, m):
        return (m - Menc(float(r))/Mtot)
    
    if r_min == 0.:
        m_min = 0.
    else:
        m_min = Menc(r_min)/Mtot
        
    if r_max == np.inf:
        m_max = 1.
    else:
        m_max = Menc(r_max)/Mtot
    
    m = np.random.uniform(m_min, m_max, size=size)
    return np.array([root(root_func, 1., args=(m[i],)).x[0] for i in range(size)])

In [None]:
r = sample_radii(1024)
xyz = np.zeros((3,r.size))
xyz[0] = r

### Make sure the sampled radii follow the correct profile

In [None]:
bins = np.logspace(-1, 4, 32)
H,_ = np.histogram(r, bins=bins)

V = 4/3*np.pi*(bins[1:]**3 - bins[:-1]**3)
bin_cen = (bins[1:]+bins[:-1])/2.

q = np.zeros((3,len(bin_cen)))
q[0] = bin_cen
plt.plot(bin_cen, _hernquist.density(q) / m_h, marker=None, lw=2., ls='--')

plt.loglog(bin_cen, H/V/r.size, marker=None)

plt.xlabel('$r$')
plt.ylabel('$n(r)$')

## now sample velocities from the DF...

In [None]:
def ln_vel_dist(p, r):
    v = p[0]
    if v <= 0.:
        return -np.inf
    
    E = 0.5*v**2 + nfw_potential(r)
    log_f = log_df_interp(E) + 2*np.log(v)

    return log_f

In [None]:
nwalkers = 32
v = np.zeros_like(r)

with warnings.catch_warnings():
    warnings.filterwarnings('error')
    
    for i in range(len(r)):
        p0 = np.abs(np.random.normal(1E-3, 1E-4, (nwalkers,1))) # velocity!!
        sampler = emcee.EnsembleSampler(nwalkers=nwalkers, dim=1, lnpostfn=ln_vel_dist, args=(r[i],))
        
        try:
            _ = sampler.run_mcmc(p0, 128)
        except Warning:
            print("Failed!", i)
            break
        v[i] = sampler.chain[0,-1,0]

In [None]:
for link in sampler.chain:
    plt.plot(link[:,0], drawstyle='steps', alpha=0.4, marker=None)

In [None]:
plt.hist(v, bins=np.linspace(0, 0.3, 20))
plt.xlabel("$v$")

In [None]:
def r_v_to_3d(r, v):
    phi = np.random.uniform(0, 2*np.pi, size=r.size)
    theta = np.arccos(2*np.random.uniform(size=r.size) - 1)
    sph = coord.PhysicsSphericalRepresentation(phi=phi*u.radian, theta=theta*u.radian, r=r*u.one)
    xyz = sph.represent_as(coord.CartesianRepresentation).xyz
    
    phi_v = np.random.uniform(0, 2*np.pi, size=v.size)
    theta_v = np.arccos(2*np.random.uniform(size=v.size) - 1)
    v_sph = coord.PhysicsSphericalRepresentation(phi=phi_v*u.radian, theta=theta_v*u.radian, r=v*u.one)
    v_xyz = v_sph.represent_as(coord.CartesianRepresentation).xyz
    
    return xyz, v_xyz

In [None]:
xyz, vxyz = r_v_to_3d(r, v)

In [None]:
w0 = gd.CartesianPhaseSpacePosition(pos=xyz, vel=vxyz)

In [None]:
t_cross = r / v
ecc = np.zeros_like(t_cross)
r_f = np.zeros_like(t_cross)

for i in range(len(t_cross)):
    w = _nfw.integrate_orbit(w0[i], dt=t_cross[i]/100., n_steps=2000)
    ecc[i] = w.eccentricity()
    r_f[i] = np.sqrt(np.sum(w.pos[:,-1]**2)).value

In [None]:
plt.hist(ecc[np.isfinite(ecc)])

## See what final radial distribution looks like

In [None]:
bins = np.logspace(-1, 3, 32)
H,_ = np.histogram(r_f, bins=bins)

V = 4/3*np.pi*(bins[1:]**3 - bins[:-1]**3)
bin_cen = (bins[1:]+bins[:-1])/2.

q = np.zeros((3,len(bin_cen)))
q[0] = bin_cen
plt.plot(bin_cen, _hernquist.density(q) / m_h, marker=None, lw=2., ls='--')

plt.loglog(bin_cen, H/V/r.size, marker=None)

plt.xlabel('$r$')
plt.ylabel('$n(r)$')