# Quickstart Tutorial

`galtab` is a general approach for calculating the expectation value of
counts-in-cells statistics for a given halo catalog and HOD model. It pretabulates
placeholder galaxies inside each halo to yield rapid, deterministic results,
which is ideal for MCMC likelihood evaluations.

This [tutorial](https://github.com/AlanPearl/galtab/blob/main/docs/source/notebooks/intro.ipynb)
will demonstrate some basic Counts-in-Cylinders (CiC) calculations
using the intended `galtab` workflow.

To cite `galtab`, learn more implementation details, and explore an example science
use case, check out https://arxiv.org/abs/2309.08675.

## Prerequisites

All of the following are `pip` installable

- `galtab`
    - `numpy`
    - `jax`
    - `astropy`
    - `halotools`
- `matplotlib`
- `jupyterlab`

After installing the above *and downloading the bolplanck z=0 halotools catalog*,
you should be able to run the following cell. In this cell:

- set our cosmology and CiC parameters
- choose an HOD model
- load the simulation data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax

from astropy import cosmology
import halotools.empirical_models as htem
import halotools.sim_manager as htsm
import halotools.mock_observables as htmo

import galtab

In [None]:
# Download an example halotools catalog
htsm.DownloadManager().download_processed_halo_table(
    'bolplanck', 'rockstar', 0.0)

In [None]:
# Set our CiC parameters (all lengths are in Mpc/h)
proj_search_radius = 2.0
cylinder_half_length = 10.0
cic_edges = np.arange(-0.5, 16)

# Set our cosmology and HOD model
cosmo = cosmology.Planck13
hod = htem.PrebuiltHodModelFactory("zheng07", threshold=-21)

# Load Bolshoi-Planck simulation halos at z=0
halocat = htsm.CachedHaloCatalog(simname="bolplanck", redshift=0)
halocat.halo_table[:5]

## Calculate CiC the standard way with `halotools`

- Populate the halocat with galaxies probabilistically from the HOD model
- Compute the number of neighbors within a cylinder around each neighbor
- Tally up a histogram of the neighbor counts for a given set of CiC bins

In [None]:
# Choose your HOD parameters (in this case, we will keep them the same)
hod.param_dict.update({})

# Populated model galaxies and get their Cartesian coordinates
hod.populate_mock(halocat, seed=0)
galaxies = hod.mock.galaxy_table
xyz = htmo.return_xyz_formatted_array(
    galaxies["x"], galaxies["y"], galaxies["z"], velocity=galaxies["vz"],
    velocity_distortion_dimension="z", period=halocat.Lbox, cosmology=cosmo
)

# Compute CiC (self-counting subtracted by the `-1`)
cic_counts = htmo.counts_in_cylinders(
    xyz, xyz, proj_search_radius, cylinder_half_length,
    period=halocat.Lbox) - 1
cic_halotools = np.histogram(cic_counts, bins=cic_edges, density=True)[0]
cic_halotools

## Now let's do it the `galtab` way

In [None]:
# Give the Tabulator the halo catalog and a fiducial HOD model
gtab = galtab.GalaxyTabulator(halocat, hod)

# Prepare the CICTabulator to make predictions
cictab = galtab.CICTabulator(gtab, proj_search_radius, cylinder_half_length,
                            bin_edges=cic_edges)

# Choose your HOD parameters (in this case, we will keep them the same)
hod.param_dict.update({})

# Predict CiC for this model
cic_galtab = cictab.predict(hod)
cic_galtab

### Optionally, write the CIC tabulation to disk for later use:

In [None]:
# Pickle the CICTabulator - creates a large file named `cictab.pickle`:
cictab.save("cictab.pickle")

# And load it back with:
cictab = galtab.CICTabulator.load("cictab.pickle")
gtab = cictab.galtabulator

## Plot the `galtab` vs. `halotools` comparison

- `galtab` predicts the CiC expectation value (smooth + deterministic)
- `halotools` draws a CiC realization (noisy + stochastic)

In [None]:
cic_cens = 0.5 * (cic_edges[:-1] + cic_edges[1:])
plt.semilogy(cic_cens, cic_galtab, label="galtab", lw=3)
plt.semilogy(cic_cens, cic_halotools, label="halotools", lw=3, ls="--")
plt.legend(frameon=False)
plt.xlabel("$N_{\\rm CiC}$")
plt.ylabel("$P(N_{\\rm CiC})$")
plt.show()

## **In Development:** Differentiate CiC w.r.t. the HOD parameter $\log M_{\rm min}$

`galtab` is implemented in JAX, so it is portable to GPU and differentiable
(in principal), assuming your HOD model is compatible with JAX. Unfortunately,
this requires a few modifications to `halotools` models. For example, let's
use the `JaxZheng07Cens` and `JaxZheng07Sats` models, originally implemented
for the [JaxTabCorr](https://github.com/AlanPearl/JaxTabCorr) project.

We can construct a composite HOD model with our JAX-compatible mean
occupation functions, which we call `hod_jax`. This model allows us to
differentiate `cictab.predict` with `jax.grad`.

*Note:* You shouldn't try using `jax.jit` directly on `cictab.predict`, since it
contains some lines of code that can't be compiled. Rest assured that the primary
expensive computations will automatically compile and run on the GPU if available.

In [None]:
from galtab.jaxhalotools import JaxZheng07Cens, JaxZheng07Sats

# Create JAX-compatible composite HOD model
def make_hod_jax():
    return htem.HodModelFactory(
        centrals_occupation=JaxZheng07Cens(threshold=-21),
        satellites_occupation=JaxZheng07Sats(threshold=-21),
        centrals_profile=htem.TrivialPhaseSpace(),
        satellites_profile=htem.NFWPhaseSpace()
    )

# Define function that predictions P(N_cic = 1)
def calc_cic1(logMmin=12.79):
    hod_jax = make_hod_jax()
    hod_jax.param_dict.update({"logMmin": logMmin})
    return cictab.predict(hod_jax, warn_p_over_1=False)[1]

# Define the derivative of calc_cic1
diff_cic1 = jax.grad(calc_cic1)

# Note that we shouldn't make logMmin too much lower than that of our fiducial
# model. If desired, make more conservative choices for the fiducial parameters.
# i.e., low logMmin / logM1 / logM0 values and large sigma_logM values
for logmmin in np.linspace(11.0, 15.0, 20):
    value = calc_cic1(logmmin)
    derivative = diff_cic1(logmmin)

    plt.plot(logmmin, value, "bo")
    plt.quiver(logmmin, value, 1, derivative, angles="xy")

plt.xlabel("$\\log M_{\\rm min}$")
plt.ylabel("$P(N_{\\rm CiC} = 1)$")
plt.show()

## `jax.grad` (the arrows in the above plot) isn't working *yet*...

- I actually wasn't expecting the above to work perfectly, because it's using the Monte-Carlo mode, which isn't perfectly continuous
- But analytic mode moment derivatives aren't working either...
- TODO: Figure out what's going wrong