In [None]:
import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.spatial import cKDTree
from scipy.stats import binned_statistic
from scipy.interpolate import interp1d
from tqdm import tqdm

# gala
import gala.coordinates as gc
import gala.dynamics as gd
import gala.integrate as gi
import gala.potential as gp
from gala.units import galactic

from thriftshop.config import galcen_frame, elem_names
from thriftshop.data import load_apogee_sample
from thriftshop.potentials import potentials, galpy_potentials
from thriftshop.objective import TorusImagingObjective

In [None]:
t, c = load_apogee_sample('../data/apogee-parent-sample.fits')
t = t[np.argsort(t['APOGEE_ID'])]

In [None]:
galcen = c.transform_to(galcen_frame)
w0s = gd.PhaseSpacePosition(galcen.data)

In [None]:
obj = TorusImagingObjective(t, c, 'MG_FE', tree_K=20)

In [None]:
all_p_vals = []
all_obj_vals = []
tree_Ks = [2, 4, 8, 16, 32, 64, 128, 256]
for tree_K in tree_Ks:
    obj = TorusImagingObjective(t, c, 'MG_FE', tree_K=tree_K)

    i = 0

    x0 = np.array([1.1, -16.6, 7.78])
    # x0 = res.x

    idx = np.arange(3)
    idx = np.delete(idx, i)

    obj_vals = [] 
    p_vals = x0[i] * np.linspace(0.8, 1.2, 32)
    for val in tqdm(p_vals):
        p = np.full(3, np.nan)
        p[i] = val
        p[idx] = x0[idx]
        obj_vals.append(obj(p))

    obj_vals = np.array(obj_vals)
    
    all_p_vals.append(p_vals)
    all_obj_vals.append(obj_vals)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
for tree_K, p_vals, obj_vals in zip(tree_Ks[:4], all_p_vals, all_obj_vals):
    ax.plot(p_vals, obj_vals * 1e7, label=f"K={tree_K}")
ax.legend()
fig.set_facecolor('w')
fig.tight_layout()

In [None]:
tree_Ks = 2 ** np.arange(2, 8+1, 1)
results = []
for tree_K in tqdm(tree_Ks):
    obj = TorusImagingObjective(t, c, 'MG_FE', tree_K=tree_K)
    res = obj.minimize()
    results.append(res)

In [None]:
xs = np.array([r.x for r in results])

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(8, 10), 
                         sharex=True)
for i in range(3):
    axes[i].plot(tree_Ks, xs[:, i])
    axes[i].errorbar(tree_Ks, xs[:, i], 0.07*xs[:, i], 
                     zorder=-10, marker='', ls='none', ecolor='#aaaaaa')

axes[0].set_ylabel('mdisk')
axes[1].set_ylabel('zsun')
axes[2].set_ylabel('vzsun')
    
axes[0].set_title('optimization results')
axes[0].set_xscale('log', basex=2)
axes[2].set_xlabel('$K$')
fig.set_facecolor('w')

### Bootstrap testing:

In [None]:
tree_K = 20
boot_K = 128
np.random.seed(42)

# Full sample fit:
obj = TorusImagingObjective(t, c, 'MG_FE', tree_K=tree_K)
full_sample_res = obj.minimize()
if not full_sample_res.success:
    raise RuntimeError("WTF")

results = []
for k in tqdm(range(boot_K)):
    idx = np.random.choice(len(t), len(t), replace=True)
    obj = TorusImagingObjective(t[idx], c[idx], 'MG_FE', 
                                tree_K=tree_K)
    res = obj.minimize(x0=full_sample_res.x)
    results.append(res)

In [None]:
results[0]

In [None]:
full_sample_res