Experiment with sampling from distribution functions

In [None]:
from astropy.constants import G
import astropy.units as u
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import root
from scipy.integrate import quad
import emcee

from gala.mpl_style import mpl_style
import gala.potential as gp
from gala.units import dimensionless

plt.style.use('apw-notebook')
%matplotlib inline

### Hernquist with:
### $G=M=a=1$

In [None]:
pot = gp.HernquistPotential(m=1., c=1., units=dimensionless)

In [None]:
def sample_radii(pot, size=1, r_min=0.*u.kpc, r_max=np.inf*u.kpc):
    r_min = r_min.to(u.kpc).value
    r_max = r_max.to(u.kpc).value
    
    # hernquist-specific
    Menc = lambda rr: rr**2 / (1+rr)**2
    Mtot = 1.
    
    def root_func(r, m):
        return (m - Menc(float(r))/Mtot)
    
    if r_min == 0.:
        m_min = 0.
    else:
        m_min = Menc(r_min)/Mtot
        
    if r_max == np.inf:
        m_max = 1.
    else:
        m_max = Menc(r_max)/Mtot
    
    m = np.random.uniform(m_min, m_max, size=size)
    return np.array([root(root_func, 1., args=(m[i],)).x[0] for i in range(size)]) * u.kpc

In [None]:
r = sample_radii(pot, 10000)

In [None]:
xyz = np.zeros((3,r.size)) * r.unit
xyz[0] = r

## Make sure the sampled radii follow the correct profile

In [None]:
bins = np.logspace(-3, 3, 32)
H,_ = np.histogram(r, bins=bins)

V = 4/3*np.pi*(bins[1:]**3 - bins[:-1]**3)
bin_cen = (bins[1:]+bins[:-1])/2.

q = np.zeros((3,len(bin_cen)))
q[0] = bin_cen
plt.plot(bin_cen, pot.density(q) / pot.parameters['m'], marker=None, lw=2., ls='--')

plt.loglog(bin_cen, H/V/r.size, marker=None)

plt.xlabel('$r$ [kpc]')
plt.ylabel('$n(r)$ [kpc$^{-3}$]')

In [None]:
def hernquist_df(curly_E_tilde):
    E = curly_E_tilde
    
    A = (np.sqrt(2)*(2*np.pi)**3)**-1 * np.sqrt(E) / (1-E)**2
    term1 = (1 - 2*E)*(8*E**2 - 8*E - 3)
    term2 = 3*np.arcsin(np.sqrt(E)) / np.sqrt(E*(1-E))
    
    return A * (term1 + term2)

### Compare to Figure 4.2 in BT08

In [None]:
eee = np.linspace(0, 1.5, 256)

plt.plot(eee, np.log10(hernquist_df(eee)), marker=None)
plt.xlim(0,1.5)
plt.ylim(-7, 4)
plt.xlabel(r'$\tilde{\mathcal{E}}$')
plt.ylabel(r'$\log_{10}[(GMa)^{3/2} f]$')

In [None]:
import math

In [None]:
def _hernquist_df(E):    
    A = (math.sqrt(2)*(2*math.pi)**3)**-1 * math.sqrt(E) / (1-E)**2
    term1 = (1 - 2*E)*(8*E**2 - 8*E - 3)
    term2 = 3*math.asin(math.sqrt(E)) / math.sqrt(E*(1-E))
    
    return A * (term1 + term2)

def _potential(r):
    return -1. / (r + 1.)

def vel_dist(v, r):
    E = 0.5*v**2 + _potential(r)
    curly_E = -E
    if curly_E <= 0. or curly_E >= 1.:
        return 0.
    return v**2 * _hernquist_df(curly_E)

# def ln_vel_dist(p, r):
#     v = p[0]
#     if v <= 0.:
#         return -np.inf
    
#     E = 0.5*v**2 + _hernquist_potential(r, _G, M, a)
#     curlyE = -E*a / (_G*M)
    
#     if curlyE <= 0. or curlyE >= 1.:
#         return -np.inf

#     return np.log(_hernquist_df(curlyE)) + 2*np.log(v)

In [None]:
vs = np.linspace(0, 1., 1024)
plt.plot(vs, [vel_dist(vv,8.) for vv in vs])

# TODO: Now generate velocities for each $r$, turn into full-space position, velocity, integrate orbits, compute eccentricity

In [None]:
n_walkers = 32
sampler = emcee.EnsembleSampler(n_walkers, dim=1, lnpostfn=ln_vel_dist, args=(r[1].value,))
p0 = np.random.normal(0.1, 1E-3, size=(n_walkers,1))
pos,_,_ = sampler.run_mcmc(p0, 128)
sampler.reset()
_ = sampler.run_mcmc(pos, 128)

In [None]:
plt.hist(sampler.flatchain)

In [None]:
plt.hist(sampler.flatchain)

In [None]:
for i in range(n_walkers):
    plt.plot(sampler.chain[i,:,0], marker=None, drawstyle='steps', alpha=0.1)

In [None]:
E = curly_E_tilde_to_E(sampler.chain[:,-1,0], pot)
v = np.sqrt(2*(-E - pot.potential(xyz[:,:1024])))

In [None]:
E

In [None]:
pot.potential(xyz[:,:1024])

In [None]:
v

In [None]:
plt.hist(v.to(u.km/u.s))

In [None]:
def sample_velocities(pot, size=1, Emin=1E-2, Emax=0.99):
    