In [None]:
import pathlib

from astropy.convolution import Gaussian2DKernel, convolve
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.stats import binned_statistic, binned_statistic_2d
from astropy.stats import median_absolute_deviation as MAD
from tqdm.notebook import trange, tqdm

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic

from cmastro import cmaps
from totoro.actions import get_staeckel_aaf

In [None]:
potential = gp.MilkyWayPotential()

In [None]:
sun_w0 = gd.PhaseSpacePosition(pos=[-8.1, 0, 0.0206] * u.kpc, 
                               vel=[12.9, 245.6, 7.78] * u.km/u.s)
sun_orbit = potential.integrate_orbit(sun_w0, dt=2., t1=0, t2=250*u.Myr * 100, 
                                      Integrator=gi.DOPRI853Integrator)
sun_aaf = gd.find_actions(sun_orbit, N_max=8)

In [None]:
action_units = sun_aaf['actions']
action_units

In [None]:
def get_aafs(w0):
    orbits = potential.integrate_orbit(
        w0, dt=1., t1=0, t2=250*u.Myr * 100, 
        Integrator=gi.DOPRI853Integrator)
    
    aafs = []

    for i in trange(w0.shape[0]):
        try:
            aaf = gd.find_actions(orbits[:, i], N_max=8)
        except ValueError:
            continue

        for k in ['Sn', 'dSn_dJ', 'nvecs']:
            aaf.pop(k)
        aafs.append(aaf)

    aafs = at.QTable(aafs)
    aafs['periods'] = np.abs((2*np.pi*u.rad) / (aafs['freqs']*u.rad))
    
    return aafs

# Compare $\Omega_R(J_R)$ vs. $\Omega_z(J_z)$

In [None]:
xs = [-12, -8, -4]

In [None]:
npts = 32

R_all_aafs = []
for x in xs:
    xyz = np.repeat(np.array([[x, 0, 0]]).T, npts, axis=1) * u.kpc

    vxyz = np.zeros_like(xyz.value) * u.km/u.s
    vxyz[0] = np.geomspace(0.5, 100, npts) * u.km/u.s
    vxyz[1] = potential.circular_velocity(xyz)
    vxyz[2] = 1e-1*u.km/u.s

    w0 = gd.PhaseSpacePosition(pos=xyz, vel=vxyz)
    aafs = get_aafs(w0)
    
    R_all_aafs.append(aafs)

In [None]:
npts = 32

z_all_aafs = []
for x in xs:
    xyz = np.repeat(np.array([[x, 0, 0]]).T, npts, axis=1) * u.kpc

    vxyz = np.zeros_like(xyz.value) * u.km/u.s
    vxyz[0] = 0.1*u.km/u.s
    vxyz[1] = potential.circular_velocity(xyz)
    vxyz[2] = np.geomspace(0.1, 100, npts) * u.km/u.s

    w0 = gd.PhaseSpacePosition(pos=xyz, vel=vxyz)
    aafs = get_aafs(w0)
    
    z_all_aafs.append(aafs)

In [None]:
ax2 = axes[0].twinx()
    ax2.set_ylabel(f'period, $P_{coo[k]}$ [Myr]')
    ax2.set_ylim([1000 / x for x in axes[0].get_ylim()])
    
    ax3 = axes[1].twinx()
    ax3.set_ylabel(f'phase-mixing time [Myr]')
    ax3.set_ylim([1000 / x for x in axes[1].get_ylim()])

In [None]:
coo = {
    0: 'R',
    2: 'z'
}

spans = {
    0: [0.23645763, 9.5840498],
    2: [0.51749222, 50.65398411]
}

all_aafs = {
    0: R_all_aafs,
    2: z_all_aafs
}

In [None]:
from scipy.interpolate import interp1d

In [None]:
for k in [0, 2]:
    fig, axes = plt.subplots(2, 1, figsize=(8, 8), 
                             sharex=True)

    for aafs, xx in zip(all_aafs[k], xs):
        J = aafs['actions'][:, k].to_value(action_units[k])
        f = (aafs['freqs'][:, k]*u.rad / (2*np.pi*u.rad)).to_value(1 / u.Gyr)

        axes[0].plot(J, 1000 / f, label=f'$R={abs(xx):.0f}$ kpc',
                     lw=2, marker='')
        
        
        f_func = interp1d(J, f)
        axes[1].plot(J,
                     1 / np.abs(f_func(J) - f_func(1.)),
                     lw=2, marker='')

    axes[0].set_xlim(0, 1.5 * spans[k][1])

    axes[0].set_ylabel(f'period, $P_{coo[k]}$ [Myr]')
#     axes[1].set_ylabel('$' + 
#                        r'\sigma_{\nu_' + coo[k] + '}^{-1} = ' +
#                        r'\left[\frac{{\rm d}\nu_' + coo[k] + r'}' + 
#                        r'{{\rm d}J_' + coo[k] + r'} \, \sigma_{J_' + coo[k] + r'}\right]^{-1}$ ' +
#                        r'[${\rm Gyr}$]')
    axes[1].set_ylabel('spiral-forming timescale [Gyr]')
    axes[1].set_xlabel(f'$J_{coo[k]}$')

    axes[0].legend(loc='upper right', fontsize=14)
    
    axes[0].set_ylim(0, 300)
    axes[1].set_ylim(0, 7)

    # 5, 95 percentile from local data:
    for ax in axes:
        ax.axvspan(*spans[k], zorder=-10, color='tab:green', 
                   alpha=0.1, linewidth=0)

    fig.tight_layout()

### Note: this is the phase-mixing timescale below. I don't think that's actually what we want?

In [None]:
for k in [0, 2]:
    fig, axes = plt.subplots(2, 1, figsize=(8, 8), 
                             sharex=True)

    for aafs, xx in zip(all_aafs[k], xs):
        J = aafs['actions'][:, k].to_value(action_units[k])
        f = (aafs['freqs'][:, k]*u.rad / (2*np.pi*u.rad)).to_value(1 / u.Gyr)

        axes[0].plot(J, 1000 / f, label=f'$R={abs(xx):.0f}$ kpc',
                     lw=2, marker='')

        axes[1].plot(0.5 * (J[:-1] + J[1:]), 
                     1 / (np.abs(np.diff(f) / np.diff(J)) * spans[k][1]),
                     lw=2, marker='')

    axes[0].set_xlim(0, 1.5 * spans[k][1])

    axes[0].set_ylabel(f'period, $P_{coo[k]}$ [Myr]')
    axes[1].set_ylabel('$' + 
                       r'\sigma_{\nu_' + coo[k] + '}^{-1} = ' +
                       r'\left[\frac{{\rm d}\nu_' + coo[k] + r'}' + 
                       r'{{\rm d}J_' + coo[k] + r'} \, \sigma_{J_' + coo[k] + r'}\right]^{-1}$ ' +
                       r'[${\rm Gyr}$]')
    axes[1].set_xlabel(f'$J_{coo[k]}$')

    axes[0].legend(loc='upper right', fontsize=14)
    
    axes[0].set_ylim(0, 300)
    axes[1].set_ylim(0, 7)

    # 5, 95 percentile from local data:
    for ax in axes:
        ax.axvspan(*spans[k], zorder=-10, color='tab:green', 
                   alpha=0.1, linewidth=0)

    fig.tight_layout()

In [None]:
for k in [0, 2]:
    fig, axes = plt.subplots(2, 1, figsize=(8, 8), 
                             sharex=True)

    for aafs, xx in zip(all_aafs[k], xs):
        J = aafs['actions'][:, k].to_value(action_units[k])
        f = (aafs['freqs'][:, k]*u.rad / (2*np.pi*u.rad)).to_value(1 / u.Gyr)

        axes[0].plot(J, f, label=f'$R={abs(xx):.0f}$ kpc',
                     lw=2, marker='')

        axes[1].plot(0.5 * (J[:-1] + J[1:]), 
                     np.abs(np.diff(f) / np.diff(J)) * spans[k][1],
                     lw=2, marker='')

    axes[0].set_xlim(0, 1.5 * spans[k][1])

    axes[0].set_ylabel(fr'frequency, $\nu_{coo[k]}$ ' + r'[${\rm Gyr}^{-1}$]')
    axes[1].set_ylabel('$' + 
                       r'\sigma_{\nu_' + coo[k] + '} = ' +
                       r'\frac{{\rm d}\nu_' + coo[k] + r'}' + 
                       r'{{\rm d}J_' + coo[k] + r'} \, \sigma_{J_' + coo[k] + '}$ ' +
                       r'[${\rm Gyr}^{-1}$]')
    axes[1].set_xlabel(f'$J_{coo[k]}$')

    axes[0].legend(loc='upper right', fontsize=14)

    # 5, 95 percentile from local data:
    for ax in axes:
        ax.axvspan(*spans[k], zorder=-10, color='tab:green', 
                   alpha=0.1, linewidth=0)

    fig.tight_layout()

In [None]:
2π / (Omega(J_1) - Omega(J_2))