In [None]:
import pathlib

from astropy.convolution import Gaussian2DKernel, convolve
import astropy.coordinates as coord
from astropy.io import ascii, fits
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.stats import binned_statistic, binned_statistic_2d
from IPython.display import HTML
from astropy.stats import median_absolute_deviation as MAD
from matplotlib.animation import FuncAnimation

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.potential as gp
from gala.units import galactic

from pyia import GaiaData
from cmastro import cmaps

from helpers import plot_spiral, vcirc, galcen_frame, solar_action_units

In [None]:
# See: Setup.ipynb
data_path = pathlib.Path('../data/').resolve()
_cache_file = data_path / 'edr3-2mass-actions.fits'
data = at.Table.read(_cache_file)

cmd_masks_file = data_path / 'cmd-masks.fits'
cmd_masks = at.Table.read(cmd_masks_file)

data = at.hstack((data, cmd_masks))

In [None]:
tess = at.Table.read(
    pathlib.Path('~/data/Asteroseismology/TESS_Hon2021_GaiaEDR3.fits').expanduser(),
    hdu=1)

tess = tess[tess['massflag'] == 1]
tess = tess['TIC', 'source_id', 'mass', 'e_mass']
data = at.join(data, tess, join_type='left', keys='source_id')

In [None]:
g = GaiaData(data)
len(g)

In [None]:
cmd_mask = data['ms_cmd_mask']

## Coordinate transforms

In [None]:
c = g.get_skycoord()

In [None]:
galcen = c.transform_to(galcen_frame)
w0 = gd.PhaseSpacePosition(galcen.data)
L = w0.angular_momentum()
Lz = L[2]

cyl = galcen.cylindrical
vcyl = galcen.cylindrical.differentials['s']
vR = vcyl.d_rho.to_value(u.km/u.s)
vphi = (cyl.rho * vcyl.d_phi).to_value(u.km/u.s, u.dimensionless_angles())

In [None]:
vz_z_bins = (
    np.arange(-75, 75+1e-3, 1.),
    np.arange(-1.5, 1.5+1e-3, 25/1e3)
)

In [None]:
Lz_sun = 240*u.km/u.s * 8.1*u.kpc
Lz_solar = -Lz.to_value(Lz_sun)

Rg_bins = [[x, x+2] for x in np.arange(5, 10+1, 1)] * u.kpc
Lz_bins = (Rg_bins * vcirc).to_value(Lz_sun)
Rg_kpc = (Lz_solar * Lz_sun / vcirc).to_value(u.kpc)

print([((Lz_solar > m1) & (Lz_solar < m2)).sum()
       for m1, m2 in Lz_bins])

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

ax = axes[0]
ax.hist(Lz_solar,
        bins=np.linspace(0, 2, 32));
ax.set_yscale('log')
ax.set_xlabel('$L_z$');
ax.vlines(np.ravel(Lz_bins),
           plt.ylim()[0], plt.ylim()[1], 
           color='tab:blue')

ax = axes[1]
ax.hist(Rg_kpc,
        bins=np.linspace(3, 15, 51));
ax.set_yscale('log')
ax.set_xlabel('$R_g$');
ax.vlines(np.ravel(Rg_bins),
          plt.ylim()[0], plt.ylim()[1], 
          color='tab:blue')

## Action distributions

In [None]:
print(
    np.nanpercentile(g.J_R.to_value(solar_action_units[0]), [5, 84]),
    np.nanpercentile(g.J_z.to_value(solar_action_units[2]), [5, 84])
)

In [None]:
print(
    1.5 * MAD(g.J_R.to_value(solar_action_units[0]), ignore_nan=True),
    1.5 * MAD(g.J_z.to_value(solar_action_units[2]), ignore_nan=True)
)

In [None]:
Lz_mask = np.abs((Lz - Lz_sun) / Lz_sun) < 0.1

fig, axes = plt.subplots(1, 2, figsize=(10, 5), 
                         sharex=True, sharey=True)

axes[0].hist(g.J_R.to_value(solar_action_units[0]),
             bins=np.linspace(0., 10, 128), 
             density=True);
axes[1].hist(np.sqrt(g.J_z.to_value(solar_action_units[2])),
             bins=np.linspace(0., 10, 128),
             density=True);

axes[0].set_xlabel(r'$J_R$ [$J_{R, \odot}$]')
axes[1].set_xlabel(r'$\sqrt{J_z / J_{z, \odot}}$')

fig.tight_layout()

## Spiral plots

In [None]:
fig, ax = plt.subplots(
    1, 1,
    constrained_layout=True
)

(Lz1, Lz2) = Lz_bins[2]
(Rg1, Rg2) = Rg_bins[2]
Lz_mask = (Lz_solar > Lz1) & (Lz_solar <= Lz2) & cmd_mask
ax.set_title(f'$R_g = {Rg1.value:.0f}$–${Rg2.value:.0f}$ kpc')

plot_spiral(
    galcen.v_z[Lz_mask],
    galcen.z.to(u.kpc)[Lz_mask],
    colorbar=False, xlabel='', ylabel='',
    ax=ax)


# for ax in all_axes[-1]:
#     ax.set_xlabel(f'$v_z$ [{u.km/u.s:latex_inline}]')
# for ax in all_axes[:, 1]:
#     ax.set_ylabel(f'$z$ [{u.kpc:latex_inline}]')
    
# _labels = ['number density', r'$\delta n$']
# for ax, label in zip(all_axes[:, -1], _labels):
#     ax.set_ylabel(label)
#     ax.yaxis.set_label_position('right')

In [None]:
ncols = len(Lz_bins)
fig, all_axes = plt.subplots(
    2, ncols, figsize=(ncols*4 + 0.5, 4 * 2),
    sharex='row', sharey='row',
    constrained_layout=True
)

for i, ((Lz1, Lz2), (Rg1, Rg2)) in enumerate(zip(Lz_bins, Rg_bins)):
    axes = all_axes[:, i]

    Lz_mask = (Lz_solar > Lz1) & (Lz_solar <= Lz2) & cmd_mask
    axes[0].set_title(f'$R_g = {Rg1.value:.0f}$–${Rg2.value:.0f}$ kpc')
    
    plot_spiral(
        galcen.v_z[Lz_mask],
        galcen.z.to(u.kpc)[Lz_mask],
        colorbar=False, xlabel='', ylabel='',
        ax=axes[0])
    
    plot_spiral(
        g.J_z[Lz_mask],
        g.theta_z[Lz_mask],
        bins=np.linspace(-7.5, 7.5, 128),
        ax=axes[1], shuffle_subtract=True,
        colorbar=False, xlabel='', ylabel='',
        pcolor_kw=dict(cmap='magma_r', vmin=-0.5, vmax=0.5)
    )

for ax in all_axes[-1]:
    ax.set_xlabel(f'$v_z$ [{u.km/u.s:latex_inline}]')
for ax in all_axes[:, 1]:
    ax.set_ylabel(f'$z$ [{u.kpc:latex_inline}]')
    
_labels = ['number density', r'$\delta n$']
for ax, label in zip(all_axes[:, -1], _labels):
    ax.set_ylabel(label)
    ax.yaxis.set_label_position('right')

## Theta R

In [None]:
Lz_mask = (
    (np.abs(Lz_solar - 1.) < 0.2) &
    cmd_mask &
    (g.theta_R > 0) & (g.theta_R < 2*np.pi*u.rad)
)

fig, axes = plt.subplots(1, 2, figsize=(15, 6), 
                         constrained_layout=True)

JR_solar = g.J_R.to_value(40*u.km/u.s * 1*u.kpc)
plot_spiral(
    galcen.v_z[Lz_mask], 
    galcen.z.to(u.kpc)[Lz_mask], 
    arr=np.sqrt(JR_solar[Lz_mask]),
    colorbar_label=r'$J_R$',
    ax=axes[0], arr_statistic='median'
);

plot_spiral(
    galcen.v_z[Lz_mask], 
    galcen.z.to(u.kpc)[Lz_mask],
    arr=np.cos(g.theta_R.value[Lz_mask]),
    colorbar_label=r'$\cos\theta_R$',
    pcolor_kw=dict(vmin=-1, vmax=1, cmap='cma:laguna'),
    ax=axes[1], arr_statistic='median'
);

In [None]:
mask = (
    (np.abs(Lz_solar - 1.) < 0.2) &
    cmd_mask &
    (g.theta_R > 0) & (g.theta_R < 2*np.pi*u.rad)
)

fig, axes = plt.subplots(1, 2, figsize=(15, 6), 
                         constrained_layout=True)

JR_solar = g.J_R.to_value(40*u.km/u.s * 1*u.kpc)
plot_spiral(
    g.J_z[mask],
    g.theta_z[mask],
    arr=np.sqrt(JR_solar[mask]),
    colorbar_label=r'$J_R$',
    pcolor_kw=dict(vmin=0.5, vmax=1.6, cmap='cma:hesperia_r'),
    ax=axes[0], arr_statistic='median'
);

plot_spiral(
    g.J_z[mask],
    g.theta_z[mask],
    arr=np.cos(g.theta_R.value[mask]),
    colorbar_label=r'$\cos\theta_R$',
    pcolor_kw=dict(vmin=-1, vmax=1, cmap='cma:laguna'),
    ax=axes[1], arr_statistic='median'
);

for ax in axes:
    ax.set_xlim(-7, 7)
    ax.set_ylim(-7, 7)

In [None]:
mask = (
    (np.abs(Lz_solar - 1.) < 0.2) &
    cmd_mask &
    (g.theta_R > 0) & (g.theta_R < 2*np.pi*u.rad)
)

JR_solar = g.J_R.to_value(40*u.km/u.s * 1*u.kpc)
cosarr = np.sqrt(JR_solar) * np.cos(g.theta_R.value)
sinarr = np.sqrt(JR_solar) * np.sin(g.theta_R.value)

fig, axes = plt.subplots(1, 2, figsize=(15, 6), 
                         constrained_layout=True)

style = dict(vmin=-0.5, vmax=0.5, cmap='magma_r')
plot_spiral(
    galcen.v_z[Lz_mask], 
    galcen.z.to(u.kpc)[Lz_mask], 
    arr=cosarr[mask],
    colorbar_label=r'$\sqrt{J_R} \, \cos\theta_R$',
    pcolor_kw=style,
    ax=axes[0]);

plot_spiral(
    galcen.v_z[Lz_mask], 
    galcen.z.to(u.kpc)[Lz_mask], 
    arr=sinarr[mask],
    colorbar_label=r'$\sqrt{J_R} \, \sin\theta_R$',
    pcolor_kw=style,
    ax=axes[1]);

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 6), 
                         constrained_layout=True)

style = dict(vmin=-0.8, vmax=0.8, cmap='magma_r')
plot_spiral(
    g.J_z[mask], 
    g.theta_z[mask],
    arr=cosarr[mask],
    colorbar_label=r'$\sqrt{J_R} \, \cos\theta_R$',
    pcolor_kw=style,
    ax=axes[0], arr_statistic='median');

plot_spiral(
    g.J_z[mask], 
    g.theta_z[mask],
    arr=sinarr[mask],
    colorbar_label=r'$\sqrt{J_R} \, \sin\theta_R$',
    pcolor_kw=style,
    ax=axes[1], arr_statistic='median');

In [None]:
JRthR_bins = np.linspace(-1, 1, 9)
JRthR_bins = np.array(list(zip(JRthR_bins[:-1], JRthR_bins[1:])))

ncols = len(JRthR_bins)
fig, axes = plt.subplots(
    2, ncols//2, figsize=(ncols//2*4 + 0.5, 4.5 * 2),
    sharex=True, sharey=True,
    constrained_layout=True
)

Lz1, Lz2 = ([7, 10]*u.kpc * vcirc).to_value(Lz_sun)

for i, (v1, v2) in enumerate(JRthR_bins):
    ax = axes.flat[i]
    mask = (
        (Lz_solar > Lz1) & 
        (Lz_solar <= Lz2) &
        cmd_mask &
        (cosarr > v1) &
        (cosarr <= v2)
    )
    ax.set_title(rf'${v1:.2f} < J_R\,\cos\theta_R < {v2:.2f}$', fontsize=14)
    
    ax, mesh = plot_spiral(
        g.J_z[mask], 
        g.theta_z[mask],
        bins=np.linspace(-7.5, 7.5, 128),
        ax=ax, shuffle_subtract=True,
        colorbar=False, xlabel='', ylabel='',
        pcolor_kw=dict(cmap='magma_r', vmin=-0.5, vmax=0.5)
    )
    
for ax in axes[-1]:
    ax.set_xlabel(r'$\sqrt{J_z}\,\cos\theta_z$')
for ax in axes[:, 0]:
    ax.set_ylabel(r'$\sqrt{J_z}\,\sin\theta_z$')
    
cb = fig.colorbar(mesh, ax=axes, aspect=30)

In [None]:
thR_bins = np.linspace(-1, 1, 9)
thR_bins = np.array(list(zip(thR_bins[:-1], thR_bins[1:])))

ncols = len(thR_bins)
fig, axes = plt.subplots(
    2, ncols//2, figsize=(ncols//2*4 + 0.5, 4.5 * 2),
    sharex=True, sharey=True,
    constrained_layout=True
)

Lz1, Lz2 = ([7, 9]*u.kpc * vcirc).to_value(Lz_sun)

for i, (th1, th2) in enumerate(thR_bins):
    ax = axes.flat[i]
    mask = (
        (Lz_solar > Lz1) & 
        (Lz_solar <= Lz2) &
        cmd_mask &
        (np.cos(g.theta_R) > th1) &
        (np.cos(g.theta_R) <= th2) &
        # (JR_solar > 0.5) & (JR_solar < 1)
        (JR_solar < 0.2)
    )
    ax.set_title(rf'${th1:.1f} < \cos\theta_R < {th2:.1f}$')
    
    ax, mesh = plot_spiral(
        g.J_z[mask], 
        g.theta_z[mask],
        bins=np.linspace(-7.5, 7.5, 128),
        ax=ax, shuffle_subtract=True,
        colorbar=False, xlabel='', ylabel='',
        pcolor_kw=dict(cmap='magma_r', vmin=-0.5, vmax=0.5)
    )
    
for ax in axes[-1]:
    ax.set_xlabel(r'$\sqrt{J_z}\,\cos\theta_z$')
for ax in axes[:, 0]:
    ax.set_ylabel(r'$\sqrt{J_z}\,\sin\theta_z$')
    
cb = fig.colorbar(mesh, ax=axes, aspect=30)

---

## Animations

In [None]:
nframes = 40
dx = 0.1
x_bin_l = np.arange(-1, 1+1e-3 - dx, 2/nframes)  # x = cos(theta_R)
x_bins = [(x, x+dx) for x in x_bin_l]

anim_Lz_bins = ([
    [5, 8],
    [6.5, 9.5],
    [8, 11]
] * u.kpc * vcirc).to_value(Lz_sun)

# ---

fig, axes = plt.subplots(
    1, len(anim_Lz_bins), 
    figsize=(len(anim_Lz_bins) * 3 + 2, 5),
    sharex=True, sharey=True,
    constrained_layout=True
)

for ax in axes:
    ax.set_xlabel(r'$\sqrt{J_z}\,\cos\theta_z$')
axes[0].set_ylabel(r'$\sqrt{J_z}\,\sin\theta_z$')

def func(i, meshes=None):
    if i == 0:
        for j, ax in enumerate(axes):
            rg1, rg2 = (anim_Lz_bins[j] * Lz_sun / vcirc).to_value(u.kpc)
            ax.set_title(rf'${rg1:.1f} < R_g < {rg2:.1f}$ kpc', fontsize=18)
    
    x1, x2 = x_bins[i]
    fig.suptitle(rf'${x1:.1f} < \cos\theta_R < {x2:.1f}$', fontsize=22)
    
    _meshes = []
    for k, (ax, (Lz1, Lz2)) in enumerate(zip(axes, anim_Lz_bins)):
        mask = (
            (Lz_solar > Lz1) & 
            (Lz_solar <= Lz2) &
            cmd_mask &
            (np.cos(g.theta_R) > x1) &
            (np.cos(g.theta_R) <= x2)
        )
        
        if meshes is None:
            ax, mesh = plot_spiral(
                g.J_z[mask], 
                g.theta_z[mask],
                bins=np.linspace(-7.5, 7.5, 128),
                ax=ax, shuffle_subtract=True,
                colorbar=False, xlabel='', ylabel='',
                pcolor_kw=dict(cmap='magma_r', vmin=-0.5, vmax=0.5)
            )
            
        else:
            ax, mesh = plot_spiral(
                g.J_z[mask], 
                g.theta_z[mask],
                bins=np.linspace(-7.5, 7.5, 128),
                ax=ax, mesh=meshes[k], shuffle_subtract=True,
                colorbar=False, xlabel='', ylabel='',
                pcolor_kw=dict(cmap='magma_r', vmin=-0.5, vmax=0.5)
            )
            
        _meshes.append(mesh)
    
    return _meshes

meshes = func(0)

anim = FuncAnimation(fig, func, frames=len(x_bins), 
                     fargs=(meshes,), blit=True)

In [None]:
# HTML(anim.to_html5_video())
anim.save('../plots/spiral-costhetaR-Jzcosthz.gif')

TODO: At fixed phase in theta_R, slice in J_R?

In [None]:
plt.hist(JR_solar[Lz_mask], bins=np.logspace(-2, 1, 64));
plt.xscale('log')

## Spiral density - JR selection

In [None]:
JR_solar = g.J_R.to_value(40*u.km/u.s * 1*u.kpc)

mask1 = (
    (np.abs(Lz_solar - 1.) < 0.1) &
    cmd_mask &
    (JR_solar < 0.1)
)

mask2 = (
    (np.abs(Lz_solar - 1.) < 0.1) &
    cmd_mask &
    (JR_solar >= 1)
)

for mask in [mask1, mask2]:
    fig, ax = plt.subplots(1, 1, figsize=(6.5, 6))

    style = dict(cmap='magma_r', vmin=-0.75, vmax=0.75)
    plot_spiral(
        g.J_z[mask], 
        g.theta_z[mask], 
        pcolor_kw=style,
        shuffle_subtract=True,
        ax=ax);

In [None]:
mask = (
    (np.abs(Lz_solar - 1.) < 0.1) &
    cmd_mask
#     (JR_solar > 1)
)

# x1 = (g.J_R.to_value(40*u.km/u.s * 1*u.kpc) * np.cos(g.theta_R))[mask]
# x2 = (g.J_z.to_value(15*u.km/u.s * 0.5*u.kpc) * np.cos(g.theta_z))[mask]
# x1 = (np.cos(g.theta_R))[mask]
# x2 = (np.cos(g.theta_z))[mask]
x1 = (g.J_R.to_value(40*u.km/u.s * 1*u.kpc))[mask]
x2 = (g.J_z.to_value(15*u.km/u.s * 0.5*u.kpc))[mask]

fig, ax = plt.subplots(1, 1, figsize=(6, 6))

ax.hist2d(x1, x2, bins=np.logspace(-2, 1, 128), 
          norm=mpl.colors.LogNorm());
ax.set_xscale('log')
ax.set_yscale('log')

---

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

ax.hist2d(Rg_kpc, g.Om_z.to_value(1/u.Gyr) / (2*np.pi),
          bins=(np.linspace(0, 25, 128),
                np.logspace(0, 1.3, 128)),
          norm=mpl.colors.LogNorm());
ax.set_ylim(1, 20)
ax.set_yscale('log')

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

ax.hist2d(Rg_kpc, g.Om_R.to_value(1/u.Gyr) / (2*np.pi),
          bins=(np.linspace(0, 25, 128),
                np.logspace(0, 1.3, 128)),
          norm=mpl.colors.LogNorm());
ax.set_ylim(1, 20)
ax.set_yscale('log')