In [None]:
import os
import pickle

import astropy.coordinates as coord
from astropy.convolution import convolve, Gaussian2DKernel
from astropy.io import fits
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.stats import binned_statistic_2d
import corner

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic

from thriftshop.potentials import potentials
from thriftshop.config import vcirc, rsun
from thriftshop.actions import safe_get_actions, get_w0s_with_same_actions

coord.galactocentric_frame_defaults.set('v4.0');

In [None]:
t = at.Table.read('../data/apogee-parent-sample.fits')
t = t[t['GAIA_PARALLAX'] > 0.5]
len(t)

In [None]:
c = coord.SkyCoord(ra=t['RA']*u.deg,
                   dec=t['DEC']*u.deg,
                   distance=1000 / t['GAIA_PARALLAX'] * u.pc,
                   pm_ra_cosdec=t['GAIA_PMRA']*u.mas/u.yr,
                   pm_dec=t['GAIA_PMDEC']*u.mas/u.yr,
                   radial_velocity=t['VHELIO_AVG']*u.km/u.s)
galcen = c.transform_to(coord.Galactocentric)

In [None]:
z = galcen.z.to_value(u.kpc)
vz = galcen.v_z.to_value(u.km/u.s)
elem = t['MG_FE']
# elem = t['MN_FE']

In [None]:
zlim = 2 # kpc
vlim = 100. # km/s
vstep = 4
zstep = 75 / 1e3
vzz_bins = (np.arange(-vlim, vlim+1e-3, vstep),
            np.arange(-zlim, zlim+1e-3, zstep))

fig, axes = plt.subplots(1, 2, figsize=(12, 5),
                         constrained_layout=True)

stat = binned_statistic_2d(vz, z, elem, statistic='mean',
                           bins=vzz_bins)

vmin, vmax = np.percentile(elem, [15, 85])

ax = axes[0]
cs = ax.pcolormesh(stat.x_edge, stat.y_edge, stat.statistic.T, 
                   cmap='cividis', vmin=vmin, vmax=vmax)
cb = fig.colorbar(cs, ax=ax, aspect=40)

ax.set_xlabel('v_z')
ax.set_ylabel('z')

ax = axes[1]
H, *_ = np.histogram2d(vz, z, bins=vzz_bins)
cs = ax.pcolormesh(stat.x_edge, stat.y_edge, H.T, 
                   cmap='cividis', 
                   norm=mpl.colors.LogNorm(1, 3e2))
cb = fig.colorbar(cs, ax=ax, aspect=40)

ax.set_xlabel('v_z')
# ax.set_ylabel('z')

fig.set_facecolor('w')

# ---

# fig, ax = plt.subplots(figsize=(6, 5))

# stat = binned_statistic_2d(vz, z, elem, statistic='mean',
#                            bins=vzz_bins)

# vmin, vmax = np.percentile(elem, [15, 85])

# HH = stat.statistic.copy()
# HH[H < 3] = np.nan
# cs = ax.pcolormesh(stat.x_edge, stat.y_edge, 
#                    HH.T, cmap='cividis', vmin=vmin, vmax=vmax)
# cb = fig.colorbar(cs, ax=ax, aspect=40)

# ax.set_xlabel('v_z')
# ax.set_ylabel('z')
# fig.tight_layout()

### Initial conditions for our comparison orbits

In [None]:
fiducial_w0 = gd.PhaseSpacePosition(
    pos=([[-rsun.to_value(u.kpc), 0, 0],
          [-rsun.to_value(u.kpc), 0, 0]]*u.kpc).T,
    vel=([[15, vcirc.to_value(u.km/u.s), 20.],
          [15, vcirc.to_value(u.km/u.s), 45.]]*u.km/u.s).T)

In [None]:
w0s = get_w0s_with_same_actions(fiducial_w0, staeckel=True)

In [None]:
with open('../cache/w0s.pkl', 'wb') as f:
    pickle.dump(w0s, f)

In [None]:
orbits = {}
for k, w0 in w0s.items():
    print(k)
    print(w0.v_xyz.T)
    orbits[k] = potentials[k].integrate_orbit(
        w0, dt=0.5*u.Myr, t1=0, t2=6*u.Gyr
    )

### Compute actions for these orbits with the Sanders & Binney method to compare:

In [None]:
sanders_actions = {}
for name in potentials.keys():
    sanders_actions[name] = []
    for n in range(w0s[name].shape[0]):
        actions = safe_get_actions(potentials[name], w0s[name][n], N_max=8)['actions']
        sanders_actions[name].append(actions)
    sanders_actions[name] = u.Quantity(sanders_actions[name])

In [None]:
[v[0] for x, v in sanders_actions.items()]  # Orbit 1

In [None]:
[v[1] for x, v in sanders_actions.items()]  # Orbit 2

In [None]:
with open('../cache/w0s-actions.pkl', 'wb') as f:
    pickle.dump(sanders_actions, f)

### Visualize the orbits:

In [None]:
sorted_keys = ['0.4', 'fiducial', '1.6']

In [None]:
plot_zlim = 1.75
plot_vzlim = 100

# -----
# vz, z
fig, axes = plt.subplots(1, 3, figsize=(15, 5), 
                         sharex=True, sharey=True)

for k, ax in zip(sorted_keys, axes):
    _ = orbits[k].plot(['v_z', 'z'], axes=[ax], 
                       auto_aspect=False, units=[u.km/u.s, u.kpc])
    try:
        ax.set_title(f'${float(k):.1f}' + r' \, {\rm M}_{\rm disk}$')
    except ValueError:
        ax.set_title(f'{k}')

axes[1].set_ylabel('')
axes[2].set_ylabel('')

axes[0].set_xlim(-plot_vzlim, plot_vzlim)
axes[0].set_ylim(-plot_zlim, plot_zlim)

fig.tight_layout()

# -----
# R, z

fig, axes = plt.subplots(1, 3, figsize=(15, 5), 
                         sharex=True, sharey=True)

for k, ax in zip(sorted_keys, axes):
    _ = orbits[k].cylindrical.plot(
        ['rho', 'z'], axes=[ax], 
        auto_aspect=False, units=[u.kpc, u.kpc])
    
    try:
        ax.set_title(f'${float(k):.1f}' + r' \, {\rm M}_{\rm disk}$')
    except ValueError:
        ax.set_title(f'{k}')

axes[1].set_ylabel('')
axes[2].set_ylabel('')

axes[0].set_xlim(7., 10.)
axes[0].set_ylim(-plot_zlim, plot_zlim)

fig.tight_layout()

In [None]:
zlim = 1.75 # kpc
vlim = 75. # pc/Myr
vstep = 4.
zstep = 75 / 1e3
vzz_bins = (np.arange(-vlim, vlim+1e-3, vstep),
            np.arange(-zlim, zlim+1e-3, zstep))

In [None]:
import itertools

quad_vz = np.array([])
quad_z = np.array([])
quad_elem = np.array([])
for i, j in itertools.product([-1,1], [-1, 1]):
    quad_vz = np.concatenate((quad_vz, i * vz))
    quad_z = np.concatenate((quad_z, j * z))
    quad_elem = np.concatenate((quad_elem, elem))

quad_stat = binned_statistic_2d(quad_vz, quad_z, 
                                quad_elem, statistic='mean',
                                bins=vzz_bins)
stat = binned_statistic_2d(vz, z, 
                           elem, statistic='mean',
                           bins=vzz_bins)

quad_counts, *_ = np.histogram2d(quad_vz, quad_z, bins=vzz_bins)
counts, *_ = np.histogram2d(vz, z, bins=vzz_bins)

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(17, 5), 
                         constrained_layout=True,
                         sharex=True, sharey=True)

H = stat.statistic
# H = quad_stat.statistic
H[quad_counts <= 1] = np.nan

# vmin, vmax = np.percentile(elem, [5, 95])
# vmin, vmax = (-0.4, 0.2)
vmin, vmax = (-0.02, 0.1)

for ax in axes:
    cs = ax.pcolormesh(stat.x_edge, stat.y_edge, 
                       H.T, cmap='magma', vmin=vmin, vmax=vmax)
    ax.set_xlabel(f'$v_z$ [{u.km/u.s:latex_inline}]')

cb = fig.colorbar(cs, ax=axes, aspect=40)
# cb.set_label(r'$[{\rm Fe} / {\rm H}]$')
cb.set_label(r'$[{\rm Mg} / {\rm Fe}]$')

for k, color, ax in zip(sorted_keys,
                        ['w', 'w', 'w'],
                        axes):
    o = orbits[k]
    o_z = o.z.to_value(u.kpc)
    o_vz = o.v_z.to_value(u.km/u.s)
    ax.plot(o_vz, o_z, marker='', color=color, alpha=0.5)
    
    try:
        ax.set_title(f'${float(k):.1f}' + r' \, {\rm M}_{\rm disk}$')
    except ValueError:
        ax.set_title(f'{k}')
    
axes[0].set_xlim(-vlim, vlim)
axes[0].set_ylim(-zlim, zlim)
    
axes[0].set_ylabel(f'$z$ [{u.kpc:latex_inline}]')

fig.suptitle(r'$v_{\rm circ} = 229\,{\rm km}\,{\rm s}^{-1}$ for all potentials',
             fontsize=24)

fig.set_facecolor('w')
fig.savefig('../plots/orbits-Mg-Fe.png', dpi=250)