In [None]:
import os

import astropy.table as at
import astropy.coordinates as coord
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

from gala.units import galactic

import jax
jax.config.update('jax_enable_x64', True)
import jax.numpy as jnp

from empaf import DensityOrbitModel
from empaf.jax_helpers import designer_func

In [None]:
grid = np.linspace(0, 5, 128)
plt.plot(grid, designer_func(grid, A=0.1, alpha=0.33, x0=3.))
plt.plot(grid, -designer_func(grid, A=0.01, alpha=0.45, x0=3.))

In [None]:
grid = np.linspace(0, 5, 128)
plt.plot(grid, np.exp(designer_func(grid, A=-30, alpha=0.55, x0=3., c=0.) - 20.))
plt.yscale('log')

In [None]:
grid = np.linspace(0, 3, 128)

def get_ln_dens(x, f0, f3, alpha, x0):
    A = (f3 - f0) / (1 - designer_func(3., 1., alpha, x0, c=0.))
    offset = f0 + A
    return (
        designer_func(grid, c=0.0, A=A, alpha=alpha, x0=x0)
        + offset
    )

func_vals = np.exp(get_ln_dens(grid, f0=9.5, f3=-20, alpha=0.55, x0=3.))

plt.plot(grid, func_vals)
plt.axhline(np.exp(9.5))
plt.axhline(np.exp(-20))
plt.yscale('log')

---

In [None]:
init_model = DensityOrbitModel(
    e_signs={2: 1.0, 4: -1.0}, # , 6: -1.0, 8: -1.0},
    unit_sys=galactic,
)

In [None]:
model0 = init_model.copy()

valid_state = {}
valid_state["e_params"] = {m: {} for m in [2, 4]}

# valid_state["e_params"][2]['A'] = 0.1
valid_state["e_params"][2]['A'] = 0.
valid_state["e_params"][2]['alpha'] = 0.33
valid_state["e_params"][2]['x0'] = 3.

# valid_state["e_params"][4]['A'] = 0.04
valid_state["e_params"][4]['A'] = 0.2
valid_state["e_params"][4]['alpha'] = 0.45
valid_state["e_params"][4]['x0'] = 3.

valid_state["ln_dens_params"] = {"f0": 9.5, "f3": -20, "alpha": 0.54, "x0": 3.0}

valid_state["Omega"] = 0.06
valid_state["z0"] = 0.0
valid_state["vz0"] = 0.0

model0.state = valid_state
model0._validate_state()

In [None]:
z_grid, vz_grid = np.meshgrid(
    np.linspace(-2, 2, 128),
    np.linspace(-0.1, 0.1, 128)
)

In [None]:
vlim = dict(
    norm=mpl.colors.LogNorm(vmax=3e4, vmin=1e-1), shading="auto"
) 

fig, ax = plt.subplots(
    1, 1, figsize=(6, 5), constrained_layout=True
)

cs = ax.pcolormesh(
    vz_grid, z_grid, 
    np.exp(model0.ln_density(z=z_grid, vz=vz_grid)),
    **vlim
)
fig.colorbar(cs, ax=ax)

ax.contour(
    vz_grid, z_grid, 
    np.exp(model0.ln_density(z=z_grid, vz=vz_grid)),
    levels=16,
    colors='k'
)

ax.set_title("initial model")