In [None]:
from os import path

import astropy.coordinates as coord
from astropy.table import Table
import astropy.units as u
from astropy.io import ascii
from astropy.io import fits
from astropy.wcs import WCS
import reproject

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import h5py
from scipy.spatial import cKDTree

from pyia import GaiaData
import gala.dynamics as gd
import gala.coordinates as gc
import gala.potential as gp
from gala.mpl_style import center_emph, center_deemph

from dustmaps.sfd import SFDQuery
from helpers import get_ext

import 

In [None]:
g = GaiaData('../data/data-joined.fits')
c = g.get_skycoord(distance=False)
mag_c = c.transform_to(gc.MagellanicStream)

## Dust-correct Gaia colors:

In [None]:
sfd = SFDQuery()
ebv = sfd.query(c)

In [None]:
Agaia = np.vstack(get_ext(g.phot_g_mean_mag.value,
                          g.phot_bp_mean_mag.value,
                          g.phot_rp_mean_mag.value,
                          ebv)).T

In [None]:
G0 = g.phot_g_mean_mag.value - Agaia[:, 0]
BP0 = g.phot_bp_mean_mag.value - Agaia[:, 1]
RP0 = g.phot_rp_mean_mag.value - Agaia[:, 2]
bprp0 = BP0 - RP0
bpg0 = BP0 - G0

## Define cluster and control fields

In [None]:
control_mask = ( ((mag_c.L < 70.2*u.deg) & (mag_c.L > 67*u.deg)) |
                 ((mag_c.L < 60.25*u.deg) & (mag_c.L > 58.25*u.deg)) )
control_mask.sum(), 8 * (70.2-67 + 60.25-58.25)

In [None]:
cluster_c = coord.SkyCoord(ra=179.5*u.deg,
                           dec=-28.8*u.deg)

cl_rad = 1.7*u.deg
cluster_mask = c.separation(cluster_c) < cl_rad
cluster_mask.sum(), np.pi * cl_rad**2

## Define feature arrays:

In [None]:
# Cov = np.zeros((cluster_mask.sum(), 2))
# Gerr = g.phot_g_mean_flux_over_error
# bprperr = np.sqrt(g.phot_bp_mean_flux_over_error**2 + g.phot_rp_mean_flux_over_error**2)

In [None]:
full_cov = g.get_cov()[cluster_mask]
Cov = full_cov[:, 3:5, 3:5]

X = np.vstack((g.pmra.value[cluster_mask],
               g.pmdec.value[cluster_mask])).T

In [None]:
def select_func(data):
    return ((data[:, 0] < 10) & (data[:, 0] > -10) &
            (data[:, 1] < 10) & (data[:, 1] > -10))

In [None]:
from functools import partial

# use the covariance of the nearest neighbor.
def covar_tree_cb(coords, tree, covar):
    """Return the covariance of the nearest neighbor of coords in data."""
    dist, ind = tree.query(coords, k=1)
    return covar[ind.flatten()]

from sklearn.neighbors import KDTree
tree = KDTree(X, leaf_size=100)
covar_cb = partial(covar_tree_cb, tree=tree, covar=Cov)

In [None]:
gmm = pygmmis.GMM(K=4, D=2)

In [None]:
logL, U = pygmmis.fit(gmm, X, init_method='kmeans', 
                      covar=Cov, sel_callback=select_func,
                      covar_callback=covar_cb, )

In [None]:
samples, covar_samples, N_orig = pygmmis.draw(gmm, X.shape[0], sel_callback=select_func,
                                              covar_callback=covar_cb)

In [None]:
bins = np.linspace(-10, 10, 64)

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

for i, data in enumerate([X, samples]):
    ax = axes[i]
    H, xe, ye = np.histogram2d(data[:, 0], data[:, 1], bins=bins)
    ax.pcolormesh(xe, ye, H.T, cmap='magma')