## DATA SETUP

In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
# load local version before pip installed version, for debugging
import pathlib
import sys
import os

sys.path.append(str(pathlib.Path(os.getcwd()).parent.joinpath("src")))

In [None]:
import yt
import numpy as np
import matplotlib.pyplot as plt

from gallifrey.setup import data_setup
from gallifrey.utilities.math import calculate_pca
from gallifrey.particles import rotated_dataset

In [None]:
snapshot = 127
resolution = 4096
sim_id = "09_18"
ngpps_id: str = "ng75"
planet_categories = [
    "Earth",
    "Super-Earth",
    "Neptunian",
    "Sub-Giant",
    "Giant",
    "D-Burner",
]
save = False

In [None]:
ds, mw, stellar_model, imf, planet_model = data_setup(
    snapshot=snapshot, resolution=resolution, sim_id=sim_id, ngpps_id=ngpps_id
)

## PLOT SETUP

In [None]:
from planet_maps import plot_maps
from planet_1dprofiles import plot_1dprofiles
from planet_2dprofiles import plot_2dprofiles

In [None]:
def add_star_weighted_field(category, normalize=True):
    def _star_weighted_planets(field, data):
        planets_per_star = data["stars", category] / data["stars", "number"]
        if normalize:
            max_value = np.amax(planets_per_star)
            if max_value == 0:
                return np.repeat(0, len(planets_per_star))
            else:
                return planets_per_star / np.amax(planets_per_star)
        else:
            return planets_per_star

    ds.add_field(
        ("stars", f"star_weighted_{category}"),
        function=_star_weighted_planets,
        sampling_type="local",
        units="auto",
        dimensions=1,
    )


for category in planet_categories:
    add_star_weighted_field(category, normalize=False)

## CREATE DATA SOURCE

In [None]:
radius = 60
normal_vector = calculate_pca(
    mw.sphere(radius=(10, "kpc"))["stars", "Coordinates"]
).components_[-1]

In [None]:
sphere_data = mw.sphere(radius=ds.quan(radius, "kpc"))

rotated_disk_data = rotated_dataset(
    mw.disk(
        radius=ds.quan(radius, "kpc"), height=ds.quan(0.5, "kpc"), normal=normal_vector
    ),
    mw.centre(),
    normal_vector,
    [
        ("stars", "[Fe/H]"),
        ("stars", "number"),
        *[("stars", f"{category}") for category in planet_categories],
        *[("stars", f"star_weighted_{category}") for category in planet_categories],
    ],
)

rotated_sphere_data = rotated_dataset(
    sphere_data,
    mw.centre(),
    normal_vector,
    [
        ("stars", "[Fe/H]"),
        ("stars", "number"),
        *[("stars", f"{category}") for category in planet_categories],
        *[("stars", f"star_weighted_{category}") for category in planet_categories],
    ],
)

In [None]:
fields = [("stars", category) for category in planet_categories]

plot = yt.ParticleProjectionPlot(
    ds=rotated_sphere_data,
    fields=fields,
    axis="z",
    width=(42, "kpc"),
    deposition="cic",
    weight_field=("stars", "number"),
    density=True,
)

for field in fields:
    plot.set_unit(field, "1/pc**2")

image_values = np.array([np.array(plot.frb[field]) for field in fields]).flatten()
percentiles = np.nanpercentile(image_values[image_values > 0], [1, 99])

for field in fields:
    plot.set_cmap(field, "kelp")
    plot.set_colorbar_label(
        field, field[-1] + r"s Per Star $\left(1/\mathrm{pc}^2\right)$"
    )

    image_values = np.array(plot.frb[field]).flatten()
    percentiles = np.nanpercentile(image_values[image_values > 0], [1, 99])

    plot.set_zlim(field, *percentiles)
fig = plot.export_to_mpl_figure((2, 3))
fig.set_size_inches(18.5, 10.5)
fig.tight_layout()

## FACE-ON MAPS

In [None]:
zplot, zplots = plot_maps(rotated_data, axis="z", save=save)

## SIDE-ON MAPS

In [None]:
yplot, yplots = plot_maps(rotated_data, axis="x", save=save)

## 1D Profiles

In [None]:
figs, axes = plot_1dprofiles(
    disk_data.sphere([0, 0, 0], 30), halo=mw, disk_height=ds.quan(0.5, "kpc"), save=save
)

## 2D Profiles

In [None]:
age_plots = plot_2dprofiles(sphere_data, "stellar_age", save=save)

In [None]:
fe_plots = plot_2dprofiles(sphere_data, "metallicity", save=save)