In [None]:
import pathlib

from astropy.convolution import Gaussian2DKernel, convolve
import astropy.coordinates as coord
from astropy.io import ascii, fits
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.stats import binned_statistic, binned_statistic_2d
from IPython.display import HTML
from astropy.stats import median_absolute_deviation as MAD
from tqdm.notebook import tqdm

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic

from pyia import GaiaData
from cmastro import cmaps

In [None]:
vcirc = 229 * u.km/u.s
galcen_frame = coord.Galactocentric()

galcen_frame

In [None]:
mw = gp.MilkyWayPotential()
sun_w0 = gd.PhaseSpacePosition([-8.1, 0, 0.0206]*u.kpc,
                               [12.9, 245.6, 7.78]*u.km/u.s)
sun_orbit = mw.integrate_orbit(sun_w0, dt=0.5, n_steps=4000)
# _ = sun_orbit.plot()
Jz_unit = np.abs(sun_orbit.z).max() * np.abs(sun_orbit.v_z).max().to(u.km/u.s)
Jz_unit

In [None]:
# See: Setup.ipynb
data_path = pathlib.Path('../../data/').resolve()
_cache_file = data_path / 'edr3-2mass-actions.fits'
data = at.Table.read(_cache_file)

cmd_masks_file = data_path / 'cmd-masks.fits'
cmd_masks = at.Table.read(cmd_masks_file)

data = at.hstack((data, cmd_masks))

g = GaiaData(data)

In [None]:
c = g.get_skycoord()

Rg = np.abs(g.J_phi / vcirc).to_value(u.kpc)

xx = np.sqrt(g.J_z.to_value(Jz_unit)) * np.cos(g.theta_z)
yy = np.sqrt(g.J_z.to_value(Jz_unit)) * np.sin(g.theta_z)
# xx = (g.J_z.to_value(Jz_unit)) * np.cos(g.theta_z)
# yy = (g.J_z.to_value(Jz_unit)) * np.sin(g.theta_z)

Rg_mask = (Rg > 7) & (Rg < 9)

# Simulate a spiral:

$$
\rho(z) = \frac{1}{2b}{\rm sech}(z/b)^2\\
\Phi(z) = 2\pi \, b^3 \, G \, \log(\cosh(z/b))\\
M(<z) = \frac{1}{2} \, {\rm tanh}(z/b)\\
\frac{{\rm d}\Phi}{{\rm d} z} = {\rm tanh}(z/b)/b
$$

In [None]:
from numba import jit

In [None]:
@jit('void(f8[:,:], f8, f8[:,:])', nopython=True)
def numbagrad(w, b, g):
    for i in range(w.shape[1]):
        g[0, i] = w[1, i]
        g[1, i] = -np.tanh(w[0, i] / b) / b


class FastGrad:
    def __init__(self, N, b):
        self._grad = np.zeros((2, N))
        
    def __call__(self, t, w):
        numbagrad(w, b, self._grad)
        return self._grad

In [None]:
N = 100_000
b = 250

z = np.arctanh(2.*np.random.uniform(size=N)-1) * b*2.
vz = np.random.normal(size=N)

In [None]:
fastgrad = FastGrad(N=N, b=b)

w0 = np.stack((z, vz))
test = fastgrad(0., w0)
assert np.isfinite(test).all()

In [None]:
integrator = gi.DOPRI853Integrator(fastgrad)

In [None]:
init_pos = w0[0]
init_vel = w0[1] + 0.3
init_w0 = np.stack((init_pos, init_vel))

xscale = np.std(w0[0])
vscale = np.std(w0[1])

In [None]:
orbits = integrator.run(init_w0, dt=5., n_steps=2000)

# Project onto Fourier expansion

In [None]:
w0 = np.random.normal(0, 1, size=(100_000, 2)).T

In [None]:
from collections import defaultdict

def spirality(xx, yy, radius_bins, m_max=3):
    X = np.stack((xx, yy)).T
    r = np.sqrt(xx**2 + yy**2).view(np.ndarray)
    theta_z = np.arctan2(yy, xx).view(np.ndarray)
    
    amps = list()
    phases = list()
    Ns = list()
    for r1, r2 in zip(radius_bins[:-1], radius_bins[1:]):
        mask = (r >= r1) & (r < r2)
        Ns.append(mask.sum())
        
        m_amps = list()
        m_phases = list()
        for m in range(1, m_max+1):
            coeff = np.sum(np.exp(1j * m * theta_z[mask])) / len(theta_z)
            amp = np.abs(coeff)
            phase = np.arctan2(coeff.imag, coeff.real)
            
            m_amps.append(amp)
            m_phases.append(phase)
        
        amps.append(m_amps)
        phases.append(m_phases)
        
    return np.array(Ns), np.array(amps), np.array(phases)

In [None]:
radius_bins = np.arange(np.sqrt(0.1), np.sqrt(4.5), 0.2)**2
radii = 0.5 * (radius_bins[:-1] + radius_bins[:-1])

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

ax = axes[0]
H, xe, ye = np.histogram2d(
    w0[0],
    w0[1],
    bins=np.linspace(-4, 4, 128)
)
ax.pcolormesh(xe, ye, H.T, 
              norm=mpl.colors.LogNorm(),
              cmap='Greys')

ax = axes[1]
k = 1000
H, xe, ye = np.histogram2d(
    orbits[k].v_x1.value / vscale,
    orbits[k].x1.value / xscale,
    bins=np.linspace(-4, 4, 128)
)
ax.pcolormesh(xe, ye, H.T, 
              norm=mpl.colors.LogNorm(),
              cmap='Greys')

for ax in axes:
    for r in radius_bins:
        circ = mpl.patches.Circle((0,0), radius=r, 
                                  facecolor='none', 
                                  edgecolor='tab:red', 
                                  linewidth=1,
                                  alpha=0.75)
        ax.add_patch(circ)
    
    ax.set_aspect('equal')

In [None]:
m_max = 8

fig, axes = plt.subplots(2, 2, figsize=(10, 10), 
                         sharex='row', sharey='row')

N, amps, phases = spirality(
    w0[0],
    w0[1],
    radius_bins=radius_bins, 
    m_max=m_max)

axes[0, 0].plot(radii[N > 100], np.sqrt(np.sum(amps[N > 100]**2, axis=1)), lw=2)
for i in range(m_max):
    axes[0, 0].plot(
        radii[N > 100], amps[N > 100, i], 
        label=f'm={i+1}')
    
    axes[1, 0].plot(
        radii[N > 100], phases[N > 100, i])
    
axes[0, 0].legend(fontsize=14)

# ---

k = 2000
N, amps, phases = spirality(
    orbits[k].x1 / xscale, 
    orbits[k].v_x1 / vscale, 
    radius_bins=radius_bins,
    m_max=m_max)

axes[0, 1].plot(radii[N > 100], np.sqrt(np.sum(amps[N > 100]**2, axis=1)), lw=2)
for i in range(m_max):
    axes[0, 1].plot(radii[N > 100], amps[N > 100, i])
    
    axes[1, 1].plot(
        radii[N > 100], phases[N > 100, i])

In [None]:
# m_max = 3

# new_amps = []
# ks = np.arange(0, orbits.shape[0], 10)
# for k in ks:
#     xx = orbits[k].x1 / xscale
#     yy = orbits[k].v_x1 / vscale
#     N, amps, phases = spirality(
#         xx, yy, 
#         radius_bins=radius_bins,
#         m_max=m_max)
    
#     amp = np.sqrt(np.sum(amps**2, axis=1))
#     new_amps.append(amp)
    
# new_amps = np.array(new_amps)

In [None]:
m_max = 3

# Bootstrap
bootstrap_N = 128
bootstrap_amps = []
# ks = np.arange(0, orbits.shape[0], 200)
ks = [-1]
for k in tqdm(ks):
    xx = orbits[k].x1 / xscale
    yy = orbits[k].v_x1 / vscale
#     xx = w0[0]
#     yy = w0[1]
    
    trials = []
    for n in range(bootstrap_N):
        idx = np.random.choice(len(xx), size=len(xx))
        N, amps, phases = spirality(
            xx[idx], yy[idx], 
            radius_bins=radius_bins,
            m_max=m_max)
        # amp = np.sqrt(np.sum(amps**2, axis=1))
        # amp = np.sqrt(amps[:, 0]**2 + amps[:, 1]**2)
        amp = amps[:, 0:2]
        trials.append(amp)
    
    bootstrap_amps.append(trials)
    
bootstrap_amps = np.array(bootstrap_amps)

In [None]:
vals = np.mean(bootstrap_amps, axis=1)[0]
errs = np.std(bootstrap_amps, axis=1)[0]

np.sum(vals / errs**2, axis=0) / np.sum(1 / errs**2, axis=0), np.sqrt(1 / np.sum(1 / errs**2, axis=0))

In [None]:
i = -1

if bootstrap_amps.ndim == 3:
    plt.errorbar(radii, 
                 np.mean(bootstrap_amps, axis=1)[i],
                 np.std(bootstrap_amps, axis=1)[i])

elif bootstrap_amps.ndim == 4:
    for n in range(bootstrap_amps.shape[-1]):
        plt.errorbar(
            radii, 
            np.mean(bootstrap_amps, axis=1)[i, ..., n],
            np.std(bootstrap_amps, axis=1)[i, ..., n])

In [None]:
plt.figure(figsize=(4, 10))
# plt.pcolormesh(new_amps, norm=mpl.colors.LogNorm())
plt.pcolormesh(np.mean(boostrap_amps, axis=1))

# Plot the "Classic" asymmetry parameters

In [None]:
Ngrtr = (orbits.x1 > 0).sum(axis=1)
Nless = (orbits.x1 <= 0).sum(axis=1)
Az = (Ngrtr - Nless) / (Ngrtr + Nless)

Ngrtr = (orbits.v_x1 > 0).sum(axis=1)
Nless = (orbits.v_x1 <= 0).sum(axis=1)
Avz = (Ngrtr - Nless) / (Ngrtr + Nless)

In [None]:
plt.plot(Az, marker='', color='tab:red')
plt.plot(Avz, marker='', color='tab:blue')
plt.plot(np.sqrt(Avz**2 + Az**2), marker='', color='tab:purple')

---

OLD

In [None]:
def crazy_plots(xx, yy, radius_bins, ax=None):
    X = np.stack((xx, yy)).T
    r = np.sqrt(xx**2 + yy**2).view(np.ndarray)
    theta_z = np.arctan2(yy, xx).view(np.ndarray)
    
    if ax is None:
        fig, ax = plt.subplots()
    
    for r1, r2 in zip(radius_bins[:-1], radius_bins[1:]):
        mask = (r >= r1) & (r < r2)
        
        ax.hist(theta_z[mask], 
                bins=np.linspace(0, 2*np.pi, 64), 
                density=True,
                histtype='step')
        
    return ax

In [None]:
ax = crazy_plots(
    orbits[k].x1 / xscale, 
    orbits[k].v_x1 / vscale,
    radius_bins=radius_bins)s
ax.set_xlim(0, 2*np.pi)

# Wedges idea

In [None]:
mask = g.ms_cmd_mask & (g.J_z > 0) & Rg_mask
bins = np.linspace(-6, 6, 201)

fig, ax = plt.subplots(1, 1, figsize=(6, 6), 
                       constrained_layout=True)

H, xe, ye = np.histogram2d(
    xx[mask],
    yy[mask],
    bins=bins
)
ax.pcolormesh(xe, ye, H.T, 
              norm=mpl.colors.LogNorm(),
              cmap='Greys')

dang = 45.
radius_step = 1.
radii = np.arange(0.5, 5.5, radius_step)
for radius in radii:
    for ang in np.arange(0, 360, dang):
        w = mpl.patches.Wedge((0, 0), radius, ang, ang+dang, 
                              width=radius_step,
                              facecolor='none', 
                              edgecolor='k', linewidth=1)
        ax.add_patch(w)
        
        
colors = plt.get_cmap('tab10').colors
angs = np.arange(0, 180, dang)
radius = radii[3]
for i, ang in enumerate(angs):
    for j in [0, 180]:
        w = mpl.patches.Wedge((0, 0), radius, ang+j, ang+dang+j, 
                              width=radius_step,
                              facecolor=colors[i], 
                              edgecolor='none', alpha=0.4)
        ax.add_patch(w)
    
ax.set_xlim(bins.min(), bins.max())
ax.set_ylim(bins.min(), bins.max())

ax.set_xlabel(r'$\sqrt{J_z} \, \cos(\theta_z)$')
ax.set_ylabel(r'$\sqrt{J_z} \, \sin(\theta_z)$')

ax.set_aspect('equal')

In [None]:
def wedge_gen(radii, d_angle):
    rs = [(radii[i], radii[i+1] - radii[i]) 
          for i in range(len(radii) - 1)]
    
    d_angle = d_angle.to_value(u.degree)
    angs = np.arange(0, 180, d_angle)
    
    for radius, radius_width in rs:
        for i, ang in enumerate(angs):
            conjugate_wedges = []
            for flip in [0, 180]:
                w = mpl.patches.Wedge((0, 0), radius, 
                                      ang+flip, ang+d_angle+flip, 
                                      width=min(radius_width, radius))
                conjugate_wedges.append(w)
                
            yield radius, ang, conjugate_wedges
            

from collections import defaultdict
def spirality(xx, yy, radii, d_angle=45*u.deg, summary=True):
    X = np.stack((xx, yy)).T
    
    diffs = defaultdict(list)
    for radius, _, wedges in wedge_gen(radii, d_angle):
        mask1 = wedges[0].get_path().contains_points(X)
        mask2 = wedges[1].get_path().contains_points(X)
        diffs[radius].append( (mask1.sum() - mask2.sum()) / (mask1.sum() + mask2.sum()) )
    
    if summary:
        vals = []
        for r in radii:
            vals.append(np.sqrt(np.sum(np.array(diffs[r])**2)))
        return np.array(vals)
    else:
        return diffs

In [None]:
# derp = np.stack((xx[mask], yy[mask])).T

# plt.figure(figsize=(6, 6))
# for radius, ang, wedges in wedge_gen(np.arange(0.5, 5.5, 1.), 
#                                      45*u.deg):
#     fuck = wedges[0].get_path().contains_points(derp)
#     ls, = plt.plot(derp[fuck, 0], derp[fuck, 1], ls='none')
    
#     fuck = wedges[1].get_path().contains_points(derp)
#     plt.plot(derp[fuck, 0], derp[fuck, 1], ls='none', color=ls.get_color())

## "Classic" asymmetry parameters

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6), 
                         sharex=True, sharey=True,
                         constrained_layout=True)

ax = axes[0]
H, xe, ye = np.histogram2d(
    w0[1] / vscale,
    w0[0] / xscale,
    bins=np.linspace(-4, 4, 128)
)
ax.pcolormesh(xe, ye, H.T, 
              norm=mpl.colors.LogNorm(),
              cmap='Greys')

ax = axes[1]
H, xe, ye = np.histogram2d(
    orbits[1000].v_x1.value / vscale,
    orbits[1000].x1.value / xscale,
    bins=np.linspace(-4, 4, 128)
)
ax.pcolormesh(xe, ye, H.T, 
              norm=mpl.colors.LogNorm(),
              cmap='Greys')

In [None]:
k = 1000
plt.plot(spirality(orbits[k].x1 / xscale, 
           orbits[k].v_x1 / vscale, 
           radii=np.arange(0.5, 4.5, 0.5),
           summary=True))

In [None]:
plt.plot(spirality(w0[0] / xscale,
          w0[1] / vscale,
          radii=np.arange(0.5, 4.5, 0.25),
          summary=True))

In [None]:
ks = np.arange(0, orbits.shape[0], 200)
spiral_asym = np.zeros(len(ks))
for i, k in enumerate(ks):
    spiral_asym[i] = spirality(orbits[k].x1 / xscale, 
                               orbits[k].v_x1 / vscale, 
                               radii=np.arange(0.5, 4.5, 0.5),
                               summary=True)

In [None]:
fuckit = np.linspace(0.1, 1, 32)
Ns = np.zeros(len(fuckit))
vals = np.zeros_like(fuckit)
for i, stupid in enumerate(fuckit):
    derps = spirality(w0[1] / vscale,
                      w0[0] / xscale, 
                      radii=np.arange(0.5, 4.5, stupid),
                      summary=False)
    Ns[i] = len(derps)
    vals[i] = np.sqrt(np.sum(derps**2))

In [None]:
np.polyfit(Ns, vals, deg=2)

In [None]:
plt.figure(figsize=(6, 6))
plt.plot(Ns, vals)

In [None]:
plt.plot(ks, spiral_asym)

In [None]:
plt.plot(ks, spiral_asym)