In [None]:
import pickle

import astropy.table as at
import astropy.coordinates as coord
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib.collections import LineCollection
import numpy as np
from scipy.interpolate import interp1d
from scipy.integrate import simps

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic
from gala.mpl_style import hesperia, laguna, turbo

from thriftshop.config import rsun, vcirc, fig_path, cache_path
from thriftshop.config import plot_config as pc
from thriftshop.potentials import potentials, galpy_potentials
from thriftshop.galpy_helpers import get_staeckel_aaf

coord.galactocentric_frame_defaults.set('v4.0');

In [None]:
with open(cache_path / 'w0s.pkl', 'rb') as f:
    w0s = pickle.load(f)

with open(cache_path / 'w0s-actions.pkl', 'rb') as f:
    w0s_actions = pickle.load(f)
    
name = '1.0'
potential = potentials[name]
w0 = w0s[name]
actions = w0s_actions[name]

In [None]:
orbits = potential.integrate_orbit(
    w0, dt=0.1*u.Myr, t1=0, t2=2*u.Gyr,
    Integrator=gi.DOPRI853Integrator
)

In [None]:
tmp = orbits[:, 1]
tmp = tmp[tmp.z > 0]
i1 = np.where(np.abs(np.diff(tmp.v_z.value)) > 1e-3)[0][0]
tmp = tmp[:i1]

tmp_z = tmp.z.to_value(pc['zunit'])
tmp_vz = tmp.v_z.to_value(pc['vunit'])
tmp_Jz = 2 * simps(y=tmp_vz, x=tmp_z) / (2 * np.pi)

In [None]:
aafs = []
for i in range(orbits.shape[1]):
    tmp = get_staeckel_aaf(orbits[:, i], galpy_potentials[name])
    aafs.append(at.QTable({k: tmp[k].T for k in tmp}))

In [None]:
def cmap_line(x, y, c, ax, lw=1, cmap='cividis', vmin=0, vmax=2*np.pi, rasterize=True):
    # Taken from: https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/multicolored_line.html
    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)

    norm = plt.Normalize(vmin, vmax)
    lc = LineCollection(segments, cmap=cmap, norm=norm)
    lc.set_array(c)
    lc.set_linewidth(lw)
    line = ax.add_collection(lc)
    
    if rasterize:
        lc.set_antialiased(True)
        lc.set_rasterized(True)
    
    smap = mpl.cm.ScalarMappable(norm=norm, cmap=plt.get_cmap(cmap))
    
    return lc, smap

In [None]:
Jlim = 2 * np.sqrt(tmp_Jz)
Jticks = np.arange(-8, 8+1, 4)
Jminorticks = np.arange(-8, 8+1, 2)

In [None]:
pc['Rminorticks']

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15.5, 4.9), 
                         constrained_layout=True)

ax = axes[0]
for i in range(orbits.shape[1]):
    o = orbits[:, i]
    aaf = aafs[i]
    
    z = o.z.to_value(pc["zunit"])
    R = o.cylindrical.rho.to_value(pc["zunit"])
    tz = aaf['angles'][:, 2].to_value(u.rad)
    
    # ax.scatter(vz, z, c=aaf['angles'][:, 2].to_value(u.rad), vmin=0, vmax=6.28, cmap=hesperia, s=1)
    _, smap = cmap_line(R, z, tz, ax=ax, lw=2.5)

ax.set_xlim(pc['Rlim'])
ax.set_xticks(pc['Rticks'])
ax.set_xticks(pc['Rminorticks'], minor=True)
ax.set_ylim(-pc['zlim'], pc['zlim'])
ax.set_yticks(pc['zticks'])
ax.set_yticks(pc['zminorticks'], minor=True)

ax.set_xlabel(f'$R$ [{pc["zunit"]:latex_inline}]')
ax.set_ylabel(f'$z$ [{pc["zunit"]:latex_inline}]')
    

# ---

ax = axes[1]
for i in range(orbits.shape[1]):
    o = orbits[:, i]
    aaf = aafs[i]
    
    z = o.z.to_value(pc['zunit'])
    vz = o.v_z.to_value(pc['vunit'])
    tz = aaf['angles'][:, 2].to_value(u.rad)
    
    # ax.scatter(vz, z, c=aaf['angles'][:, 2].to_value(u.rad), vmin=0, vmax=6.28, cmap=hesperia, s=1)
    _, smap = cmap_line(vz, z, tz, ax=ax, lw=2.5)

ax.set_xlim(-pc['vlim'], pc['vlim'])
ax.set_xticks(pc['vticks'])
ax.set_xticks(pc['vminorticks'], minor=True)
ax.set_ylim(-pc['zlim'], pc['zlim'])
ax.set_yticks(pc['zticks'])
ax.set_yticks(pc['zminorticks'], minor=True)

ax.set_xlabel(f'$v_z$ [{pc["vunit"]:latex_inline}]')
ax.set_ylabel(f'$z$ [{pc["zunit"]:latex_inline}]')

# inset axes for outer orbit
axins = ax.inset_axes([0.03, 0.03, 0.33, 0.33])
cmap_line(vz, z, tz, ax=axins, lw=2)
axins.set_xlim(39+0.2, 49-0.2)
axins.set_ylim(-0.12, 0.12)
axins.xaxis.set_visible(False)
axins.yaxis.set_visible(False)

ax.indicate_inset_zoom(axins)
    
# ---

ax = axes[2]
for i in range(orbits.shape[1]):
    aaf = aafs[i]
    
    Jz = np.mean(aaf['actions'][:, 2]).to_value(pc["Junit"])
    tz = np.linspace(0, 2*np.pi, 512)
    
    Jzcos = np.sqrt(Jz) * np.cos(tz)
    Jzsin = np.sqrt(Jz) * np.sin(tz)
    
    _, smap = cmap_line(Jzcos, Jzsin, tz, ax=ax, lw=2.5)

# Arrow and theta_z indication
anno_color = '#888888'
ax.plot([0, 10], [0, 0], marker='', ls='-', color=anno_color, zorder=-10, alpha=0.6)

_th = 20 * u.deg
_r = 7.5
arrsty = mpl.patches.ArrowStyle.Simple(head_length=0.6, 
                                       head_width=0.4, 
                                       tail_width=0.12)
arrow = mpl.patches.FancyArrowPatch(
    (_r, 0), (_r*np.cos(_th), _r*np.sin(_th)), mutation_scale=15,
    connectionstyle=f"arc3,rad=0.15", 
    arrowstyle=arrsty,
    linewidth=0, color=anno_color, zorder=10, alpha=1.)
ax.add_patch(arrow)
ax.text(_r*np.cos(_th), _r*np.sin(_th), r'$+\theta_z$', 
        color='#555555', alpha=1, fontsize=18, 
        ha='center', va='bottom')

ax.set_xlim(-Jlim, Jlim)
ax.set_ylim(-Jlim, Jlim)

ax.set_xticks(Jticks)
ax.set_xticks(Jminorticks, minor=True)
ax.set_yticks(Jticks)
ax.set_yticks(Jminorticks, minor=True)

ax.set_xlabel(r'$\sqrt{J_z} \, \cos\left(\theta_z\right)$ ' + f'[{pc["Junit"]:latex_inline}]')
ax.set_ylabel(r'$\sqrt{J_z} \, \sin\left(\theta_z\right)$ ' + f'[{pc["Junit"]:latex_inline}]')

cb = fig.colorbar(smap, ax=axes, aspect=40)
cb.set_label(r'$\theta_z$ [rad]')

fig.suptitle("Two orbits in different phase-space coordinates and projections",
             fontsize=22)

fig.set_facecolor('w')
fig.savefig(fig_path / 'zvz-orbit-demo.pdf', dpi=400)