# TODO: redo with 0.5 deg bins or overlapping 1 deg bins

In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

import astropy.coordinates as coord
from astropy.coordinates.matrix_utilities import rotation_matrix
from astropy.table import Table
from astropy.io import fits
import astropy.units as u
from scipy.special import logsumexp
from scipy.stats import truncnorm
import emcee

import gala.coordinates as gc
import gala.dynamics as gd
import gala.potential as gp
import gala.mpl_style

from coordinates import (pal5_c, galcen_frame, 
                         pal5_lead_frame, pal5_trail_frame)

In [None]:
t = Table.read('../data/pal5-apw-filtered.fits')
c = coord.SkyCoord(ra=t['ra']*u.deg, dec=t['dec']*u.deg)

In [None]:
c_l = c.transform_to(pal5_lead_frame)
c_t = c.transform_to(pal5_trail_frame)

# Width along stream

In [None]:
Xl = np.stack((c_l.phi1.wrap_at(180*u.deg).degree,
               c_l.phi2.degree)).T

Xt = np.stack((c_t.phi1.wrap_at(180*u.deg).degree,
               c_t.phi2.degree)).T

In [None]:
phi1_bins = np.arange(0, 20+1e-3, 1.)
phi2_bins = np.arange(-2, 2+1e-3, 0.1)

# ---

for X in [Xl, Xt]:
    fig, ax = plt.subplots(1, 1, figsize=(10, 4))

    ax.plot(X[:, 0], X[:, 1],
            marker='o', ls='none', 
            color='k', alpha=0.25, ms=2)

    for l in phi1_bins:
        ax.axvline(l, marker='', zorder=10, 
                   color='tab:blue', alpha=0.8)

    for i, l, r in zip(range(len(phi1_bins)), phi1_bins[:-1], phi1_bins[1:]):
        ax.text(0.5*(l+r), 2.5, str(i), fontsize=14, 
                ha='center', va='center')

    ax.set_xlim(0, 20.)
    ax.set_ylim(-2, 2)
    ax.set_aspect('equal')

    ax.set_xlabel(r'$\phi_1$ [deg]')
    ax.set_ylabel(r'$\phi_2$ [deg]')

    fig.tight_layout()
    fig.set_facecolor('w')

    # ---

    fig, axes = plt.subplots(5, 4, figsize=(16, 16),
                             sharex=True, sharey=True)

    for i, l, r in zip(range(len(phi1_bins)-1), 
                       phi1_bins[:-1], 
                       phi1_bins[1:]):
        phi1_mask = (X[:, 0] >= l) & (X[:, 0] <= r)
        ax = axes.flat[i]

        ax.hist(X[phi1_mask, 1], 
                bins=phi2_bins)
        ax.set_title(str(i), fontsize=16)

        ax.set_ylim(0, 150)

    fig.set_facecolor('w')
    fig.tight_layout()

In [None]:
def lnnormal(x, mu, std):
    return -0.5 * (x-mu)**2 / std**2 - 0.5*np.log(2*np.pi) - np.log(std)

def ln_truncnorm(x, mu, sigma, clip_a, clip_b):
    a, b = (clip_a - mu) / sigma, (clip_b - mu) / sigma
    return truncnorm.logpdf(x, a, b, loc=mu, scale=sigma)

def lnprior(p):
    a_s1, a_s2, mu_s, lnstd_s1, lnstd_s2, *bg_p = p
    
    lp = 0
    
    fs = [a_s1, a_s2]
    for f in fs:
        if not 0 <= f < 1:
            return -np.inf
    
    if a_s1 < a_s2:
        return -np.inf
    
    if lnstd_s2 < lnstd_s1:
        return -np.inf
    
    if sum(fs) > 1:
        return -np.inf
    
    if not -1 < mu_s < 1:
        return -np.inf
    
    # lp += lnnormal(lnstd_s1, -1, 1)
    # lp += lnnormal(lnstd_s2, -1, 1)
    lp += ln_truncnorm(lnstd_s1, -0.5, 1, -2.5, 1.5)
    lp += ln_truncnorm(lnstd_s1, -0.5, 1, -2.5, 1.5)
    
    for pp in bg_p:
        lp += lnnormal(pp, 0, 5)
    
    return lp

def lnlike(p, phi2):
    a_s1, a_s2, mu_s, lnstd_s1, lnstd_s2, *bg_p = p
    a_bg = 1 - a_s1 - a_s2
    
    stream1 = lnnormal(phi2, mu_s, np.exp(lnstd_s1)) + np.log(a_s1)
    stream2 = lnnormal(phi2, mu_s, np.exp(lnstd_s2)) + np.log(a_s2)
    
    # Background model:
    # for quadratic bg:
    a, b, c = bg_p
    lnA = np.log(6) - np.log(-6*c*phi2_min - 3*b*phi2_min**2 - 2*a*phi2_min**3 + 
                             6*c*phi2_max + 3*b*phi2_max**2 + 2*a*phi2_max**3)
    bg_ll = lnA + np.log(a*phi2**2 + b*phi2 + c)
    bg = bg_ll + np.log(a_bg)
    
    # for constant bg:
    # bg_ll = -np.log(phi2_max - phi2_min)
    # bg = np.full_like(stream1, bg_ll) + np.log(a_bg)
    
    return logsumexp([stream1, stream2, bg], axis=0)

def lnprob(p, phi2):
    lp = lnprior(p)
    if not np.all(np.isfinite(lp)):
        return -np.inf
    
    ll = lnlike(p, phi2)
    if not np.all(np.isfinite(ll)):
        return -np.inf
    
    return ll.sum() + lp

In [None]:
nwalkers = 64
nburn = 1024
nsteps = 1024

In [None]:
phi2_min = -2.
phi2_max = 2.

data = dict()
for name, X, _phi1_bins in zip(['lead', 'trail'], 
                               [Xl, Xt], 
                               [phi1_bins[:12], phi1_bins[:18]]):
    phi2_mask = (X[:, 1] > phi2_min) & (X[:, 1] < phi2_max)

    all_samplers = []
    Ns = []
    for i, l, r in zip(range(len(phi1_bins)-1), 
                       _phi1_bins[:-1], 
                       _phi1_bins[1:]):

        phi1_mask = (X[:, 0] > l) & (X[:, 0] <= r)
        binX = X[phi1_mask & phi2_mask]
        Ns.append((phi1_mask & phi2_mask).sum())

        H, _ = np.histogram(binX[:, 1], bins=phi2_bins)
        phi2_bin_c = 0.5*(phi2_bins[:-1]+phi2_bins[1:])
        mu = phi2_bin_c[H.argmax()]
        if np.abs(mu) > 1.:
            mu = 0.
        p0 = (0.2, 0., mu, 
              np.log(0.1), np.log(0.2)) + (0, 0, 1.)

        p0s = emcee.utils.sample_ball(p0, [1e-3]*len(p0), nwalkers)

        sampler = emcee.EnsembleSampler(nwalkers, len(p0), 
                                        log_prob_fn=lnprob, 
                                        args=(binX[:, 1], ))

        pos,*_ = sampler.run_mcmc(p0s, nburn, progress=True)
        pos = emcee.utils.sample_ball(np.median(pos, axis=0), 
                                      [1e-3]*len(p0), nwalkers)
        sampler.reset()
        _ = sampler.run_mcmc(pos, nsteps, progress=True)
        sampler.reset()
        pos,*_ = sampler.run_mcmc(pos, nsteps//2, progress=True)
        sampler.reset()
        _ = sampler.run_mcmc(pos, nsteps//2, progress=True)
        print()
        
        all_samplers.append(sampler)
    
    data[name] = dict()
    data[name]['X'] = X
    data[name]['samplers'] = all_samplers
    data[name]['phi1_bins'] = _phi1_bins
    data[name]['N'] = np.array(Ns)

In [None]:
# # sampler = data['lead']['samplers'][2]
# sampler = data['lead']['samplers'][0]

# fig, axes = plt.subplots(sampler.ndim, 1, figsize=(8, 2*sampler.ndim),
#                          sharex=True)
# for k in range(sampler.ndim):
#     for walker in sampler.chain[..., k]:
#         axes[k].plot(walker, marker='', drawstyle='steps-mid', 
#                      color='k', alpha=0.2)
        
# fig.tight_layout()

In [None]:
flatchains = dict()
for name in data:
    all_flatchains = []
    for sampler in data[name]['samplers']:
        all_flatchains.append(sampler.flatchain)

    all_flatchains = np.array(all_flatchains)
    all_flatchains[..., 3] = np.exp(all_flatchains[..., 3])
    all_flatchains[..., 4] = np.exp(all_flatchains[..., 4])

    flatchains[name] = all_flatchains

In [None]:
fig = plt.figure(figsize=(8,4))

for name in data.keys():
    this_data = data[name]
    phi1_bin_c = 0.5 * (this_data['phi1_bins'][:-1] + this_data['phi1_bins'][1:])
    
    flatchain = flatchains[name]
    med = np.median(flatchain[..., 3], axis=1)
    err1 = med - np.percentile(flatchain[..., 3], 16, axis=1)
    err2 = np.percentile(flatchain[..., 3], 84, axis=1) - med
    plt.errorbar(phi1_bin_c, med, yerr=(err1, err2),
                 ls='none', marker='o', label=name)
plt.legend(loc='best', fontsize=15)
    
plt.xlim(0, 17)
plt.ylim(0, 0.6)

plt.xlabel(r'$\Delta \phi_1$ [deg]')
plt.ylabel(r'$\sigma$ [deg]')
fig.set_facecolor('w')
fig.tight_layout()

In [None]:
# plt.figure()

# for name in data.keys():
#     this_data = data[name]
#     phi1_bin_c = 0.5 * (this_data['phi1_bins'][:-1] + this_data['phi1_bins'][1:])
    
#     flatchain = flatchains[name]
#     med = np.median(flatchain[..., 4], axis=1)
#     err1 = med - np.percentile(flatchain[..., 4], 16, axis=1)
#     err2 = np.percentile(flatchain[..., 4], 84, axis=1) - med
#     plt.errorbar(phi1_bin_c, med, yerr=(err1, err2),
#                  ls='none', marker='o')
    
# plt.xlim(0, 17)
# plt.ylim(0, 1)

In [None]:
fig = plt.figure(figsize=(8,4))

for name in data.keys():
    this_data = data[name]
    phi1_bin_c = 0.5 * (this_data['phi1_bins'][:-1] + this_data['phi1_bins'][1:])
    
    flatchain = flatchains[name]
    med = np.median(flatchain[..., 2], axis=1)
    err1 = med - np.percentile(flatchain[..., 2], 16, axis=1)
    err2 = np.percentile(flatchain[..., 2], 84, axis=1) - med
    plt.errorbar(phi1_bin_c, med, yerr=(err1, err2),
                 ls='none', marker='o')
    
plt.xlim(0, 17)
plt.ylim(-1, 1)

plt.xlabel(r'$\Delta \phi_1$ [deg]')
plt.ylabel('$\Delta \phi_2$ [deg]')
fig.set_facecolor('w')
fig.tight_layout()

In [None]:
plt.figure()

for name in data.keys():
    this_data = data[name]
    phi1_bin_c = 0.5 * (this_data['phi1_bins'][:-1] + this_data['phi1_bins'][1:])
    
    flatchain = flatchains[name]
    ch = flatchain[..., 0] + flatchain[..., 1]
    med = np.median(ch, axis=1)
    err1 = med - np.percentile(ch, 16, axis=1)
    err2 = np.percentile(ch, 84, axis=1) - med
    plt.errorbar(phi1_bin_c, med, yerr=(err1, err2),
                 ls='none', marker='o', label=name)
plt.legend(loc='best', fontsize=15)
    
plt.xlim(0, 17)
plt.ylim(0, 0.25)
plt.ylabel('$f$')

In [None]:
ch = (flatchain[..., 0] + flatchain[..., 1]) * this_data['N'][:, None]

In [None]:
fig = plt.figure(figsize=(8,4))

for name in data.keys():
    this_data = data[name]
    phi1_bin_c = 0.5 * (this_data['phi1_bins'][:-1] + this_data['phi1_bins'][1:])
    
    flatchain = flatchains[name]
    ch = (flatchain[..., 0] + flatchain[..., 1]) * this_data['N'][:, None]

    med = np.median(ch, axis=1)
    err1 = med - np.percentile(ch, 16, axis=1)
    err2 = np.percentile(ch, 84, axis=1) - med
    plt.errorbar(phi1_bin_c, med, yerr=(err1, err2),
                 ls='none', marker='o', label=name)
plt.legend(loc='best', fontsize=15)
    
plt.xlim(0, 17)
# plt.ylim(0, 0.25)
plt.xlabel(r'$\Delta \phi_1$ [deg]')
plt.ylabel('$N$ stream members')
fig.set_facecolor('w')
fig.tight_layout()

In [None]:
for name in data.keys():
    this_data = data[name]
    X = this_data['X']
    phi2_mask = (X[:, 1] > phi2_min) & (X[:, 1] < phi2_max)
    
    phi1_bin_c = 0.5 * (this_data['phi1_bins'][:-1] + this_data['phi1_bins'][1:])
    for i,l,r in zip(range(len(phi1_bin_c)),
                     this_data['phi1_bins'][:-1],
                     this_data['phi1_bins'][1:]):
        phi1_mask = (X[:, 0] > l) & (X[:, 0] <= r)
        binX = X[phi1_mask & phi2_mask]
        sampler = data[name]['samplers'][i]
        
        plt.figure()
        plt.hist(binX[:, 1], bins=phi2_bins, density=True);
        _grid = np.linspace(-2, 2, 1000)
        
        for k in np.random.choice(len(sampler.flatchain), size=128, replace=False):
            pp = sampler.flatchain[k]
            plt.plot(_grid, np.exp(lnlike(pp, _grid)), 
                     marker='', alpha=0.1, color='tab:orange')
        # plt.plot(_grid, np.exp(lnlike(p0, _grid)), marker='', ls='--')
        plt.title("{} {}: at phi1 {:.1f}".format(name, i, phi1_bin_c[i]))
        plt.xlim(-2, 2)
        plt.ylim(0, 1.2)
        plt.xlabel(r'$\phi_2$ [deg]')