In [None]:
import astropy.units as u
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('apw-notebook')
%matplotlib inline

import gala.dynamics as gd
import gala.potential as gp
import gala.integrate as gi
from gala.units import galactic

In [None]:
pot = gp.MilkyWayPotential()
frame = gp.ConstantRotatingFrame(Omega=[0,0,-42.]*u.km/u.s/u.kpc,
                                 units=galactic)
H = gp.Hamiltonian(potential=pot, frame=frame)

In [None]:
# Third-party
import numpy as np
from scipy.optimize import minimize

def tube_grid_xz(EJ, hamiltonian, dx=1., dz=1.):
    
    # find maximum x on z=0
    def func(x):
        return (EJ - hamiltonian.energy(np.array([[x[0],0,0]])).value[0])**2
    res = minimize(func, x0=[10.], method='powell')
    if not res.success:
        raise ValueError("Failed to find boundary of ZVC on x-axis.")
    max_x = res.x

    xgrid = np.arange(0.1, max_x+dx, dx)

    # compute ZVC boundary for each x
    for xg in xgrid:
        # find maximum allowed z along x=xx
        def func(x):
            return (EJ - hamiltonian.energy(np.array([[xg,0,x[0]]])).value[0])**2
        res = minimize(func, x0=[25.], method='powell')
        max_z = np.abs(res.x)
        if not res.success or max_z == 25.:
            vals = np.linspace(0.1,100)
            plt.clf()
            plt.plot(vals,[func([derp]) for derp in vals])
            plt.show()
            raise ValueError("Failed to find boundary of ZVC for x={}.".format(xg))

        # logger.debug("Max. z: {}".format(max_z))
        zgrid = np.arange(0.1, max_z, dz)
        xs = np.zeros_like(zgrid) + xg
        try:
            xz = np.hstack((xz, np.vstack((xs,zgrid))))
        except NameError:
            xz = np.vstack((xs,zgrid))

    xyz = np.zeros((3, xz.shape[-1]))
    xyz[0] = xz[0]
    xyz[2] = xz[1]

    # now, for each grid point, compute the y velocity
    vxyz = np.zeros_like(xyz)
    Omz = hamiltonian.frame.parameters['Omega'][2].to(1/u.Myr).value
    Phi = hamiltonian.potential.energy(xyz).value
    vxyz[1] = (Omz * xz[0] + np.sqrt(Omz**2*xz[0]**2 - (Phi - EJ)))

    return np.concatenate((xyz, vxyz), axis=0)

In [None]:
w0 = gd.PhaseSpacePosition(pos=[8.,0,0]*u.kpc,
                           vel=[0,170.,35.]*u.km/u.s)
H.energy(w0).value[0]

In [None]:
w = tube_grid_xz(-0.088904381, H)

In [None]:
plt.figure(figsize=(6,6))
plt.scatter(w[0], w[2])
plt.xlim(-1, 40)
plt.ylim(-1, 40)