In [None]:
import os

# Third-party
from astropy.table import Table
import astropy.coordinates as coord
from astropy.constants import G as _G
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('apw-notebook')
%matplotlib inline
from scipy.special import loggamma
from gala.units import UnitSystem

import pystan
from stderr_helper import suppress_stdout_stderr

Make some fake data

Tracer density:
$$
\nu(r) = \nu_0 \, r^{-\gamma} \\
\int \nu(r) \, r^2 \, {\rm d}r
$$

Background potential is Keplerian
$$
v_c(r) = \sqrt{\frac{G M}{r}}
$$

I'm going to work in some crazy-ass units:

In [None]:
units = UnitSystem(u.mpc, u.yr, u.Msun, u.rad)
G = _G.decompose(units).value

In [None]:
true_gamma = 2.5
true_Mbh = 4E6 # Msun

In [None]:
def potential(r, Mbh):
    return G*Mbh / r

def rv_to_E(r, v, Mbh):
    return -0.5*v**2 + potential(r, Mbh)

def v_c(r, Mbh):
    return np.sqrt(G * Mbh / r)

def log_df(E, Mbh, g):
    E = np.atleast_1d(E)
    term1 = (g - 3/2.)*np.log(E) - 0.5*np.log(8*np.pi**3) - g*np.log(G * Mbh)
    term2 = loggamma(g + 1) - loggamma(g - 0.5)
    res = term1 + term2
    res[E <= 0] = -np.inf
    return res.real

In [None]:
from math import log, exp
import emcee

def log_f(p, Mbh, gamma):
    log_r, log_v = p
    r = exp(log_r)
    v = exp(log_v)
    E = -0.5*v**2 + G*Mbh/r
    
    if E < 0 or log_r > 5 or log_r < -6:
        return -np.inf
    
    return (gamma - 1.5)*log(E) + log(r) + log(v) # + 2*log(r) + 2*log(v)

In [None]:
sampler = emcee.EnsembleSampler(32, 2, log_f, args=(true_Mbh, true_gamma))
p0 = np.log(np.random.uniform(0, 1, size=(32, 2)))
assert np.all(np.isfinite([log_f(p0[i], true_Mbh, true_gamma) 
                           for i in range(len(p0))]))

pos,_,_ = sampler.run_mcmc(p0, 4000)
sampler.reset()
_ = sampler.run_mcmc(p0, 4000)

In [None]:
for dim in range(sampler.chain.shape[-1]):
    plt.figure()
    for walker in sampler.chain[...,dim]:
        plt.plot(walker, marker='', drawstyle='steps-mid', alpha=0.25)

In [None]:
true_r, true_v = np.exp(np.vstack(sampler.chain[:,::64])).T

In [None]:
plt.hist(rv_to_E(true_r, true_v, true_Mbh), bins=np.logspace(-1, 4, 32));
plt.xscale('log')

In [None]:
(true_r.min()*u.mpc).to(u.au)

In [None]:
(true_v.max()*u.mpc/u.yr).to(u.km/u.s)

In [None]:
fig,axes = plt.subplots(1, 2, figsize=(10,5))

axes[0].hist(true_r, bins=np.logspace(-3, 1, 32));
axes[0].set_xscale('log')
axes[0].set_yscale('log')

axes[1].hist(true_v, bins=np.logspace(-3, 3, 32));
axes[1].set_xscale('log')
axes[1].set_yscale('log')

In [None]:
n_data = 32

r_err = np.full(n_data, 0.01)
r = np.random.normal(true_r[:n_data], r_err)

v_err = np.full(n_data, 0.01)
v = np.random.normal(true_v[:n_data], v_err)

In [None]:
rv_to_E(r, v, true_Mbh), r, v

In [None]:
data_dict = dict(r=r, r_err=r_err,
                 v=v, v_err=v_err,
                 N=n_data)

In [None]:
sm = pystan.StanModel('simple_model.stan')

In [None]:
n_chains = 1
# init_dict = [dict(true_r=r,
#                   true_v=v,
#                   Mbh=true_Mbh,
# #                   phi0=50000.,
# #                   gamma=true_gamma
#                  )
#              for _ in range(n_chains)]

# HACK:
n_data = 128
data_dict = dict(true_r=true_r[:n_data], 
                 true_v=true_v[:n_data],
                 N=n_data)

init_dict = [dict(Mbh=true_Mbh,
#                   phi0=50000.,
#                   gamma=true_gamma
                 )
             for _ in range(n_chains)]

In [None]:
fit = sm.sampling(data=data_dict, algorithm='HMC', iter=32, init=init_dict, 
                  chains=n_chains, n_jobs=1)

In [None]:
with suppress_stdout_stderr():
    fit = sm.sampling(data=data_dict, algorithm='HMC', iter=8192, init=init_dict, 
                      chains=n_chains, n_jobs=1)

In [None]:
_ = fit.traceplot('Mbh')

In [None]:
plot_pars = ['Mbh']# 'gamma']
samples = fit.extract(plot_pars)

plt.figure()
plt.hist(samples['Mbh'], color='#666666') #, bins=np.logspace(4, 6, 32))
plt.axvline(true_Mbh, color='r')
# plt.xscale('log')

# plt.figure()
# plt.hist(samples['a_g'], color='#666666', bins=np.linspace(1, 30, 16))
# plt.axvline(true_alpha/true_gamma, color='r')

# plt.figure()
# plt.hist(samples['gamma'], color='#666666', bins=np.linspace(0, 1, 12))
# plt.axvline(true_gamma, color='r')