In [None]:
from math import log, pi
import astropy.coordinates as coord
from astropy.constants import G as _G
import astropy.units as u
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('apw-notebook')
%matplotlib inline
import emcee

from gala.units import UnitSystem
import gala.potential as gp

import starfish

In [None]:
units = UnitSystem(u.milliparsec, u.yr, u.Msun, u.radian)

In [None]:
G = units.decompose(_G).value

In [None]:
def log_df(_E, M, a):
    E = _E*a / (G*M)
    A = 0.5*log(2.) + 3*log(2*pi) + 1.5*log(G*M*a)
    term1 = 0.5*np.log(E) - 2*np.log(1-E)
    term2 = np.log((1-2*E)*(8*E**2-8*E-3) + 3*np.arcsin(np.sqrt(E))/np.sqrt(E*(1-E)))
    return term1 + term2 - A

def E_circ(r, M, a):
    """ Energy of a circular orbit at r """
    return G*M / (r+a) * (1 - r/(2*(r+a)))

def v_circ(r, M, a):
    """ Energy of a circular orbit at r """
    return np.sqrt(G*M * r / (r+a)**2)

def M_enc(r, M, a):
    return M * r**2 / (r+a)**2

def r_max(E, M, a):
    return G*M/E - a

def rv_to_E(r, v, M, a):
    return G*M/(r+a) - 0.5*v**2

def sample_r(a, r_min=0, r_max=1E20, size=1):
    uu = np.random.uniform(M_enc(r_min, 1., a), 
                           M_enc(r_max, 1., a),
                           size=size)
    
    sgn = np.random.choice([-1,1], size=size)
    # return a*np.sqrt(uu) / (1 + sgn*np.sqrt(uu))
    r_m = a*np.sqrt(uu) / (1 - np.sqrt(uu))
    # r_p = a*np.sqrt(uu) / (1 + np.sqrt(uu))
    
    return r_m

In [None]:
M = 1E6 # Msun
a = 1. # mpc
pot = gp.HernquistPotential(m=M*u.Msun, c=a*u.mpc, units=units)

# Maximum energy to consider is that of a circular orbit at r = 0.001
r_min = 1E-1
max_E = E_circ(r_min, M, a)
max_E / (G*M/a)

In [None]:
E = np.linspace(0, 1, 1024) * (G*M/a)

fig,ax = plt.subplots(1, 1, figsize=(6,5))
ax.set_xlim(0, 1.5)

ax.plot(E / (G*M/a), np.exp(log_df(E, M, a) + 1.5*np.log(G*M/a)), marker='')
ax.set_yscale('log')
ax.set_ylim(1e-7, 1e4)

ax.set_xlabel(r'$\mathcal{E}/(GM/a)$')
ax.set_ylabel(r'$\log_{10}\left[ (GMa)^{3/2} \, f(\mathcal{E}) \right]$')

# Sample from the DF using MCMC

In [None]:
def lnprob(p, M, a):
    ln_r, ln_v = p
    r = np.exp(ln_r)
    v = np.exp(ln_v)
    
    E = G*M/(r+a) - 0.5*v**2
    
    if not 0 < E < max_E:
        return -np.inf
    
    df_ = log_df(E, M, a)
    
    if not np.isfinite(df_):
        return -np.inf
    
    return df_ + 2*ln_r + 2*ln_v + ln_r + ln_v

In [None]:
p0 = np.zeros((32, 2))
p0[:,0] = np.random.normal(1., 0.1, size=p0.shape[0])
p0[:,1] = v_circ(p0[:,0], M, a)
p0 = np.log(p0)

In [None]:
sampler = emcee.EnsembleSampler(nwalkers=p0.shape[0], dim=2, 
                                lnpostfn=lnprob, args=(M, a))

In [None]:
_ = sampler.run_mcmc(p0, 16384)

In [None]:
for dim in range(p0.shape[1]):
    plt.figure()
    for walker in sampler.chain[...,dim]:
        plt.plot(walker, marker='', drawstyle='steps-mid', alpha=0.2)

In [None]:
flatchain = np.vstack(sampler.chain[:,200::8])
r_samples, v_samples = np.exp(flatchain.T)

In [None]:
bins = np.logspace(-2, 2.5, 64)
bin_ctr = (bins[:-1] + bins[1:]) / 2.
V = 4*np.pi*bin_ctr**2 * (bins[1:] - bins[:-1])

H,_ = np.histogram(r_samples, bins)

plt.plot(bin_ctr, H/V)
plt.plot(bin_ctr, [pot.density([x,0,0.]).value[0] for x in bin_ctr])

plt.xscale('log')
plt.yscale('log')

In [None]:
E_samples = rv_to_E(r_samples, v_samples, M, a)
plt.hist(E_samples, bins='auto');
plt.xlabel(r'$\mathcal{E}$')

In [None]:
plt.hist(v_samples, bins='auto');
plt.xlabel('$v$')

## Now get 3D coordinates assuming isotropy:

In [None]:
# further downsample:
r_subset = r_samples[::32]
v_subset = v_samples[::32]

w0 = starfish.rv_to_3d_isotropic(r_subset*units['length'], 
                                 v_subset*units['speed'])

w = pot.integrate_orbit(w0, dt=0.1, n_steps=4000)
w = w[-1].represent_as(coord.PhysicsSphericalRepresentation)

Compare initial and final density profile:

In [None]:
Hi,_ = np.histogram(r_subset, bins)
Hf,_ = np.histogram(w.r.value, bins)

plt.plot(bin_ctr, [pot.density([x,0,0.]).value[0] for x in bin_ctr], marker='')
plt.plot(bin_ctr, Hi/V, marker='', linestyle='--')
plt.plot(bin_ctr, Hf/V, marker='', linestyle='--')

plt.xscale('log')
plt.yscale('log')