In [None]:
# Third-party
from astropy.table import Table
import astropy.coordinates as coord
import astropy.units as u
from astropy.constants import G
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import numpy as np
# plt.style.use('notebook.mplstyle')
plt.style.use('apw-notebook')
%matplotlib inline
import corner
import emcee
from scipy.integrate import quad
from scipy.misc import logsumexp
import schwimmbad
import tqdm

In [None]:
brewer = Table.read('../data/brewer.csv')

In [None]:
brewer

In [None]:
tgas = Table.read('../data/tgas.csv')
tgas['hd_id'] = ['240430', '240429']
tgas['rv'] = [brewer['Vrad'][brewer['Name'] == 'HD 240430'][0], 
              brewer['Vrad'][brewer['Name'] == 'HD 240429'][0]]
tgas['rv_error'] = 0.1 # km/s

In [None]:
def get_y_hat(row, names=['ra', 'dec', 'parallax', 'pmra', 'pmdec', 'ra'], units=None):
    y = np.zeros(len(names))
        
    default_units = dict()
    default_units['ra'] = u.degree
    default_units['dec'] = u.degree
    default_units['parallax'] = u.mas
    default_units['pmra'] = u.mas/u.yr
    default_units['pmdec'] = u.mas/u.yr
    default_units['rv'] = u.km/u.s
    
    if units is None:
        units = [default_units[name] for name in names]
    
    for i,name in enumerate(names):
        y[i] = (row[name]*default_units[name]).to(units[i]).value
    
    return y

In [None]:
def get_cov(row, names=['ra', 'dec', 'parallax', 'pmra', 'pmdec', 'ra'], units=None):
    
    default_err_units = dict()
    default_err_units['ra'] = u.mas
    default_err_units['dec'] = u.mas
    default_err_units['parallax'] = u.mas
    default_err_units['pmra'] = u.mas/u.yr
    default_err_units['pmdec'] = u.mas/u.yr
    default_err_units['rv'] = u.km/u.s
    
    if units is None:
        units = [default_err_units[name] for name in names]
    
    C = np.zeros((len(names), len(names)))

    # pre-load the diagonal
    for i,name in enumerate(names):
        full_name = "{}_error".format(name)
        C[i,i] = (row[full_name]*default_err_units[name]).to(units[i]).value**2

    for i,name1 in enumerate(names):
        for j,name2 in enumerate(names):
            if j <= i:
                continue
                
            if full_name not in row: # skip if no correlations exist
                continue
                
            full_name = "{}_{}_corr".format(name1, name2)
            u_old = default_err_units[name1]*default_err_units[name2]
            u_new = units[i]*units[j]
            C[i,j] = (row[full_name] * np.sqrt(C[i,i]*C[j,j]) * u_old).to(u_new).value
            C[j,i] = (row[full_name] * np.sqrt(C[i,i]*C[j,j]) * u_old).to(u_new).value
            
    return C

In [None]:
class ProbModel(object):
        
    def ln_posterior(self, pars):
        """ 
        Up to a normalization constant, the log of the posterior pdf is just 
        the sum of the log likelihood plus the log prior.
        """
        lnp = self.ln_prior(pars)
        if np.isinf(lnp): # short-circuit if the prior is infinite (don't bother computing likelihood)
            return lnp

        lnL = self.ln_likelihood(pars).sum()
        lnprob = lnp + lnL

        if np.isnan(lnprob):
            return -np.inf

        return lnprob
    
    def __call__(self, pars):
        return self.ln_posterior(pars)

### Component 1: wide binary

The stars are a wide binary, drawn from some separation distribution 

### Component 2: co-moving pair

The stars are co-moving but not necessarily *bound*

### Component 3: chance alignment

The stars are individually drawn from the field population

---

In [None]:
kmspc_to_masyr = 210.94953
masyr_to_kmspc = 1/kmspc_to_masyr

In [None]:
def w_to_y(w):
    x = w[:3]
    v = w[3:]
    dist = np.linalg.norm(x)
    y = np.array([np.arctan2(x[1], x[0]) % (2*np.pi), # rad
                  np.arcsin(x[2] / dist), # rad
                  1000. / dist, # mas
                  v[0] / dist * kmspc_to_masyr, # mas/yr
                  v[1] / dist * kmspc_to_masyr, # mas/yr
                  v[2]]) # km/s
    return y

def y_to_w(y):
    dist = 1000. / y[2] # pc
    w = np.array([dist * np.cos(y[0]) * np.cos(y[1]), # pc
                  dist * np.sin(y[0]) * np.cos(y[1]), # pc
                  dist * np.sin(y[1]), # pc
                  y[3] * dist / kmspc_to_masyr, # km/s
                  y[4] * dist / kmspc_to_masyr, # km/s
                  y[5]]) # km/s
    return w

def ln_gaussian(x, mu, var):
    return -0.5*((x-mu)**2/var + np.log(2*np.pi*var))

In [None]:
class MixtureModel(ProbModel):
    
    def __init__(self, tgas_rows, mass=[1., 1.]*u.Msun, mass_err=[0.01,0.01]*u.Msun,
                 V1=(0.1*u.km/u.s)**2, V2=(0.1*u.km/u.s)**2, V3=(25.*u.km/u.s)**2):
        """
        TODO: the right thing to do is to rotate (vra,vdec,vr) to (vx,vy,vz), but
            this is ok at the ~10 m/s level (given their small sky separation)
            
        TODO: update masses
        
        V1 : quantity_like
            Velocity variance of bound stars, added to account for eccentricity.
        V2 : quantity_like
            Velocity variance of co-moving star pairs, i.e. something like 
            the velocity difference at ~0.1 pc when they become unbound.
        V3 : quantity_like    
            Velocity variance assumed for field population, i.e. disk stars.
        """ 
        assert len(tgas_rows) == 2
        
        self.y_hats = []
        self.Covs = []
        self.Cinvs = []
        self._uvecs = []
        self._logdets = []
        
        for row in tgas_rows:
            y_hat = get_y_hat(row, names=['ra', 'dec', 'parallax', 'pmra', 'pmdec', 'rv'],
                              units=[u.rad, u.rad, u.mas, u.mas/u.yr, u.mas/u.yr, u.km/u.s])
            Cov = get_cov(row, names=['ra', 'dec', 'parallax', 'pmra', 'pmdec', 'rv'],
                          units=[u.rad, u.rad, u.mas, u.mas/u.yr, u.mas/u.yr, u.km/u.s])
            _,log_det = np.linalg.slogdet(2*np.pi*Cov)
            
            rep = coord.UnitSphericalRepresentation(lon=row['ra']*u.deg, 
                                                    lat=row['dec']*u.deg)
            uvec = rep.represent_as(coord.CartesianRepresentation).xyz.value
            
            
            self.y_hats.append(y_hat)
            self.Covs.append(Cov)
            self.Cinvs.append(np.linalg.inv(Cov))
            self._uvecs.append(uvec)
            self._logdets.append(log_det)
        
        # masses
        self.mass = mass.to(u.Msun).value
        self.mass_err = mass_err.to(u.Msun).value
        self._tot_mass = np.sum(self.mass)
        self._Mred = np.prod(self.mass) / self._tot_mass # reduced mass
        self._G = G.to(u.pc/u.Msun*u.km**2/u.s**2).value
        
        # sky separation
        coords = coord.SkyCoord(ra=tgas_rows['ra']*u.deg, 
                                dec=tgas_rows['dec']*u.deg)
        self._cos_sep = np.cos(coords[1].separation(coords[0]))
                
        # some assumed hyperparameters
        kms_sq = (u.km/u.s)**2
        self.V1 = V1.to(kms_sq).value
        self.V2 = V2.to(kms_sq).value
        self.V3 = V3.to(kms_sq).value
    
    # ======
    # Priors
    # ======
    def ln_p_dx(self, dx, a_min=1E-4, a_max=1E1): # pc
        a_x = np.linalg.norm(dx)
        if a_x < a_min or a_x > a_max:
            return -np.inf
        
        ln_p = 0.
        
        # "radial" term
        C = -np.log(a_max**4 - a_min**4)
        ln_p += C - 3*np.log(a_x)
        
#         # angle terms
#         ln_p += -np.log(2*np.pi)
#         ln_p += -np.log(2.)
        
#         # Jacobian
#         theta = np.arcsin(dx[2] / a_x)
#         ln_p += 2*np.log(a_x) + np.log(np.sin(theta % np.pi))
        
        return ln_p
    
    def ln_p_dv(self, dv, dx): # km/s
        a_x = np.linalg.norm(dx)
        a_v = np.linalg.norm(dv)
        
        ln_p = 0.
        
        # mean velocity computed from orbital separation
        mean_dv = np.sqrt(self._G * self._tot_mass / a_x) # km/s
        
        # "radial" term
        ln_p += ln_gaussian(a_v, mean_dv, self.V1)
        
#         # angle terms
#         ln_p += -np.log(2*np.pi)
#         ln_p += -np.log(2)
        
#         # Jacobian
#         theta = np.arcsin(dv[2] / a_v)
#         ln_p += 2*np.log(a_v) + np.log(np.sin(theta % np.pi))
        
        return ln_p
    
    def ln_p_x(self, x, x_max=1000., x_min=-1000.): #pc
        if np.any(x < x_min) or np.any(x > x_max):
            return -np.inf
        return -len(x) * np.log(x_max-x_min)

    def ln_p_v(self, v, mu=0., var=None): # km/s
        if var is None:
            var = self.V3 # field
        return -0.5 * np.sum((v-mu)**2 / var + np.log(2*np.pi*var))
    
    def unpack_pars(self, pars):
        """
        We sample over the 6D Cartesian phase-space parameters of the barycenter
        of the pair and the separation from star1 to star2 -- that is, the 
        separation is defined as positive from star1 to star2. 
        
        Positions are in [pc]
        Velocities in [km/s]
        Masses in [Msun]
        """
        (x1, y1, z1, vx1, vy1, vz1, 
         x2, y2, z2, vx2, vy2, vz2, 
         M1, M2, # masses
         f1, f2) = pars
        
        # construct 6D vectors of phase-space coordinates
        w1 = np.array([x1, y1, z1, vx1, vy1, vz1])
        w2 = np.array([x2, y2, z2, vx2, vy2, vz2])
        dw = w2 - w1
        
        # mass ratio factors
        fac1 = self._Mred / self.mass[0]
        
        # position, velocity of star 1, star2:
        w_bary = w1 + fac1*dw
        
        mix_weights = [f1, f2, 1-(f1+f2)]
        
        return w1, w2, w_bary, dw, [M1, M2], mix_weights
    
    def ln_prior1(self, w1, w2, w_bary, dw):
        """
        Compute the log-prior for component 1 of the mixture model:
        wide binary
        """
        ln_p = 0.
        
        # prior terms
        ln_p += self.ln_p_dx(dw[:3])
        ln_p += self.ln_p_dv(dw[3:], dx=dw[:3])
        ln_p += self.ln_p_x(w_bary[:3])
        ln_p += self.ln_p_v(w_bary[3:])
        
        if not np.isfinite(ln_p):
            return -np.inf
        
        return ln_p
        
    def ln_prior2(self, w1, w2, w_bary, dw):
        """
        Compute the log-prior for component 2 of the mixture model:
        unbound but comoving
        """
        ln_p = 0.
        
        # prior terms
        ln_p += self.ln_p_x(w1[:3])
        ln_p += self.ln_p_x(w2[:3])
        ln_p += self.ln_p_v(w1[3:])
        ln_p += self.ln_p_v(w2[3:], w1[3:], self.V2)
        
        if not np.isfinite(ln_p):
            return -np.inf
        
        return ln_p
    
    def ln_prior3(self, w1, w2, w_bary, dw):
        """
        Compute the log-prior for component 3 of the mixture model:
        two independent draws from the field population
        """
        ln_p = 0.
        
        # prior terms
        ln_p += self.ln_p_x(w1[:3])
        ln_p += self.ln_p_x(w2[:3])
        ln_p += self.ln_p_v(w1[3:])
        ln_p += self.ln_p_v(w2[3:])
        
        if not np.isfinite(ln_p):
            return -np.inf
        
        return ln_p
    
    def ln_prior(self, pars):        
        w1, w2, w_bary, dw, _, mix_weights = self.unpack_pars(pars)
        
        ln_p = 0.
        
        # uniform prior on weights
        if (mix_weights[0] < 0 or mix_weights[0] > 1. or 
            mix_weights[1] < 0 or mix_weights[1] > 1 or 
            (mix_weights[0]+mix_weights[1]) > 1):
            return -np.inf
        
        lnprob1 = self.ln_prior1(w1, w2, w_bary, dw)
        lnprob2 = self.ln_prior2(w1, w2, w_bary, dw)
        lnprob3 = self.ln_prior3(w1, w2, w_bary, dw)
        lnprobs = [lnprob1, lnprob2, lnprob3]
        
        if np.any(np.logical_not(np.isfinite(lnprobs))):
            return np.array([-np.inf])
        
        return logsumexp(lnprobs, b=mix_weights)
        
    def ln_likelihood(self, pars):
        w1, w2, _, _, masses, _ = self.unpack_pars(pars)
        ws = [w1, w2]
        
        ln_l = 0.
        for i in range(2):
            # convert Cartesian to [ra, dec, parallax, mu_ra, mu_dec, rv]
            y = w_to_y(ws[i])
        
            # difference in data space
            dy = self.y_hats[i] - y
        
            # kinematic data
            ln_l += -0.5 * self._logdets[i] - 0.5 * dy.T @ self.Cinvs[i] @ dy
              
            # mass
            ln_l += ln_gaussian(masses[i], self.mass[i], self.mass_err[i])
        
        return ln_l
    
    def get_p0(self, size=1, scale=1E-3):
        try:
            len(size)
        except:
            size = (size,)
        p0 = np.zeros(size+(16,))
        
        y_hats = [np.random.multivariate_normal(yh, (1E-3)**2*Cov, size=size)
                  for yh,Cov in zip(self.y_hats, self.Covs)]
        
        p0[...,:6] = y_to_w(y_hats[0].T).T
        p0[...,6:12] = y_to_w(y_hats[1].T).T
        p0[...,12:14] = np.random.normal(self.mass, self.mass_err*scale, size=size+(2,))
        p0[...,14:16] = np.random.normal([0.5, 0.1], scale, size=size+(2,))
        
        return np.squeeze(p0)

In [None]:
model = MixtureModel(tgas, mass=[1., 1.07]*u.Msun, mass_err=[0.01, 0.01]*u.Msun, V1=(1E-3*u.km/u.s)**2)

In [None]:
p0 = model.get_p0()
ndim = len(p0)
model.ln_likelihood(p0), model.ln_prior(p0)

## PTSampler instead

In [None]:
ntemps = 32
betas = np.logspace(0, -8, ntemps)
nwalkers = 128

In [None]:
n_burnin = 1024
n_mcmc = 4096

In [None]:
%%time

all_p0 = model.get_p0(size=(ntemps,nwalkers))
ndim = all_p0.shape[-1]
with schwimmbad.MultiPool() as pool:
    pt_sampler = emcee.PTSampler(ntemps, nwalkers, ndim, 
                                 model.ln_likelihood, model.ln_prior,
                                 betas=betas, pool=pool)
    
    for res in tqdm.tqdm_notebook(pt_sampler.sample(all_p0, iterations=n_burnin), desc='Burn-in'):
        pass
    
    pos,_,_ = res
    pt_sampler.reset()
    
    for res in tqdm.tqdm_notebook(pt_sampler.sample(pos, iterations=n_mcmc), desc='Production'):
        pass

In [None]:
n_plot_walkers = 32
alpha = 0.1

fig,axes = plt.subplots(6, 3, figsize=(15,15), sharex=True)

temp = 0

for k in range(0,6):
    for j in range(n_plot_walkers):
        axes[k,0].plot(pt_sampler.chain[temp,j,:,k], marker='', alpha=alpha, 
                       drawstyle='steps-mid', color='k')
        
for k in range(0,6):
    for j in range(n_plot_walkers):
        axes[k,1].plot(pt_sampler.chain[temp,j,:,k+6], marker='', alpha=alpha, 
                       drawstyle='steps-mid', color='k')
        
for k in range(0,4):
    for j in range(n_plot_walkers):
        axes[k,2].plot(pt_sampler.chain[temp,j,:,k+12], marker='', alpha=alpha, 
                       drawstyle='steps-mid', color='k')

fig.tight_layout()

axes[-1,-1].set_visible(False)

# fig.savefig('../plots/trace_pt.png')

In [None]:
flatchain = np.vstack(pt_sampler.chain[0,:,::8])
flatchain = np.vstack((flatchain.T, 1-(flatchain[:,-2]+flatchain[:,-1]))).T

In [None]:
fig = corner.corner(flatchain, bins=64, plot_datapoints=False)
#                     labels=['$r_1$ [pc]', r'$v_{\alpha,1}$ [km/s]', r'$v_{\delta,1}$ [km/s]', '$v_{r,1}$ [km/s]', 
#                             '$r_2$ [pc]', r'$v_{\alpha,2}$ [km/s]', r'$v_{\delta,2}$ [km/s]', '$v_{r,2}$ [km/s]', 
#                             r'$\ln a/{\rm pc}$', '$f_1$ (bound)', '$f_2$ (comoving)', '$f_3$ (field)'])
# fig.savefig('../plots/corner.png')

In [None]:
dx = np.linalg.norm(flatchain[:,6:9]-flatchain[:,0:3], axis=-1)
plt.hist(dx, bins=np.logspace(-3, 1, 16))
plt.xscale('log')
plt.axvline(0.6)

In [None]:
# np.save('pt_chain_temp0.npy', pt_sampler.chain[0])

Standard ensemble sampler:

In [None]:
nwalkers = 256
n_burnin = 512
n_mcmc = 1024

In [None]:
# %%time

# all_p0 = model.get_p0(size=nwalkers)
# ndim = all_p0.shape[1]
# with schwimmbad.MultiPool() as pool:
#     sampler = emcee.EnsembleSampler(nwalkers, ndim, model, pool=pool)
    
#     for res in tqdm.tqdm_notebook(sampler.sample(all_p0, iterations=n_burnin), desc='Burn-in'):
#         pass
    
#     pos,_,_ = res
#     sampler.reset()
    
#     for res in tqdm.tqdm_notebook(sampler.sample(pos, iterations=n_mcmc), desc='Production'):
#         pass

In [None]:
# n_plot_walkers = 128
# alpha = 0.1

# fig,axes = plt.subplots(6, 3, figsize=(15,15), sharex=True)

# for k in range(0,6):
#     for j in range(n_plot_walkers):
#         axes[k,0].plot(sampler.chain[j,:,k], marker='', alpha=alpha, 
#                        drawstyle='steps-mid', color='k')
        
# for k in range(0,6):
#     for j in range(n_plot_walkers):
#         axes[k,1].plot(sampler.chain[j,:,k+6], marker='', alpha=alpha, 
#                        drawstyle='steps-mid', color='k')
        
# for k in range(0,4):
#     for j in range(n_plot_walkers):
#         axes[k,2].plot(sampler.chain[j,:,k+12], marker='', alpha=alpha, 
#                        drawstyle='steps-mid', color='k')

# fig.tight_layout()

# axes[-1,-1].set_visible(False)

# fig.savefig('../plots/trace.png')

In [None]:
# # flatchain = sampler.flatchain[:,[0,1,2,3,8,9,10]]
# flatchain = sampler.flatchain
# flatchain = np.vstack((flatchain.T, 1-(flatchain[:,-2]+flatchain[:,-1]))).T

# fig = corner.corner(flatchain, bins=64, plot_datapoints=False,
#                     labels=['$r_1$ [pc]', r'$v_{\alpha,1}$ [km/s]', r'$v_{\delta,1}$ [km/s]', '$v_{r,1}$ [km/s]', 
#                             '$r_2$ [pc]', r'$v_{\alpha,2}$ [km/s]', r'$v_{\delta,2}$ [km/s]', '$v_{r,2}$ [km/s]', 
#                             r'$\ln a/{\rm pc}$', '$f_1$ (bound)', '$f_2$ (comoving)', '$f_3$ (field)'])
# del fig