In [None]:
from os import path
import sys
# if '/mnt/home/apricewhelan/projects/stellarstreams/' not in sys.path:
#     sys.path.append('/mnt/home/apricewhelan/projects/stellarstreams/')
if '/Users/adrian/projects/stellarstreams/' not in sys.path:
    sys.path.append('/Users/adrian/projects/stellarstreams/')

# Third-party
import astropy.coordinates as coord
from astropy.table import Table, vstack
from astropy.io import fits, ascii
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
from scipy.optimize import minimize

import emcee
from pyia import GaiaData
import schwimmbad

import gala.coordinates as gc
import gala.dynamics as gd
from gala.dynamics import mockstream
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic
from gala.mpl_style import center_emph

from potential import default_mw

In [None]:
mw = default_mw
galcen_frame = coord.Galactocentric(galcen_distance=8.1*u.kpc)

In [None]:
tbly = Table.read('/Users/adrian/data/streams/Pal5/Odenkirchen2002_gaia.csv')
tbly = tbly[(tbly['pmra'] < 0) & (tbly['pmdec'] < 0)]
(np.sum(tbly['pmra'] / tbly['pmra_error']**2) / np.sum(1 / tbly['pmra_error']**2),
 np.sum(tbly['pmdec'] / tbly['pmdec_error']**2) / np.sum(1 / tbly['pmdec_error']**2))

In [None]:
# t1 = Table.read('/Users/adrian/data/streams/Pal5/Odenkirchen2002_gaia.csv')
t2 = Table.read('/Users/adrian/data/streams/Pal5/Odenkirchen2009_gaia.csv')
# t = vstack((t1, t2))
t = t2
g = GaiaData(t)

In [None]:
c_icrs = g.get_skycoord(distance=False, radial_velocity=g.vr_a * u.km/u.s)
c = c_icrs.transform_to(gc.Pal5PriceWhelan18)

In [None]:
C_icrs = g.get_cov()
C = gc.transform_pm_cov(c_icrs, C_icrs[:, 3:5, 3:5], gc.Pal5PriceWhelan18)
pm1_err = np.sqrt(C[:, 0, 0])
pm2_err = np.sqrt(C[:, 1, 1])

In [None]:
from coordinates import pal5_c
pal5_icrs = pal5_c
pal5_c = pal5_icrs.transform_to(gc.Pal5PriceWhelan18)

In [None]:
mask2 = (((c.phi1 < -2*u.deg) & (c.radial_velocity < -62*u.km/u.s)) | 
         ((c.phi1 > -1*u.deg) & (c.radial_velocity > -56*u.km/u.s)))

# the APW by-eye mask of hackiness
mask3 = ((c.phi1 < -5*u.deg) | (c.pm_phi1_cosphi2 < 3.55*u.mas/u.yr))

mask = ((np.abs(c.pm_phi2) < 2*u.mas/u.yr) & np.logical_not(mask2) & np.logical_not(mask3) &
        (c.radial_velocity < -40*u.km/u.s) & (c.radial_velocity > -80*u.km/u.s) & 
        (c.pm_phi1_cosphi2 < 5*u.mas/u.yr))

style = dict(marker='o', color='k', ls='none', ecolor='#aaaaaa')

fig, axes = plt.subplots(4, 1, figsize=(10, 12), 
                         sharex=True)

axes[0].errorbar(c.phi1.value[mask], 
                 c.phi2.degree[mask], **style)
axes[0].scatter(pal5_c.phi1.degree, pal5_c.phi2.degree, zorder=10, color='tab:red')
axes[0].set_ylim(-1, 1)

axes[1].errorbar(c.phi1.value[mask], c.pm_phi1_cosphi2.value[mask], 
                 yerr=pm1_err[mask], **style)
axes[1].scatter(pal5_c.phi1.degree, pal5_c.pm_phi1_cosphi2.value, zorder=10, color='tab:red')
axes[1].set_ylim(0, 8)

axes[2].errorbar(c.phi1.value[mask], c.pm_phi2.value[mask],
                 yerr=pm2_err[mask], **style)
axes[2].scatter(pal5_c.phi1.degree, pal5_c.pm_phi2.value, zorder=10, color='tab:red')
axes[2].set_ylim(-4, 4)

axes[3].errorbar(c.phi1.value[mask], c.radial_velocity.value[mask], 
                 yerr=np.sqrt(g.vr_a_err**2 + 1**2)[mask],
                 **style)
axes[3].scatter(pal5_c.phi1.degree, pal5_c.radial_velocity.value, zorder=10, color='tab:red')
axes[3].set_ylim(-80, -30)

In [None]:
w0 = gd.PhaseSpacePosition(pal5_c.transform_to(galcen_frame).data)

In [None]:
orbit = mw.integrate_orbit(w0, dt=-0.5, n_steps=6000)
stream = gd.mockstream.fardal_stream(mw, orbit[::-1], release_every=16, prog_mass=2e4*u.Msun)
stream_c = stream.to_coord_frame(gc.Pal5PriceWhelan18, galactocentric_frame=galcen_frame)

In [None]:
fig, axes = plt.subplots(4, 1, figsize=(10, 12), 
                         sharex=True)

axes[0].errorbar(c.phi1.value[mask], 
                 c.phi2.degree[mask], **style)
axes[0].scatter(pal5_c.phi1.degree, pal5_c.phi2.degree, zorder=10, color='tab:red')
axes[0].plot(stream_c.phi1.degree, stream_c.phi2.degree, 
             marker='o', ls='none', ms=1.5, color='#666666', alpha=0.5)
axes[0].set_xlim(-10, 10)
axes[0].set_ylim(-1, 1)

axes[1].errorbar(c.phi1.value[mask], c.pm_phi1_cosphi2.value[mask], 
                 yerr=pm1_err[mask], **style)
axes[1].scatter(pal5_c.phi1.degree, pal5_c.pm_phi1_cosphi2.value, zorder=10, color='tab:red')
axes[1].plot(stream_c.phi1.degree, stream_c.pm_phi1_cosphi2.value, 
             marker='o', ls='none', ms=1.5, color='#666666', alpha=0.5)
axes[1].set_ylim(2.5, 5)

axes[2].errorbar(c.phi1.value[mask], c.pm_phi2.value[mask],
                 yerr=pm2_err[mask], **style)
axes[2].scatter(pal5_c.phi1.degree, pal5_c.pm_phi2.value, zorder=10, color='tab:red')
axes[2].plot(stream_c.phi1.degree, stream_c.pm_phi2.value, 
             marker='o', ls='none', ms=1.5, color='#666666', alpha=0.5)
axes[2].set_ylim(-1, 2)

axes[3].errorbar(c.phi1.value[mask], c.radial_velocity.value[mask], 
                 yerr=np.sqrt(g.vr_a_err**2 + 1**2)[mask],
                 **style)
axes[3].scatter(pal5_c.phi1.degree, pal5_c.radial_velocity.value, zorder=10, color='tab:red')
axes[3].plot(stream_c.phi1.degree, stream_c.radial_velocity.value, 
             marker='o', ls='none', ms=1.5, color='#666666', alpha=0.5)
axes[3].set_ylim(-80, -30)

In [None]:
data = Table()

data['phi1'] = c.phi1
data['phi1_ivar'] = 1 / (1*u.mas).to(u.deg)**2

data['phi2'] = c.phi2
data['phi2_ivar'] = 1 / (1*u.mas).to(u.deg)**2

data['distance'] = np.repeat(pal5_c.distance.value, len(c)) * u.kpc
data['distance_ivar'] = 1 / (1.*u.kpc)**2

data['pm_phi1_cosphi2'] = c.pm_phi1_cosphi2
data['pm_phi1_cosphi2_ivar'] = 1 / (pm1_err*u.mas/u.yr)**2 

data['pm_phi2'] = c.pm_phi2
data['pm_phi2_ivar'] = 1 / (pm2_err*u.mas/u.yr)**2

data['radial_velocity'] = c.radial_velocity
data['radial_velocity_ivar'] = 1 / ((t['vr_a_err']**2 + 1) * (u.km/u.s)**2)

od_data = data.filled(fill_value=0)[mask]

### Add in the on-sky stream track fit points:

In [None]:
data = Table()

data['phi1'] = [-13.5 , -12.  ,  -9.75,  -7.5 ,  -6.75, -5.25,  -3.75,  -3.  ,  -0.75] * u.deg
data['phi1_ivar'] = 1 / (1*u.mas).to(u.deg)**2

data['phi2'] = [1.21,  0.93344847, 0.44480698,  0.1561218 ,  0.04700351, 
                -0.24184338, -0.24479993, -0.3153449 , -0.23257328] * u.deg
data['phi2_ivar'] = 1 / (1*u.mas).to(u.deg)**2

data['distance'] = np.full(len(data['phi1']), 0)
data['distance_ivar'] = np.full(len(data['phi1']), 0)

data['pm_phi1_cosphi2'] = np.full(len(data['phi1']), 0)
data['pm_phi1_cosphi2_ivar'] = np.full(len(data['phi1']), 0)

data['pm_phi2'] = np.full(len(data['phi1']), 0)
data['pm_phi2_ivar'] = np.full(len(data['phi1']), 0)

data['radial_velocity'] = np.full(len(data['phi1']), 0)
data['radial_velocity_ivar'] = np.full(len(data['phi1']), 0)

trail_data = data

In [None]:
data = Table()

data['phi1'] = [1.15, 3.4 , 4.15, 5.65, 6.4 , 7.9 ] * u.deg
data['phi1_ivar'] = 1 / (1*u.mas).to(u.deg)**2

data['phi2'] = [0.29975416, 0.64209922, 0.86378061, 1.37319047, 1.59490276, 2.35329475] * u.deg
data['phi2_ivar'] = 1 / (1*u.mas).to(u.deg)**2

data['distance'] = np.full(len(data['phi1']), 0)
data['distance_ivar'] = np.full(len(data['phi1']), 0)

data['pm_phi1_cosphi2'] = np.full(len(data['phi1']), 0)
data['pm_phi1_cosphi2_ivar'] = np.full(len(data['phi1']), 0)

data['pm_phi2'] = np.full(len(data['phi1']), 0)
data['pm_phi2_ivar'] = np.full(len(data['phi1']), 0)

data['radial_velocity'] = np.full(len(data['phi1']), 0)
data['radial_velocity_ivar'] = np.full(len(data['phi1']), 0)

lead_data = data

In [None]:
# data = vstack((od_data, trail_data, lead_data))
data = vstack((trail_data, lead_data))

---

In [None]:
from gala.dynamics.mockstream import fardal_stream

In [None]:
from scipy.interpolate import InterpolatedUnivariateSpline
from scipy.stats import binned_statistic

def get_stream_track(stream_c,
                     phi1_lim=[-180, 180]*u.deg,
                     phi1_binsize=1*u.deg,
                     units=None):

    # All position and velocity component names:
    component_names = (
        list(stream_c.get_representation_component_names().keys()) +
        list(stream_c.get_representation_component_names('s').keys()))

    # If no units are provided:
    if units is None:
        units = dict()

    units['phi1'] = units.get('phi1',
                              getattr(stream_c, component_names[0]).unit)

    phi1 = stream_c.spherical.lon.wrap_at(180*u.deg).to_value(units['phi1'])
    phi1_lim = phi1_lim.to_value(units['phi1'])
    phi1_binsize = phi1_binsize.to_value(units['phi1'])

    phi1_bins = np.arange(phi1_lim[0], phi1_lim[1]+1e-8, phi1_binsize)
    # HACK:
    #phi1_bins = np.concatenate((np.arange(phi1_lim[0], -1, phi1_binsize),
    #                            np.arange(-1, 1, phi1_binsize/8),
    #                            np.arange(1, phi1_lim[1], phi1_binsize)))
    phi1_binc = 0.5 * (phi1_bins[:-1] + phi1_bins[1:])

    means = dict()
    stds = dict()
    mean_tracks = dict()
    std_tracks = dict()

    for k in component_names[1:]:
        val = getattr(stream_c, k)
        if k in units:
            val = val.to_value(units[k])
        else:
            units[k] = val.unit
            val = val.value

        means[k] = binned_statistic(phi1, val,
                                    bins=phi1_bins, statistic='mean')
        stds[k] = binned_statistic(phi1, val,
                                   bins=phi1_bins, statistic='std')

        mask = np.isfinite(means[k].statistic)
        mean_tracks[k] = InterpolatedUnivariateSpline(phi1_binc[mask],
                                                      means[k].statistic[mask])
        mask = np.isfinite(stds[k].statistic)
        std_tracks[k] = InterpolatedUnivariateSpline(phi1_binc[mask],
                                                     stds[k].statistic[mask])

    return mean_tracks, std_tracks

In [None]:
def ln_normal(x, mu, std):
    return -0.5 * (x-mu)**2 / std**2 - 0.5*np.log(2*np.pi) - np.log(std)

def ln_normal_ivar(x, mu, ivar):
    return -0.5 * (x-mu)**2 * ivar - 0.5*np.log(2*np.pi) + 0.5*np.log(ivar)

def get_ivar(ivar, extra_var):
    return ivar / (1 + extra_var * ivar)

def ln_likelihood(p, phi1, pot, data, data_units, frame_comp_names, extra_var, plot=False):
    phi2, dist, pm1, pm2, rv, *other_p = p
    lnM, = other_p
    # vx, vy, vz, 
    # lnMhalo, halo_c = pot_p
    
    M_pal5 = np.exp(lnM)
    if not 8e3 < M_pal5 < 4e5:
        return -np.inf
    # M_pal5 = 2.5e4
    
    # if not 25 < lnMhalo < 29:
    #     return -np.inf
    # if not 0.8 < halo_c < 1.2:
    #     return -np.inf
    # pot = gp.MilkyWayPotential(halo=dict(m=np.exp(lnMhalo), c=halo_c))
    
    # galcen_frame = coord.Galactocentric(galcen_distance=8.1*u.kpc,
    #                                     galcen_v_sun=coord.CartesianDifferential([vx, vy, vz]*u.km/u.s))
    
    c = gc.Pal5PriceWhelan18(phi1=phi1, phi2=phi2*data_units['phi2'],
                             distance=dist*u.kpc,
                             pm_phi1_cosphi2=pm1*u.mas/u.yr,
                             pm_phi2=pm2*u.mas/u.yr,
                             radial_velocity=rv*u.km/u.s)
    w0 = gd.PhaseSpacePosition(c.transform_to(galcen_frame).data)
    
    # Integrate the orbit and generate the stream - set these parameters!:
    orbit = pot.integrate_orbit(w0, dt=-1, n_steps=6000)
    stream = fardal_stream(pot, orbit[::-1], prog_mass=M_pal5*u.Msun, release_every=8)
    stream_c = stream.to_coord_frame(gc.Pal5PriceWhelan18, galactocentric_frame=galcen_frame)
    
    phi1_lim = [-30, 30]*u.deg
    
    tracks, stds = get_stream_track(stream_c,
                                    phi1_lim=phi1_lim,
                                    phi1_binsize=1.5*u.deg,
                                    units=data_units)
    
    if plot:
        fig, axes = plt.subplots(5, 1, figsize=(8, 12), 
                                 sharex=True)

        grid = np.linspace(phi1_lim[0].value, phi1_lim[1].value, 1024)
        for i, name in enumerate(frame_comp_names[1:]):
            ax = axes[i]

            ax.plot(data['phi1'][data[name]!=0], data[name][data[name]!=0], 
                    marker='o', ls='none', color='k', ms=4)

            ax.plot(stream_c.phi1.wrap_at(180*u.deg).degree,
                    getattr(stream_c, name).value, 
                    marker='o', ls='none', color='tab:blue', ms=2, alpha=0.4, zorder=-100)

            ax.plot(grid, tracks[name](grid), marker='', color='tab:orange', alpha=0.5)

            ax.set_ylabel(name, fontsize=12)

        ax.set_xlim(phi1_lim.value)
        axes[0].set_ylim(-1.5, 3)
        axes[1].set_ylim(20, 25)
        axes[2].set_ylim(2, 5.5)
        axes[3].set_ylim(-1, 2)
        axes[4].set_ylim(-75, -20)
        fig.set_facecolor('w')
        
        # -- residuals --
        fig, axes = plt.subplots(5, 1, figsize=(8, 12), 
                                 sharex=True)

        grid = np.linspace(phi1_lim[0].value, phi1_lim[1].value, 1024)
        for i, name in enumerate(frame_comp_names[1:]):
            ax = axes[i]
            
            ivar = get_ivar(data[name+'_ivar'],
                            extra_var[name]) 
            ax.errorbar(data['phi1'][ivar > 0.], 
                        data[name][ivar > 0] - tracks[name](data['phi1'][ivar > 0.]),
                        yerr=1/np.sqrt(ivar[ivar > 0.]), 
                        marker='o', ls='none', color='k', ecolor='#aaaaaa')
            ax.axhline(0.)
            ax.set_ylabel(name, fontsize=12)

        ax.set_xlim(phi1_lim.value)
        axes[0].set_ylim(-1, 1)
        axes[1].set_ylim(-4, 4)
        axes[2].set_ylim(-2, 2)
        axes[3].set_ylim(-2, 2)
        axes[4].set_ylim(-10, 10)
        fig.set_facecolor('w')
    
    lls = []
    # for name in frame_comp_names[1:]: # skip phi1
    for name in ['phi2']: # HACK: just fit sky track
        ivar = get_ivar(data[name+'_ivar'],
                        stds[name](data['phi1'])**2 + extra_var[name])
        ll = ln_normal_ivar(tracks[name](data['phi1']),
                            data[name], ivar)
        ll[~np.isfinite(ll)] = np.nan
        lls.append(ll)

    return np.nansum(lls, axis=0).sum()

def neg_ln_likelihood(*args, **kwargs):
    return -ln_likelihood(*args, **kwargs)

In [None]:
p0 = [pal5_c.phi2.degree,
      pal5_c.distance.kpc,
      pal5_c.pm_phi1_cosphi2.value,
      pal5_c.pm_phi2.value,
      pal5_c.radial_velocity.value,
      np.log(2.5e4)]
# 11.1, 232.24, 7.25,
#      np.log(mw['halo'].parameters['m'].value), 1] 

data_units = {'phi1': u.deg, 'phi2': u.deg, 'distance': u.kpc,
              'pm_phi1_cosphi2': u.mas/u.yr, 'pm_phi2': u.mas/u.yr,
              'radial_velocity': u.km/u.s}

extra_var = dict()
extra_var['phi2'] = (0.02 * u.deg)**2
extra_var['distance'] = (3 * u.kpc)**2
extra_var['pm_phi1_cosphi2'] = (0.25 * u.mas/u.yr)**2
extra_var['pm_phi2'] = (0.5 * u.mas/u.yr)**2
extra_var['radial_velocity'] = (1 * u.km/u.s)**2

frame_comp_names = (
    list(pal5_c.get_representation_component_names().keys()) +
    list(pal5_c.get_representation_component_names('s').keys()))

_extra_var = dict()
for k in extra_var:
    _extra_var[k] = extra_var[k].to_value(data_units[k]**2)

args = (pal5_c.phi1, mw, data, data_units, frame_comp_names, _extra_var)

In [None]:
ln_likelihood(p0, *args, plot=True)

In [None]:
%%time
# res = minimize(neg_ln_likelihood, x0=p0, args=args,
#                method='L-BFGS-B',
#                bounds=[(-0.1, 0.1), (20, 35), (None, None), (-1, 1), (-70, -40),
#                        (0, 20), (220, 260), (0, 15)])
res = minimize(neg_ln_likelihood, x0=p0, args=args,
               method='powell')

In [None]:
res

In [None]:
ln_likelihood(res.x, *args, plot=True)

In [None]:
res.x[:5]

In [None]:
p0[:5]

In [None]:
galcen_frame