In [None]:
import astropy.coordinates as coord
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
import numpy as np

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.potential as gp
from gala.units import galactic

import sys

if "../scripts" not in sys.path:
    sys.path.append("../scripts")
from streamsubhalosim import (
    get_in_stream_frame,
    StreamSubhaloSimulation,
    get_stream_track,
    get_new_basis,
    get_subhalo_w0,
)
from streamsubhaloplot import plot_sky_projections

In [None]:
mw = gp.load('/Users/apricewhelan/projects/gaia-actions/potentials/MilkyWayPotential2022.yml')

In [None]:
dist = 20.
pos = [-8, 0, 20] * u.kpc
vcirc = mw.circular_velocity(pos)[0]
wf = gd.PhaseSpacePosition(pos=pos, vel=[0, 1.3, 0] * vcirc)

In [None]:
sim = StreamSubhaloSimulation(
    mw_potential=mw,
    final_prog_w=wf,
    M_stream=5e4 * u.Msun,
    t_post_impact=200 * u.Myr,
    t_pre_impact=4 * u.Gyr,
    dt=0.5 * u.Myr,
    n_particles=2,
    seed=42,
)

In [None]:
(init_stream, init_prog), (final_init_stream, _) = sim.run_init_stream()
impact_site = sim.get_impact_site(init_stream, init_prog, prog_dist=10*u.kpc)

In [None]:
fig = init_stream.plot();
impact_site.plot(axes=fig.axes, autolim=False, color='r');

In [None]:
M_subhalo = 5e7 * u.Msun
c_subhalo = 1.005 * u.kpc * (M_subhalo / (1e8 * u.Msun)) ** 0.5 / 2.0  # MAGIC
subhalo_potential = gp.HernquistPotential(m=M_subhalo, c=c_subhalo, units=galactic)

In [None]:
b = c_subhalo * 0.5
subhalo_w0 = get_subhalo_w0(
    impact_site, b=b, phi=90 * u.deg, vphi=100 * u.km / u.s, vz=0 * u.km / u.s
)
subhalo_w0.frame = gp.StaticFrame(units=galactic)

In [None]:
subhalo_dv = np.linalg.norm(subhalo_w0.v_xyz - impact_site.v_xyz)
subhalo_dx = np.max(u.Quantity([b, c_subhalo]))
t_buffer_impact = np.round((32 * subhalo_dx / subhalo_dv).to(u.Myr), decimals=0)
tmp_impact_dt = np.round((t_buffer_impact / 256).to(u.Myr), decimals=1)
dt_factors = np.arange(1, 20, 1)
dts = sim.dt / dt_factors

dt_factor = dt_factors[np.abs(dts - tmp_impact_dt).argmin()]
impact_dt = dts[np.abs(dts - tmp_impact_dt).argmin()]

print(t_buffer_impact, impact_dt)

In [None]:
final_stream, _, final_prog, final_t = sim.run_perturbed_stream(
    subhalo_w0, subhalo_potential, t_buffer_impact, impact_dt
)

In [None]:
sim.t_pre_impact, final_t

In [None]:
final_impact_site = sim.H.integrate_orbit(
    impact_site, dt=sim.dt, t1=sim.t_pre_impact, t2=final_t
)

In [None]:
fig = final_stream.plot();
final_impact_site[-1].plot(axes=fig.axes, autolim=False, color='r', marker='o');
final_prog.plot(axes=fig.axes, autolim=False, color='g', marker='o');

In [None]:
stream_sfr = get_in_stream_frame(
    final_stream, prog=final_prog, impact=final_impact_site[-1]
)
init_stream_sfr = get_in_stream_frame(
    final_init_stream, stream_frame=stream_sfr
)
tracks = get_stream_track(init_stream_sfr, lon_lim=(-45, 45))

In [None]:
plot_sky_projections(stream_sfr);

In [None]:
plot_sky_projections(stream_sfr, tracks=tracks);