In [None]:
import os

# Third-party
from astropy.table import Table
import astropy.coordinates as coord
from astropy.constants import G as _G
import astropy.units as u
import emcee
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('apw-notebook')
%matplotlib inline
from scipy.special import loggamma, gamma as Gamma
from scipy.integrate import quad
from gala.units import UnitSystem

import pystan
from stderr_helper import suppress_stdout_stderr

Make some fake data

Tracer density:
$$
\nu(r) = \nu_0 \, r^{-\gamma} \\
\int \nu(r) \, r^2 \, {\rm d}r
$$

Background potential is Keplerian
$$
v_c(r) = \sqrt{\frac{G M}{r}}
$$

I'm going to work in some crazy-ass units:

In [None]:
units = UnitSystem(u.mpc, u.yr, u.Msun, u.rad)
G = _G.decompose(units).value

In [None]:
n_data = 64
Mbh = 4E6
gamma = 3.5
a = 2. # mpc
b = 1000. # mpc = 1 pc

In [None]:
true_r, true_v = np.load('rv.npy')[:,:n_data]

In [None]:
r_err = np.full(n_data, 0.01)
r = np.random.normal(true_r, r_err)

v_err = np.full(n_data, 0.01)
v = np.random.normal(true_v, v_err)

## Test with `emcee`

First, let's implement a simplified version of the model to use `emcee` to do the sampling in `Mbh`, `gamma`, `beta`:

In [None]:
def rel_potential(r, Mbh):
    return G*Mbh/r

def rv_to_E(r, v, Mbh):
    return -0.5*v**2 + G*Mbh/r

def log_df(E, L, Mbh, gamma, beta):
    g = gamma
    
    num = -2*beta*np.log(L) + (g - beta-1.5) * np.log(E) + beta*np.log(2)
    den = 1.5*np.log(2*np.pi) + (g - 2*beta) * np.log(G*Mbh)
    gams = loggamma(g - 2*beta + 1) - loggamma(g - beta - 0.5)

    return (num - den + gams).real

# def df(E, L, Mbh, gamma, beta):
#     num = L**(-2*beta) * E**(gamma - beta - 1.5)
#     den = np.sqrt(8*np.pi**3 * 2**(-2*beta)) * (G*Mbh)**(gamma-2*beta)
#     gams = Gamma(gamma - 2*beta + 1) / Gamma(gamma - beta - 0.5)
#     return num/den*gams

In [None]:
from math import exp, log

def func(v, r, Mbh, gamma):
    E = rv_to_E(r, v, Mbh)
    if E <= 0 or v <= 0:
        return 0
    return exp(log_df(E, 1., Mbh, gamma, 0) + 2*log(v))

# vs = np.linspace(0, 32, 1024)
# plt.plot(vs, [func(v, 1., Mbh, gamma) for v in vs])

rs = np.logspace(-3, 2, 128)
dens = []
for r in rs:
    val,_ = quad(func, a=0, b=np.inf, args=(r, Mbh, gamma))
    dens.append(val)

In [None]:
plt.figure(figsize=(6,6))
plt.loglog(rs, dens)
plt.loglog(rs, rs**-3.5)

In [None]:
class Model:
    
    parameters = ['Mbh', 'gamma', 'beta']
    
    def __init__(self, r, v, **kw):
        
        frozen = dict()
        for name in self.parameters:
            val = kw.get(name)
            if val is not None:
                frozen[name] = val
        self.frozen = frozen
        
        self.r = r
        self.v = v
        self._jac =  2*np.log(self.r) + 2*np.log(self.v)
        
    def pack(self, **kwargs):
        vec = [kwargs.get(name, self.frozen.get(name)) 
               for name in self.parameters]
        return np.array(vec)
    
    def unpack(self, vec):
        pars = dict()
        
        i = 0
        for name in self.parameters:
            if name in self.frozen:
                pars[name] = self.frozen[name]
            
            else:
                pars[name] = vec[i]
                i += 1
        
        return pars
    
    def ln_likelihood(self, **pars):
        E = rv_to_E(self.r, self.v, pars['Mbh'])
        L = 1. # HACK
        
        val = log_df(E, L, pars['Mbh'], pars['gamma'], pars['beta'])
        
#         jac = 2*np.log(self.r) + 2*log(true_v[n]);
        jac = self._jac # HACK
        
        val += jac 
        
        return val
    
    def ln_prior(self, **pars):
        if pars['Mbh'] < 1E6 or pars['Mbh'] > 1E7:
            return -np.inf
        
        if pars['gamma'] < 0.5 or pars['gamma'] > 5.:
            return -np.inf
        
        return 0.
        
    def ln_posterior(self, **pars):
        lp = self.ln_prior(**pars)
        if not np.isfinite(lp):
            return -np.inf
        
        ll = self.ln_likelihood(**pars)
        if np.any(np.logical_not(np.isfinite(ll))):
            return -np.inf
        
        return lp + ll.sum()
        
    def __call__(self, p):
        pars = self.unpack(p)
        return self.ln_posterior(**pars)

In [None]:
model = Model(r=true_r, v=true_v, beta=0.)

In [None]:
gs = np.linspace(0.5, 4.5, 64)
lls = []
for g in gs:
    lls.append(model.ln_likelihood(Mbh=4E6, gamma=g, beta=0.).sum())
    
plt.plot(gs, lls)

In [None]:
n_walkers = 32

p0 = np.zeros((32, 2))
p0[:,0] = np.random.normal(5E6, 1E4, n_walkers)
p0[:,1] = np.random.uniform(3., 5, n_walkers)

sampler = emcee.EnsembleSampler(n_walkers, p0.shape[1], model)
pos,_,_ = sampler.run_mcmc(p0, 1000)
sampler.reset()
_ = sampler.run_mcmc(pos, 4000)

In [None]:
for dim in range(sampler.chain.shape[-1]):
    plt.figure()
    for walker in sampler.chain[...,dim]:
        plt.plot(walker, marker='', drawstyle='steps', alpha=0.2)

---

## Now with Stan

In [None]:
data_dict = dict(r=r, r_err=r_err,
                 v=v, v_err=v_err,
                 N=n_data)

In [None]:
sm = pystan.StanModel('simple_model.stan')

In [None]:
n_chains = 1
# init_dict = [dict(true_r=r,
#                   true_v=v,
#                   Mbh=true_Mbh,
# #                   phi0=50000.,
# #                   gamma=true_gamma
#                  )
#              for _ in range(n_chains)]

data_dict = dict(true_r=true_r[:n_data], 
                 true_v=true_v[:n_data],
                 N=n_data)

init_dict = [dict(Mbh=Mbh,
                  gamma=gamma
                 )
             for _ in range(n_chains)]

In [None]:
fit = sm.sampling(data=data_dict, algorithm='HMC', iter=32, init=init_dict, 
                  chains=n_chains, n_jobs=1)

In [None]:
with suppress_stdout_stderr():
    fit = sm.sampling(data=data_dict, algorithm='HMC', iter=8192, init=init_dict, 
                      chains=n_chains, n_jobs=1)

In [None]:
_ = fit.traceplot('Mbh')

In [None]:
plot_pars = ['Mbh']# 'gamma']
samples = fit.extract(plot_pars)

plt.figure()
plt.hist(samples['Mbh'], color='#666666') #, bins=np.logspace(4, 6, 32))
plt.axvline(Mbh, color='r')
# plt.xscale('log')

# plt.figure()
# plt.hist(samples['a_g'], color='#666666', bins=np.linspace(1, 30, 16))
# plt.axvline(true_alpha/true_gamma, color='r')

# plt.figure()
# plt.hist(samples['gamma'], color='#666666', bins=np.linspace(0, 1, 12))
# plt.axvline(true_gamma, color='r')