# Instructions to run all calculations from the command line

## Downloading SV3 Fuji Data

All clustering catalogs can be downloaded from https://data.desi.lbl.gov/desi/survey/catalogs/SV3/LSS/fuji/LSScats/EDAbeta/

NOTE: use `/global/cfs/cdirs/` instead of https://data.desi.lbl.gov/ on NERSC

### Clustering catalog data (north and south fields)
- `BGS_BRIGHT_N_clustering.dat.fits`
- `BGS_BRIGHT_S_clustering.dat.fits`
### Clustering catalog randoms (18 realizations, north and south fields)
- `BGS_BRIGHT_N_{i}_clustering.ran.fits` for `{i}` in 0-17
- `BGS_BRIGHT_S_{i}_clustering.ran.fits` for `{i}` in 0-17

The fastspecphot catalog I used can be downloaded from

**Warning**: If using a version of fastspecphot newer than ~Feb 2023, I think they switched their assumption from h=0.7 to h=1.0

## Data cleaning

Generate clean DESI data files with: `python -m galtab.paper2.clean_desi_data`

This generates new data files which have been "cleaned" of stars, z<0 spectra, duplicates, non BGS bright targets, and sources with DELTACHI2 < 25. All cleaned data is placed in `~/data/DESI/SV3/clean_fuji/` by default

## Counting Randoms in Cylinders (for sky-completeness masking)

```
job.sh (executed using `sbatch job.sh`) -> Output: desi_rand_counts.npy
---------------------------------------
#!/bin/bash
#SBATCH --qos=shared
#SBATCH --job-name=rand-cic
#SBATCH --constraint=haswell
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=32
#SBATCH --time=0:40:00
#SBATCH --account=desi

# half the 64 total CPUs in a haswell node = 32 (2 CPUs per core => 16 physical cores)

srun -n 1 -c 16 --cpu-bind=cores python -m galtab.paper2.count_desi_randoms --progress --output desi_rand_counts.npy --data-dir /global/homes/a/apearl/data/clean_fuji/ --num-threads 16 --force-no-mpi
```

Download `desi_rand_counts.npy` and place it in `~/data/DESI/SV3/clean_fuji`

## Jackknife Observations (n + wp + CiC)

This work was done in `~/Paper2Data/desi_observables/`

Perform all calculations with:
```
# Full CiC Distribution
# =====================
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_20p0.npz --abs-mr-max " -20.0" --wp-rand-frac 0.1
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_20p5.npz --abs-mr-max " -20.5" --wp-rand-frac 0.1
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_21p0.npz --abs-mr-max " -21.0" --wp-rand-frac 0.1
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_21p0_z0p2-0p3.npz --abs-mr-max " -21.0" --wp-rand-frac 0.1 --zmin 0.2 --zmax 0.3

# CiC Moments 1-5
# ===============
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_20p0_kmax5.npz --abs-mr-max " -20.0" --wp-rand-frac 0.1 --cic-kmax 5
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_20p5_kmax5.npz --abs-mr-max " -20.5" --wp-rand-frac 0.1 --cic-kmax 5
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_21p0_kmax5.npz --abs-mr-max " -21.0" --wp-rand-frac 0.1 --cic-kmax 5
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_21p0_z0p2-0p3_kmax5.npz --abs-mr-max " -21.0" --wp-rand-frac 0.1 --zmin 0.2 --zmax 0.3 --cic-kmax 5

# No CiC
# ======
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_20p0_kmax0.npz --abs-mr-max " -20.0" --wp-rand-frac 0.1 --cic-kmax 0
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_20p5_kmax0.npz --abs-mr-max " -20.5" --wp-rand-frac 0.1 --cic-kmax 0
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_21p0_kmax0.npz --abs-mr-max " -21.0" --wp-rand-frac 0.1 --cic-kmax 0
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_21p0_z0p2-0p3_kmax0.npz --abs-mr-max " -21.0" --wp-rand-frac 0.1 --zmin 0.2 --zmax 0.3 --cic-kmax 0

# cylinder_half_length = pimax = 40.0 (CiC Moments 1-5)
# =====================================================
python -m galtab.paper2.desi_observables --cylinder-half-length 40.0 --pimax 40.0 -vpn4 -o desi_obs_20p0_hl40_kmax5.npz --abs-mr-max  -20.0 --wp-rand-frac 0.1 --cic-kmax 5
python -m galtab.paper2.desi_observables --cylinder-half-length 40.0 --pimax 40.0 -vpn4 -o desi_obs_20p5_hl40_kmax5.npz --abs-mr-max  -20.5 --wp-rand-frac 0.1 --cic-kmax 5
python -m galtab.paper2.desi_observables --cylinder-half-length 40.0 --pimax 40.0 -vpn4 -o desi_obs_21p0_hl40_kmax5.npz --abs-mr-max  -21.0 --wp-rand-frac 0.1 --cic-kmax 5
python -m galtab.paper2.desi_observables --cylinder-half-length 40.0 --pimax 40.0 -vpn4 -o desi_obs_21p0_z0p2-0p3_hl40_kmax5.npz --abs-mr-max  -21.0 --wp-rand-frac 0.1 --zmin 0.2 --zmax 0.3 --cic-kmax 5

# For comparison to Kuan:
# =======================
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_20p0_kuansample.npz --abs-mr-max " -20.0" --wp-rand-frac 0.1 --kuan-mags --zmin 0.02 --zmax 0.106
python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_21p0_kuansample.npz --abs-mr-max " -21.0" --wp-rand-frac 0.1 --kuan-mags --zmin 0.02 --zmax 0.159
```

## Testing the parameter importance of n, wp, and CiC moments

This work was done in `~/Paper2Data/desi_results/`

Run `python -m galtab.paper2.importance --num-samples 10_000`

## Testing galtab accuracy vs runtime performance parameters

This work was also done in `~/Paper2Data/desi_results/`

Run `python -m galtab.paper2.accuracy_vs_runtime --num-gt-trials 25 --num-ht-trials 2500`

## Fitting HOD parameters

All fitting was performed by MCMC. To construct the MCMC chains for the -21.0 threshold sample, for example, create a directory at `~/Paper2Data/desi_results/21p0_results`. From within this directory, you can run 3000 iterations x 20 walkers with:

`python -m galtab.paper2.param_sampler -v 3000 ../../desi_observations/desi_obs_20p0.npz .`

# Imports and Loading data

In [None]:
from time import time
from contextlib import contextmanager
import pathlib

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.stats
import scipy.optimize
import scipy.special
import sklearn.ensemble
import sklearn.inspection

import shap
import emcee
import corner
from tqdm.notebook import tqdm
from astropy.io import fits
import astropy.cosmology
import halotools.mock_observables as htmo
import halotools.empirical_models as htem
import halotools.sim_manager as htsm

import mocksurvey as ms
import galtab.obs
import galtab.paper2.desi_sv3_pointings
import galtab.paper2.desi_observables
import galtab.paper2.param_sampler
import galtab.paper2.param_config

In [None]:
# Load clean DESI catalogs
# ========================
def load_fastphot(dirname=pathlib.Path.home() / "data/DESI/SV3/clean_fuji", load_rands=False):
    fastphot_fn = pathlib.Path(dirname) / "fastphot.npy"
    randcyl_fn = pathlib.Path(dirname) / "desi_rand_counts.npy"
    rands_fn = pathlib.Path(dirname) / "rands" / "rands.npy"

    data = np.load(fastphot_fn)
    randcyl = np.concatenate(np.load(randcyl_fn, allow_pickle=True))
    
    rands = None
    if load_rands:
        rands = np.load(rands_fn)
    
    region_masks = [galtab.paper2.desi_sv3_pointings.select_region(
        i, data["RA"], data["DEC"]) for i in range(20)]

    return data, rands, randcyl, region_masks

# fastphot, desi_rands, randcyl, region_masks = load_fastphot(load_rands=True)
fastphot, desi_rands, randcyl, region_masks = load_fastphot()

In [None]:
n_mask = slice(0, 1)
wp_mask = slice(1, 13)
cic_mask = slice(13, None)
not_cic_mask = slice(None, 13)

In [None]:
# Load my DESI summary statistics
# ===============================
filenames = ["desi_obs_20p0.npz", "desi_obs_20p0_puff.npz",
             "desi_obs_20p0_kmax5.npz", "desi_obs_20p0_kuansample.npz",
             "desi_obs_20p5.npz", "desi_obs_20p5_puff.npz", "desi_obs_20p5_kmax5.npz", "desi_obs_21p0.npz", "desi_obs_21p0_puff.npz",
             "desi_obs_21p0_kmax5.npz", "desi_obs_21p0_kuansample.npz",
             "desi_obs_21p0_z0p2-0p3.npz", "desi_obs_21p0_z0p2-0p3_puff.npz", "desi_obs_21p0_z0p2-0p3_kmax5.npz",
             "desi_obs_20p0_kmax0.npz", "desi_obs_20p5_kmax0.npz", "desi_obs_21p0_kmax0.npz",
             "desi_obs_21p0_z0p2-0p3_kmax0.npz", "desi_obs_20p0_hl40_kmax5.npz", "desi_obs_20p5_hl40_kmax5.npz",
             "desi_obs_21p0_hl40_kmax5.npz", "desi_obs_21p0_z0p2-0p3_hl40_kmax5.npz"]

desi_obs = {}
desi_mean = {}
desi_err = {}
for file in filenames:
    filestart, fileend = "desi_obs_", ".npz"
    assert file.startswith(filestart)
    assert file.endswith(fileend)
    key = file[len(filestart):-len(fileend)]
    try:
        desi_obs[key] = np.load(
            pathlib.Path.home() / f"Paper2Data/desi_observations/{file}", allow_pickle=True)
    except FileNotFoundError:
        print(f"Couldn't find {file} at", pathlib.Path.home(
        ) / f"Paper2Data/desi_observations/{file}")
    else:
        assert desi_obs[key]["slice_n"].tolist() == n_mask
        assert desi_obs[key]["slice_wp"].tolist() == wp_mask
        assert desi_obs[key]["slice_cic"].tolist().start == cic_mask.start
        desi_mean[key] = desi_obs[key]["mean"]
        desi_err[key] = np.sqrt(np.diag(desi_obs[key]["cov"]))
        desi_mean[key][n_mask] = desi_mean[key][n_mask] * desi_obs[key]["effective_area_sqdeg"] / 173.4924
        desi_err[key][n_mask] = desi_err[key][n_mask] * desi_obs[key]["effective_area_sqdeg"] / 173.4924

rp_edges = desi_obs["20p0"]["rp_edges"]
cic_edges = desi_obs["20p0"]["cic_edges"]
rp_cens = np.sqrt(rp_edges[:-1] * rp_edges[1:])
cic_cens = 0.5 * (cic_edges[:-1] + cic_edges[1:])
print(*desi_obs["20p0"].keys(), sep=", ")

In [None]:
cosmo = astropy.cosmology.Cosmology.from_format(desi_obs["20p0"]["cosmo"].tolist(), format="mapping")
desi_obs["20p0"]["effective_area_sqdeg"]

In [None]:
desi_obs.keys()

In [None]:
# Load results: accuracy_vs_runtime, importance, (eventually MCMC too)
# ====================================================================
gt_avr_results, ht_avr_results = np.load(
    pathlib.Path.home() / "Paper2Data/desi_results/accuracy_runtime_results.npy", allow_pickle=True)
gt_avr_results = pd.DataFrame(gt_avr_results)
ht_avr_results = pd.DataFrame(ht_avr_results)

knames = [name for name in gt_avr_results.columns if name.startswith("k") and name[1:].isdigit()]
cicnames = ["CiC" + name[1:] for name in knames]
for kname, cicname in zip(knames, cicnames):
    gt_avr_results[cicname] = gt_avr_results[kname]
    ht_avr_results[cicname] = ht_avr_results[kname]

importance_results = np.load(pathlib.Path.home() / "Paper2Data/desi_results/importance_results.npz")

# Define plotting functions

In [None]:
@contextmanager
def autoscale_turned_off(ax=None, x=True, y=True):
    ax = ax or plt.gca()
    lims = [ax.get_xlim(), ax.get_ylim()]
    yield

    if x:
        ax.set_xlim(*lims[0])
    if y:
        ax.set_ylim(*lims[1])

def plot_desi_obs(puff=False, highz=False, kuansample=False, kmax=None, figax=None, threshes=None, color_start=0,
                  plot_n=True, plot_wp=True, plot_cic=True, nolines=False, alpha=0.35, cic_offset_start=None,
                  logrp_offset_start=None, plot_wp_errorbars=False, override_label=None, xscale_cic=False):
    if threshes is None:
        threshes = [-20.0, -20.5, -21.0]
    if figax is None:
        fig, axes = plt.subplots(ncols=3, figsize=(13, 4), gridspec_kw={"width_ratios": [1, 5, 5], "wspace": 0.3})
    else:
        fig, axes = figax
    ind = color_start
    highzlabel = " (high z)" if highz else ""
    highz = "_z0p2-0p3" if highz else ""
    puff = "_puff" if puff else ""
    kuansample = "_kuansample" if kuansample else ""
    kmax_str = "" if kmax is None else "_kmax5"
    cic_offset = 0 if cic_offset_start is None else cic_offset_start
    logrp_offset = 0 if logrp_offset_start is None else logrp_offset_start
    for thresh in threshes:
        thresh_string = f"{-thresh:.1f}".replace(".", "p")
        key = f"{thresh_string}{kuansample}{highz}{puff}{kmax_str}"
        color = f"C{ind}"
        ind += 1
        n = desi_mean[key][n_mask]
        wp = desi_mean[key][wp_mask]
        cic = desi_mean[key][cic_mask]
        n_err = desi_err[key][n_mask]
        wp_err = desi_err[key][wp_mask]
        cic_err = desi_err[key][cic_mask]
        if override_label is None:
            label = f"${thresh:.1f}${highzlabel}"
        else:
            if isinstance(override_label, int):
                label = None
            else:
                label = override_label

        # n
        if plot_n:
            axes[0].set_xlabel("", fontsize=14)
            axes[0].set_ylabel("$n \\; [h^3 {\\rm Mpc}^{-3}]$", fontsize=14)
            axes[0].xaxis.set_visible(False)
            if not nolines:
                axes[0].axhline(n, color=color)
            axes[0].fill_between([-100, 100], n + n_err, n - n_err, alpha=alpha,
                                 color=color, label=label)
            axes[0].set_xlim(0, 10)
        # wp
        if plot_wp:
            axes[1].set_xlabel("$r_{\\rm p} \\; [h^{-1} {\\rm Mpc}]$", fontsize=14)
            axes[1].set_ylabel("$r_{\\rm p} \\times w_{\\rm p}(r_{\\rm p}) \\; [h^{-2} {\\rm Mpc}^2]$", fontsize=14)
            axes[1].semilogx()
            axes[1].xaxis.set_major_formatter(plt.matplotlib.ticker.ScalarFormatter())
            axes[1].set_ylim(50, 200)
            if not nolines:
                axes[1].plot(rp_cens * 10**logrp_offset, rp_cens * wp, color=color)
            if plot_wp_errorbars:
                axes[1].errorbar(rp_cens * 10**logrp_offset, rp_cens * wp, rp_cens * wp_err,
                                 alpha=alpha, color=color, label=label)
            else:
                axes[1].fill_between(rp_cens * 10**logrp_offset, rp_cens * (wp + wp_err),
                                     rp_cens * (wp - wp_err), alpha=alpha,
                                     color=color, label=label)
        # CiC
        if plot_cic:
            axes[2].semilogy()
            if kmax is None:
                axes[2].set_xlabel("$N_{\\rm CiC}$", fontsize=14)
                if xscale_cic:
                    axes[2].set_ylabel("$(N_{\\rm CiC} + 1) \\times dP/dN_{\\rm CiC}$", fontsize=14)
                else:
                    axes[2].set_ylabel("$dP/dN_{\\rm CiC}$", fontsize=14)
                axes[2].set_xlim(-5, 100)
                if not xscale_cic:
                    axes[2].set_ylim(0.0001, 0.5)
                else:
                    axes[2].set_ylim(bottom=0.004, top=0.6)
                x = cic_cens
                xscale = (x + 1) if xscale_cic else x ** 0
                axes[2].fill_between(x + cic_offset, xscale * (cic + cic_err), xscale * (cic - cic_err),
                                     alpha=alpha, color=color, label=label)
            else:
                axes[2].set_xticks(np.arange(1, kmax + 1))
                axes[2].set_xlabel("Moment number", fontsize=14)
                axes[2].set_ylabel("CiC Moments", fontsize=14)
                x = np.arange(1, kmax + 1)
                axes[2].errorbar(x + cic_offset, cic, cic_err, alpha=alpha,
                                     color=color, label=label, fmt="none", capsize=4)
            if not nolines:
                xscale = (x + 1) if xscale_cic else x ** 0
                axes[2].plot(x + cic_offset, xscale * cic, color=color)
        if cic_offset_start is not None:
            cic_offset += 0.1
        if logrp_offset_start is not None:
            logrp_offset += 0.05

    [x.tick_params(axis='both', which='major', labelsize=13) for x in axes]
    return fig, axes

from galtab.paper2.param_sampler import ParamSampler, BetterMultivariateNormal
def get_bestfit(name, num_best_fits=1):
    sampler_file = pathlib.Path.home() / f"Paper2Data/desi_results/results_{name}/sampler.npy"
    backend_file = pathlib.Path.home() / f"Paper2Data/desi_resukls/results_{name}/emcee_backend.hdf5"
    backend = emcee.backends.HDFBackend(backend_file, read_only=True)

    sampler = np.load(sampler_file, allow_pickle=True)[0]
    my_blob = pd.DataFrame(sampler.blob)

    args = np.argsort(my_blob["loglike"]).values[::-1][:num_best_fits]
    return my_blob, args

def plot_bestfit_obs(name, kmax=None, figax=None, threshes=None, color_ind=0, label=None, alpha=0.9, lw=3, ls="-",
                     plot_n=True, plot_wp=True, plot_cic=True, print_chi2=True, cic_offset=0, bring_front=False,
                     num_best_fits=1, xscale_cic=False):
    my_blob, args = get_bestfit(name=name, num_best_fits=num_best_fits)

    if figax is None:
        fig, axes = plt.subplots(ncols=3, figsize=(13, 4), gridspec_kw={"width_ratios": [1, 5, 5], "wspace": 0.3})
    else:
        fig, axes = figax

    color = f"C{color_ind}"
    zorder = np.inf if bring_front else None

    for i, arg in enumerate(args):
        n = my_blob.iloc[arg]["n"]
        wp = my_blob.iloc[arg]["wp"]
        cic = my_blob.iloc[arg]["cic"]
        if print_chi2 and i == 0:
            obs_name = name.replace("_htcic", "").replace("_noassembias", "")
            dist = galtab.paper2.param_sampler.BetterMultivariateNormal(
                mean=desi_obs[obs_name]["mean"], cov=desi_obs[obs_name]["cov"], allow_singular=True)
            res = (np.array([n, *wp, *cic]) - desi_obs[obs_name]["mean"])

            # stats.Covariance.from_eigendecomposition(np.linalg.eigh(cov))
            cov_pinv = np.linalg.pinv(desi_obs[obs_name]["cov"])
            my_chi2 = res @ (cov_pinv @ res)
            my_dof = np.linalg.matrix_rank(cov_pinv)
            my_pval = scipy.stats.chi2(df=my_dof).sf(my_chi2)
            print(f"*My* Chi2 (dof={my_dof}) for {name} = {my_chi2:.4f}; p-val = {my_pval:.4f}")

            normres = res * dist.norm
            chi2 = np.sum(np.square(np.dot(normres, dist.cov_object.U)))
            dof = dist.cov_object.rank
            # chi2 = -2 * my_blob.iloc[arg]["loglike"]
            # dof = 1 + len(wp) + len(cic)
            pval = scipy.stats.chi2(df=dof).sf(chi2)
            zscore = scipy.stats.norm.ppf(1 - pval/2)
            print(f"Chi2 value (dof={dof}) for {name} = {chi2:.4f}; p-val = {pval:.4f} "
                  f"({zscore:.4f} sigma tension)")
            print("param_dict =", repr(my_blob.iloc[arg]["param_dict"]))
            # print(cic)
            print()
        # n_err = desi_err[key][n_mask]
        # wp_err = desi_err[key][wp_mask]
        # cic_err = desi_err[key][cic_mask]

        # n
        if plot_n:
            axes[0].set_xlabel("", fontsize=14)
            axes[0].set_ylabel("$n \\; [h^3 {\\rm Mpc}^{-3}]$", fontsize=14)
            axes[0].xaxis.set_visible(False)
            axes[0].axhline(n, label=label, color=color, lw=lw, ls=ls, alpha=alpha, zorder=zorder)
            # axes[0].fill_between([-100, 100], n + n_err, n - n_err, alpha=0.4, color=color)
            axes[0].set_xlim(0, 10)
        # wp
        if plot_wp:
            axes[1].set_xlabel("$r_{\\rm p} \\; [h^{-1} {\\rm Mpc}]$", fontsize=14)
            axes[1].set_ylabel("$r_{\\rm p} \\times w_{\\rm p}(r_{\\rm p}) \\; [h^{-2} {\\rm Mpc}^2]$", fontsize=14)
            axes[1].semilogx()
            axes[1].xaxis.set_major_formatter(plt.matplotlib.ticker.ScalarFormatter())
            axes[1].set_ylim(50, 200)
            axes[1].plot(rp_cens, rp_cens * wp, lw=lw, ls=ls, label=label, color=color, alpha=alpha, zorder=zorder)
            # axes[1].fill_between(rp_cens, rp_cens * (wp + wp_err), rp_cens * (wp - wp_err), alpha=0.4, color=color)
        # CiC
        if plot_cic:
            axes[2].semilogy()
            if kmax is None:
                axes[2].set_xlabel("$N_{\\rm CiC}$", fontsize=14)
                if xscale_cic:
                    axes[2].set_ylabel("$(N_{\\rm CiC} + 1) \\times dP/dN_{\\rm CiC}$", fontsize=14)
                else:
                    axes[2].set_ylabel("$dP/dN_{\\rm CiC}$", fontsize=14)
                axes[2].set_xlim(-5, 100)
                if not xscale_cic:
                    axes[2].set_ylim(0.0001, 0.5)
                else:
                    axes[2].set_ylim(bottom=0.004, top=0.6)
                x = cic_cens
            else:
                axes[2].set_xticks(np.arange(1, kmax + 1))
                axes[2].set_xlabel("Moment number", fontsize=14)
                axes[2].set_ylabel("CiC Moments", fontsize=14)
                x = np.arange(1, kmax + 1)
            if kmax is None:
                xscale = (x + 1) if xscale_cic else x ** 0
                axes[2].plot(x + cic_offset, xscale * cic, lw=lw, ls=ls, label=label,
                             color=color, alpha=alpha, zorder=zorder)
            else:
                # axes[2].plot(x, cic, "o", label=label, color=color)
                axes[2].plot(x + cic_offset, cic, lw=lw, ls=ls, label=label, color=color, alpha=alpha, zorder=zorder)
            # axes[2].fill_between(cic_cens, cic + cic_err, cic - cic_err, alpha=0.4, color=color)

    [x.tick_params(axis='both', which='major', labelsize=13) for x in axes]
    return fig, axes

In [None]:
def plot_footprint(calc, region_num=12, fig=None):
    nrows, ncols = 1, 2
    if fig is None:
        fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(13, 5))
    else:
        fig, ax = fig, np.array(fig.axes).reshape((nrows, ncols))
    ra, dec = fastphot["RA"], fastphot["DEC"]
    ax[0].grid()
    ax[0].scatter(ra, dec, s=0.01, zorder=-10)

    midra = [np.mean(x[0]) for x in galtab.paper2.desi_sv3_pointings.lims]
    middec = [np.mean(x[1]) for x in galtab.paper2.desi_sv3_pointings.lims]
    for i in range(len(midra)):
        x = midra[i]
        if i==12:
            x -= 1.5
        if i==14:
            x += 1.5
        ax[0].text(x, middec[i], i+1, ha="center", va="center", fontsize=13)

    # Mask entire redshift range used in analysis, rather than one individual sample
    mask_z = (0.1 <= calc.fastphot["Z"]) & (calc.fastphot["Z"] <= 0.3)
    ax[1].scatter(*calc.data_rdx[calc.randcic_mask & calc.region_masks[region_num-1] & mask_z][:, :2].T,
                  s=5, zorder=-10)
    ax[1].scatter(*calc.data_rdx[~calc.randcic_mask & calc.region_masks[region_num-1] & mask_z][:, :2].T,
                  s=5, zorder=-10)

    for a in ax:
        a.invert_xaxis()
        a.set_xlabel("RA [deg]", fontsize=14)
        a.set_ylabel("DEC [deg]", fontsize=14)

    ax[1].text(0.01, 0.01, f"region_num={region_num}", fontsize=14, transform=ax[1].transAxes, color="k")
    ax[1].text(0.01, 0.99, f"cylinders $<${calc.randcic_cut / (calc.num_rand_files * 2500)*100:.0f}% \ncomplete",
               fontsize=14, transform=ax[1].transAxes, color="C1", va="top")
    ax[0].tick_params(axis='both', which='major', labelsize=13)
    ax[1].tick_params(axis='both', which='major', labelsize=13)

    # Rasterize scatter plots (which have been given zorders less than 0)
    ax[0].set_rasterization_zorder(0)
    ax[1].set_rasterization_zorder(0)
    return fig

def plot_redshift_vs_rmag(redshift, rmag, gridsize=100):
#     plt.hexbin(redshift, rmag,
#                mincnt=1, extent=(0.02, 0.3, -25, -15),
#                norm=plt.matplotlib.colors.LogNorm(), gridsize=gridsize)

    plt.scatter(redshift, rmag, s=0.002, alpha=0.6, color="grey", zorder=-10)
    plt.xlim(0.02, 0.32)
    plt.ylim(-25, -15)
    samples = [
        # {"zmin": 0.1, "zmax": 0.2, "rmag_thresh": -19.5},
        {"zmin": 0.1, "zmax": 0.2, "rmag_thresh": -20.0},
        {"zmin": 0.1, "zmax": 0.2, "rmag_thresh": -20.5},
        {"zmin": 0.1, "zmax": 0.2, "rmag_thresh": -21.0},
        {"zmin": 0.2, "zmax": 0.3, "rmag_thresh": -21.0},
        # {"zmin": 0.2, "zmax": 0.3, "rmag_thresh": -20.5},
    ]
    fill_colors = ["none", "none", "none", "none"]
    hatch_types = ["", "\\\\", "//", "//"]
    zorders = [203, 202, 201, 200]
    for i, sample in enumerate(samples):
        style = f"C{i}"
        zmin, zmax, rmag_thresh = sample["zmin"], sample["zmax"], sample["rmag_thresh"]
        zorder, hatch = zorders[i], hatch_types[i]
        label = f"${zmin}<z<{zmax}, \\; M_r<{rmag_thresh}$"
        with autoscale_turned_off():
            plt.fill_between([zmin, zmax], [-1000]*2, [rmag_thresh]*2, hatch=hatch, color=fill_colors[i],
                             edgecolor=f"C{i}", alpha=0.75, lw=3, label=label, zorder=zorder)
#         plt.plot([zmin, zmax], [rmag_thresh]*2, style,
#                  lw=2, alpha=0.9, label=f"S{i+1}: $\\rm {zmin}<z<{zmax}, \\; M_r<{rmag_thresh}$")
#         plt.plot([zmax]*2, [rmag_thresh, plt.gca().get_ylim()[0]],
#                  style, lw=2, alpha=0.9)
    plt.legend(frameon=False)
    plt.xlabel("redshift", fontsize=16)
    plt.ylabel("$M_{r, {\\rm SDSS}, 0.1} - 5\\log{h}$", fontsize=16)
    plt.ylim(-16, -24)
    plt.gca().tick_params(axis='both', which='major', labelsize=14)
    
    plt.gca().set_rasterization_zorder(0)
    return plt.gcf()

def plot_accuracy_vs_runtime(kmax=5, plot_ht_mean_err=True, multiply_time_by=1, figsize=None,
                             remove_max_weight_time_plot=False, labelsize=18, touchaxes=False,
                             percent_error=False):
    param_name = "max_weight"
    xparams = ["max_weight", "time"]
    metric_names = ["time", *(f"CiC{i}" for i in range(1, kmax+1))]
    simname = "bolplanck"
    sqiomw = False
    queries = [
        "min_quant==1e-5",
        "min_quant==1e-4",
        "min_quant==1e-3",
        "min_quant==1e-2",
              ]
    # queries = ["and max_quant == 0.9 and n_mc==5",
    #            "and 0.95 < max_quant < 0.97 and n_mc==5",
    #            "and 0.97 < max_quant < 0.99 and n_mc==5",
    #            "and 0.99 < max_quant < 0.999 and n_mc==5",
    #            "and 0.999 < max_quant < 0.9997 and n_mc==5",
    #            "and 0.9997 < max_quant < 0.9999 and n_mc==5"]

    if touchaxes:
        figsize = (6, 4 + 9/5*kmax) if figsize is None else figsize
        fig, ax = plt.subplots(figsize=figsize, nrows=len(metric_names), ncols=len(xparams),
                               gridspec_kw=dict(hspace=0, wspace=0))
    else:
        figsize = (8, 4 + 13/5*kmax) if figsize is None else figsize
        fig, ax = plt.subplots(figsize=figsize, nrows=len(metric_names), ncols=len(xparams))

    ht_time_bolplanck = ht_avr_results.query("simname=='bolplanck' and not delmock")["time"].mean(numeric_only=True)
    for x in np.ravel(ax[:, :]):
        x.set_xscale("log")
        x.xaxis.set_major_formatter(plt.matplotlib.ticker.ScalarFormatter())
    for x in np.ravel(ax[:1, :]):
        x.set_yscale("log")
        x.yaxis.set_major_formatter(plt.matplotlib.ticker.ScalarFormatter())
    for i, param in enumerate(xparams):
        ax[0, i].axhline(ht_time_bolplanck, color="k", ls="--")
        # ax[1, i].axhline(ht_frac_error[simname], color="k", ls="--")
        # ax[1, i].axhline(best_desi_frac_error[simname], color="C0", ls="--")
        # ax[2, i].axhline(0, color="k", ls="--")

        for j, metric in enumerate(metric_names):
            if param == "max_weight":
                ax[j, i].invert_xaxis()
                ax[j, i].axvline(0.05, color="C1", ls="--")
            elif param == "time":
                ax[j, i].axvline(ht_time_bolplanck, color="k", ls="--")
            if metric not in ["time", "frac_error", "frac_bias", "rchi2"]:
                q = ht_avr_results.query(f"simname=='{simname}'")[metric]
                mean = q.mean()
            for k, query in enumerate(queries):
                color = f"C{k}"
                fullquery = f"simname=='{simname}' and {query} and sqiomw=={sqiomw}"
                q = gt_avr_results.query(fullquery).groupby(
                    param_name).mean(numeric_only=True).reset_index()
                spread = gt_avr_results.query(fullquery).groupby(
                    param_name).std(numeric_only=True, ddof=1).reset_index()
                if not param == metric:
                    q["time"] *= multiply_time_by
                    if percent_error and metric not in ["time", "frac_error", "frac_bias", "rchi2"]:
                        # mean = q[metric].iloc[0]
                        value = 100*(q[metric] - mean)/mean
                        error = 100*spread[metric]/mean
                    else:
                        value = q[metric]
                        error = spread[metric]
                    if not k:
                        # Plot the spread of trials for the k=0 query
                        ax[j, i].fill_between(q[param], value + error, value - error,
                                              color="grey", alpha=0.6)
                    ax[j, i].plot(q[param], value, color=color, label=query.replace("==", "="))

                    if not touchaxes or i == 0:
                        if percent_error:
                            ax[j, i].set_ylabel(metric + " [% err]", fontsize=labelsize)
                        else:
                            ax[j, i].set_ylabel(metric, fontsize=labelsize)
                    else:
                        ax[j, i].yaxis.set_ticklabels([])
                    xlabel = "time [sec]" if param == "time" else param
                    if not touchaxes or j+1 == len(metric_names):
                        ax[j, i].set_xlabel(xlabel, fontsize=labelsize)
                    else:
                        ax[j, i].xaxis.set_ticklabels([])
                else:
                    ax[j, i].set_visible(False)
            if metric not in ["time", "frac_error", "frac_bias", "rchi2"]:
                q = ht_avr_results.query(f"simname=='{simname}'")[metric]
                mean, err = q.mean(), q.std(ddof=1)
                mean_err = err / np.sqrt(len(q))
                if percent_error:
                    ax[j, i].axhline(0, color="k", ls="--")
                    with autoscale_turned_off(ax[j, i], y=False):
                        ax[j, i].fill_between([-1e16, 1e16], [100*err/mean]*2, [-100*err/mean]*2,
                                              color="grey", alpha=0.3)
                        if plot_ht_mean_err:
                            ax[j, i].fill_between([-1e16, 1e16], [100*mean_err/mean]*2, [-100*mean_err/mean]*2,
                                                  color="grey", alpha=0.8)
                else:
                    ax[j, i].axhline(mean, color="k", ls="--")
                    with autoscale_turned_off(ax[j, i], y=False):
                        ax[j, i].fill_between([-1e16, 1e16], [mean + err]*2, [mean - err]*2,
                                              color="grey", alpha=0.3)
                        if plot_ht_mean_err:
                            ax[j, i].fill_between([-1e16, 1e16], [mean + mean_err]*2, [mean - mean_err]*2,
                                                  color="grey", alpha=0.8)
    if remove_max_weight_time_plot:
        for x in ax[0, :]:
            x.set_visible(False)
        ax[1, 1].legend(frameon=False, fontsize=12)
    else:
        ax[0, 0].legend(frameon=False, fontsize=12)

    [x.tick_params(axis='both', which='major', labelsize=15) for x in ax.ravel()]
    plt.tight_layout()

def plot_importance(shap_values, hod_param_names, perm_result=None):
    figsize = (5*len(hod_param_names), 14)
    fig, axes = plt.subplots(nrows=2, ncols=int(len(hod_param_names)/2), figsize=figsize,
                             gridspec_kw=dict(hspace=0.5))
    for i, hod_param_name in enumerate(hod_param_names):
        j = int(i * 2 / len(hod_param_names))
        i = int(i % (len(hod_param_names)/2))
        mathname = hod_param_name.replace(
            "mean_occupation_centrals_assembias_param1", "A_{cen}").replace(
            "mean_occupation_satellites_assembias_param1", "A_{sat}").replace(
            "logMmin", "\\log M_{min}").replace(
            "sigma_logM", "\\sigma_{\\log M}").replace(
            "alpha", "\\alpha").replace(
            "logM1", "\\log M_1").replace(
            "logM0", "\\log M_0")# .replace("_", "\\_")
        if perm_result is not None:
            mean, std = perm_result[hod_param_name].importances_mean, perm_result[hod_param_name].importances_std
            axes[j, i].errorbar(feature_labels, mean, yerr=std, fmt="o",
                label="permutation loss")
        mean = np.abs(shap_values[hod_param_name].values).mean(axis=0)
        std = np.abs(shap_values[hod_param_name].values).std(axis=0, ddof=1) / np.sqrt(len(shap_values))
        axes[j, i].errorbar(feature_labels, mean, yerr=std, fmt="o", markersize=10,
                            markerfacecolor="none", 
                            # label="mean(|SHAP|)",
                            # label=hod_param_name,
        )

        axes[j, i].axvline(0, color="k", ls="--")
        axes[j, i].axvline(-n_rp_bins - 1, color="k", ls="--")
        axes[j, i].set_ylabel("Feature importance", fontsize=29, labelpad=-4)
        axes[j, i].set_xlabel("n | wp bins | CiC moments", fontsize=29)

        # title = f"Importance for $\\rm {mathname}$"
        # axes[j, i].set_title(title, fontdict=dict(fontsize=24))
        axes[j, i].semilogy()
        # axes[0, i].legend(frameon=True, fontsize=14, loc=(1, 0))

        feature_tick_labels = [x.strip("wpCi") for x in feature_labels_str]
        feature_tick_labels = ["" if x in {"2", "4", "6", "8", "10", "12"} else x for x in feature_tick_labels]
        axes[j, i].set_xticks(feature_labels, feature_tick_labels)
        axes[j, i].set_title(f"$\\rm {mathname}$", fontsize=29)

#         ax = axes[j + 1, i]
#         plt.sca(ax)
#         shap.plots.beeswarm(shap_values[hod_param_name], max_display=7, plot_size=figsize, show=False,
#                             color_bar_label=f"$\\rm {mathname}$ value")
#         ax.set_xlabel(ax.get_xlabel(), fontsize=24)
#         ax.set_yticks(ax.get_yticks(), ["others", *ax.get_yticklabels()[1:]])
#         ax.set_frame_on(True)

    for ax in fig.axes:
        # Tick label size
        ax.tick_params(labelsize=25)

def plot_shap_beeswarm(shap_values, hod_param_names, perm_result=None):
    figsize = (5*len(hod_param_names), 14)
    fig, axes = plt.subplots(nrows=2, ncols=int(len(hod_param_names)/2), figsize=figsize,
                             gridspec_kw=dict(hspace=0.5))
    for i, hod_param_name in enumerate(hod_param_names):
        j = int(i * 2 / len(hod_param_names))
        i = int(i % (len(hod_param_names)/2))
        mathname = hod_param_name.replace(
            "mean_occupation_centrals_assembias_param1", "A_{cen}").replace(
            "mean_occupation_satellites_assembias_param1", "A_{sat}").replace(
            "logMmin", "\\log M_{min}").replace(
            "sigma_logM", "\\sigma_{\\log M}").replace(
            "alpha", "\\alpha").replace(
            "logM1", "\\log M_1").replace(
            "logM0", "\\log M_0")# .replace("_", "\\_")

        ax = axes[j, i]
        plt.sca(ax)
        shap.plots.beeswarm(shap_values[hod_param_name], max_display=7, plot_size=figsize, show=False)
        ax.set_xlabel(f"SHAP value (impact on $\\rm {mathname}$)", fontsize=29)
        ax.set_yticks(ax.get_yticks(), ["others", *ax.get_yticklabels()[1:]])
        ax.set_frame_on(True)
        ax.set_title(f"$\\rm {mathname}$", fontsize=29)

    for ax in fig.axes:
        # Tick label size
        ax.tick_params(labelsize=25)
    for ax in fig.axes[-6:]:
        # Colorbar label sizes
        ax.set_ylabel(ax.get_ylabel(), fontsize=29, labelpad=-20)

def get_chain(name="20p5_kmax5", discard=100, assembias_params_only=False,
              no_assembias_params=False, no_logm0=False, verbose=True):
    dirname = pathlib.Path.home() / f"Paper2Data/desi_results/results_{name}"
    chain_param_names = np.load(dirname / "sampler.npy", allow_pickle=True)[0].param_names
    param_names = list(galtab.paper2.param_config.kuan_params[-20.5].keys())
    m1_ind, mmin_ind, m0_ind = param_names.index("logM1"), param_names.index("logMmin"), param_names.index("logM0")
    chain_param_order = [chain_param_names.index(x.replace("logM0", "logM0_quant")) for x in param_names]

    param_names[param_names.index("logMmin")] = "$\\log M_{\\rm min}$"
    param_names[param_names.index("sigma_logM")] = "$\\sigma_{\\log M}$"
    param_names[param_names.index("alpha")] = "$\\alpha$"
    param_names[param_names.index("logM1")] = "$\\log M_1$"
    param_names[param_names.index("logM0")] = "$\\log M_0$"
    param_names[param_names.index("mean_occupation_centrals_assembias_param1")] = "$A_{\\rm cen}$"
    param_names[param_names.index("mean_occupation_satellites_assembias_param1")] = "$A_{\\rm sat}$"

    param_slice = list(range(len(param_names)))
    if no_assembias_params:
        assert not assembias_params_only
        param_slice.remove(param_names.index("$A_{\\rm cen}$"))
        param_slice.remove(param_names.index("$A_{\\rm sat}$"))
        if no_logm0:
            param_slice.remove(param_names.index("$\\log M_0$"))
    elif assembias_params_only:
        param_slice.remove(param_names.index("$\\log M_{\\rm min}$"))
        param_slice.remove(param_names.index("$\\sigma_{\\log M}$"))
        param_slice.remove(param_names.index("$\\alpha$"))
        param_slice.remove(param_names.index("$\\log M_1$"))
        param_slice.remove(param_names.index("$\\log M_0$"))
    elif no_logm0:
        param_slice.remove(param_names.index("$\\log M_0$"))

    file = dirname / "emcee_backend.h5"
    backend = emcee.backends.HDFBackend(file, read_only=True)
    original_chain = backend.get_chain(flat=True, discard=discard)
    chain = original_chain.copy()
    for i in range(len(param_names)):
        chain[:, i] = original_chain[:, chain_param_order[i]]

    # Convert logM0_quant (sampling parameter) -> logM0 (model parameter)
    logm0_range = chain[:, m1_ind] - (logm0_min := chain[:, mmin_ind])
    chain[:, m0_ind] = logm0_min + chain[:, m0_ind] * logm0_range
    
    if verbose:
        print(f"{name}: number of trial points (after discarding {discard*backend.shape[0]}) = "
              f"{len(chain)} w/ acceptance frac = {backend.accepted.mean()/backend.iteration:.4f}")

    return dict(chain=chain, param_names=param_names, param_slice=param_slice)

def plot_mcmc_posterior(name="20p5_kmax5", discard=100, fig=None, assembias_params_only=False,
                        no_assembias_params=False, no_logm0=False, verbose=True, **kwargs):
    chain_results = get_chain(name, discard, assembias_params_only, no_assembias_params, no_logm0, verbose)
    chain, param_names, param_slice = [chain_results[x] for x in ["chain", "param_names", "param_slice"]]
    fig = corner.corner(chain[:, param_slice], labels=np.array(param_names)[param_slice],
                        label_kwargs={"fontsize": 16},
                        quantiles=[0.16, 0.84], levels=[0.68, 0.95], fig=fig, **kwargs)
    return fig

def plot_hod_evolution(puffs="21p0"):
    fig, axes = plt.subplots(nrows=3, figsize=(5, 10), gridspec_kw=dict(hspace=0))
    sigma_name = "$\\sigma_{\\log M}$"
    inv_alpha_name = "$\\log(2) / \\alpha$"
    asat_name = "$A_{\\rm sat}$"

    threshes = np.array([-20.0, -20.5, -21.0, -21.0])
    names = ["20p0", "20p5", "21p0", "21p0_z0p2-0p3"]
    names = [f"{x}_puff" if x in puffs else x for x in names]
    redshifts = np.array([0.15, 0.15, 0.15, 0.25])
    chain_results = [get_chain(name, verbose=False) for name in names]

    colors = ["C0" if x < 0.2 else "C1" for x in redshifts]

    for i, param_name in enumerate(["$\\log M_{\\rm min}$", "$\\log M_1$", "$A_{\\rm cen}$"]):
        other_param = [None, []]
        lows, meds, highs = [], [], []
        shade_lows, shade_highs = [], []
        for j, results in enumerate(chain_results):
            param_ind = results["param_names"].index(param_name)
            sigma_ind = results["param_names"].index(sigma_name)
            alpha_ind = results["param_names"].index(inv_alpha_name.replace("\\log(2) / ", ""))
            asat_ind = results["param_names"].index(asat_name)

            param = results["chain"][:, param_ind]
            low, med, high = np.quantile(param, [0.16, 0.5, 0.84])
            lows.append(low)
            meds.append(med)
            highs.append(high)
            if param_name == "$\\log M_{\\rm min}$":
                sigma = np.median(results["chain"][:, sigma_ind])
                shade_lows.append(med - sigma)
                shade_highs.append(med + sigma)
            elif param_name == "$\\log M_1$":
                inv_alpha = np.log10(2) / results["chain"][:, alpha_ind]
                other_param[0] = "$\\langle N_{\\rm sat} \\rangle \\approx 2$"
                other_param[1].append(np.quantile(param + inv_alpha, [0.16, 0.5, 0.84]))
                shade_highs.append(other_param[1][-1][1])
            elif param_name == "$A_{\\rm cen}$":
                asat = results["chain"][:, asat_ind]
                other_param[0] = asat_name
                other_param[1].append(np.quantile(asat, [0.16, 0.5, 0.84]))
        lows, meds, highs = np.array([lows, meds, highs])
        shade_lows = np.array(shade_lows)
        shade_highs = np.array(shade_highs)

        # Shaded region (sigma) only for z=0.15
        if len(shade_lows) and len(shade_highs):
            axes[i].fill_between(threshes[:3] + (redshifts[:3] - 0.2)/3,
                                 shade_highs[:3], shade_lows[:3], color=colors[0], alpha=0.35)

        label = param_name if param_name.startswith("$A") else None
        axes[i].errorbar(threshes[:3] + (redshifts[:3] - 0.2)/3, meds[:3],
                         [meds[:3] - lows[:3], highs[:3] - meds[:3]], label=label,
                         color=colors[0], ecolor=colors[0], fmt="o-", capsize=4)
        axes[i].errorbar(threshes[-1] + (redshifts[-1] - 0.2)/3, meds[-1],
                         [[meds[-1] - lows[-1]], [highs[-1] - meds[-1]]],
                         color=colors[-1], ecolor=colors[-1], fmt="o", capsize=4)

#        if param_name.startswith("$A"):
#             axes[i].text(threshes[0], meds[0], "\n  " + param_name, fontsize=16,
#                          rotation=np.arctan((meds[1]-meds[0])/(threshes[1]-threshes[0]))*180/np.pi + 180,
#                          va="top", rotation_mode="anchor", transform_rotates_text=True, color=colors[0])


        axes[i].invert_xaxis()
        axes[i].set_xlabel("$\\rm M_r$ threshold", fontsize=15)
        axes[i].set_ylabel(param_name.replace("_{\\rm cen}", ""), fontsize=15)

        # Label the widths of shaded regions (both sigma and 1/alpha)
        if len(shade_highs):
            axes[i].arrow(threshes[:2].mean(), meds[:2].mean(), 0,
                          shade_highs[:2].mean() - meds[:2].mean(), length_includes_head=True,
                          color=colors[0], head_width=0.02, overhang=0.9, alpha=0.6, ls="-", lw=1)
            axes[i].arrow(threshes[:2].mean(), meds[:2].mean(), 0, -1e-5, length_includes_head=True,
                          color=colors[0], head_width=0.02, overhang=0.9, alpha=0.6, ls="-", lw=1)
            s = inv_alpha_name if param_name == "$\\log M_1$" else sigma_name
            axes[i].text(np.average(threshes[:2], weights=[1, 1.1]),
                         np.average([shade_highs[:2].mean(), meds[:2].mean()], weights=[1.5, 1]),
                         s, fontsize=16, color=colors[0], alpha=0.6)
        if other_param[0] is not None:
            olows, omeds, ohighs = np.array(other_param[1]).T
            axes[i].errorbar(threshes[:3] + (redshifts[:3] - 0.13)/3, omeds[:3],
                             [omeds[:3] - olows[:3], ohighs[:3] - omeds[:3]], label=other_param[0],
                             color=colors[0], ecolor=colors[0], fmt="D--", capsize=4, mfc="none")
            axes[i].errorbar(threshes[-1] + (redshifts[-1] - 0.13)/3, omeds[-1],
                             [[omeds[-1] - olows[-1]], [ohighs[-1] - omeds[-1]]],
                             color=colors[-1], ecolor=colors[-1], fmt="D", capsize=4, mfc="none")
            loc = (0.1, 0.02) if param_name.startswith("$A") else "upper left"
            axes[i].legend(frameon=True, fontsize=14, loc=loc)

    # Dashed line at Acen = 0
    axes[-1].axhline(0, color="k", ls="--", zorder=-np.inf)
    axes[0].text(0.95, 0.15, f"$\\rm z = {redshifts[0]:.2f}$", color=colors[0], fontsize=16,
                 transform=axes[0].transAxes, ha="right", va="bottom")
    axes[0].text(0.95, 0.05, f"$\\rm z = {redshifts[-1]:.2f}$", color=colors[-1], fontsize=16,
                 transform=axes[0].transAxes, ha="right", va="bottom")

    # Manually adjust limits as needed
    # axes[1].set_ylim(top=14.05, bottom=12.65)
    [x.tick_params(axis='both', which='major', labelsize=13) for x in axes]

def print_bestfit_table(names=["20p5"], print_names=["-20.0"], discard=100, assembias_params_only=False,
                        no_assembias_params=False, no_logm0=False, verbose=False, **kwargs):
    print_header = True
    for name, print_name in zip(names, print_names):
        for no_cic in [False, True, False]:
            if no_cic:
                name = name + "_kmax0"
                print_name = "(no CiC)"
            elif name.endswith("_kmax0"):
                name = name.replace("_kmax0", "_noassembias")
                print_name = "(no $A_{\\rm bias}$)"
            # Get best-fits from the n, wp, and cic values stored in my_blob
            ################################################################
            bfname = name.replace("_puff", "") if "kmax0" in name else name
            my_blob, arg = get_bestfit(name=bfname, num_best_fits=1)
            arg = arg[0]
            my_blob = my_blob.iloc[arg]
            n, wp, cic = my_blob["n"], my_blob["wp"], my_blob["cic"]

            obs_name = bfname.replace("_htcic", "").replace("_noassembias", "")
            dist = galtab.paper2.param_sampler.BetterMultivariateNormal(
                mean=desi_obs[obs_name]["mean"], cov=desi_obs[obs_name]["cov"], allow_singular=True)
            res = (np.array([n, *wp, *cic]) - desi_obs[obs_name]["mean"])

            normres = res * dist.norm
            chi2 = np.sum(np.square(np.dot(normres, dist.cov_object.U)))
            dof = dist.cov_object.rank
            # chi2 = -2 * my_blob["loglike"]
            # dof = 1 + len(wp) + len(cic)
            pval = scipy.stats.chi2(df=dof).sf(chi2)
            zscore = scipy.stats.norm.ppf(1 - pval/2)
            #################################################################

            dirname = pathlib.Path.home() / f"Paper2Data/desi_results/results_{bfname}"
            chain_param_names = np.load(dirname / "sampler.npy", allow_pickle=True)[0].param_names
            param_names = list(galtab.paper2.param_config.kuan_params[-20.5].keys())
            if print_name == "(no $A_{\\rm bias}$)":
                param_names.remove("mean_occupation_centrals_assembias_param1")
                param_names.remove("mean_occupation_satellites_assembias_param1")
            param_names_copy = param_names.copy()
            m1_ind, mmin_ind, m0_ind = param_names.index("logM1"), param_names.index("logMmin"), param_names.index("logM0")
            chain_param_order = [chain_param_names.index(x.replace("logM0", "logM0_quant")) for x in param_names]

            param_names[param_names.index("logMmin")] = "$\\log M_{\\rm min}$"
            param_names[param_names.index("sigma_logM")] = "$\\sigma_{\\log M}$"
            param_names[param_names.index("alpha")] = "$\\alpha$"
            param_names[param_names.index("logM1")] = "$\\log M_1$"
            param_names[param_names.index("logM0")] = "$\\log M_0$"
            if print_name != "(no $A_{\\rm bias}$)":
                param_names[param_names.index("mean_occupation_centrals_assembias_param1")] = "$A_{\\rm cen}$"
                param_names[param_names.index("mean_occupation_satellites_assembias_param1")] = "$A_{\\rm sat}$"
            if print_header:
                print_header = False
                colnames = ["Threshold"] + param_names + ["AIC", "$\\chi^2$", "DoF", "$p$ value", "Tension"]
                header = " & ".join([f"\\colhead{{{x}}}" for x in colnames])
                header = f"\\tablehead{{\n{header}\n}}"
                print(header)
                print("\\startdata")

            param_slice = list(range(len(param_names)))

            if no_assembias_params:
                assert not assembias_params_only
                param_slice.remove(param_names.index("$A_{\\rm cen}$"))
                param_slice.remove(param_names.index("$A_{\\rm sat}$"))
                if no_logm0:
                    param_slice.remove(param_names.index("$\\log M_0$"))
            elif assembias_params_only:
                param_slice.remove(param_names.index("$\\log M_{\\rm min}$"))
                param_slice.remove(param_names.index("$\\sigma_{\\log M}$"))
                param_slice.remove(param_names.index("$\\alpha$"))
                param_slice.remove(param_names.index("$\\log M_1$"))
                param_slice.remove(param_names.index("$\\log M_0$"))
            elif no_logm0:
                param_slice.remove(param_names.index("$\\log M_0$"))

            file = dirname / "emcee_backend.h5"
            backend = emcee.backends.HDFBackend(file, read_only=True)
            original_chain = backend.get_chain(flat=True, discard=discard)
            chain = original_chain.copy()
            for i in range(len(param_names)):
                chain[:, i] = original_chain[:, chain_param_order[i]]

            # Convert logM0_quant (sampling parameter) -> logM0 (model parameter)
            logm0_range = chain[:, m1_ind] - (logm0_min := chain[:, mmin_ind])
            chain[:, m0_ind] = logm0_min + chain[:, m0_ind] * logm0_range

            if verbose:
                print(f"{name}: number of trial points (after discarding {discard*backend.shape[0]}) = "
                      f"{len(chain)} w/ acceptance frac = {backend.accepted.mean()/backend.iteration:.4f}")

            param_string = " & ".join(map(lambda x: format(x, ".3f"), chain[arg]))
            if print_name == "(no $A_{\\rm bias}$)":
                param_string += " & & "

            newlinespace = "[5pt]" if print_name == "(no $A_{\\rm bias}$)" else ""
            num_params = len(chain[arg])
            if verbose:
                print(f"Number of parameters = {num_params}")
            aic = 2 * num_params - 2 * my_blob["loglike"]
            print(f"{print_name} & {param_string} & {aic:.2f} & {chi2:.2f} & {dof} & {pval:.3f} & ${zscore:.2f}\\sigma$\\\\{newlinespace}")
    print("\\enddata")

def print_conf_interval_table(names=["20p5"], print_names=["-20.0"], discard=100, assembias_params_only=False,
                              no_assembias_params=False, no_logm0=False, verbose=False, **kwargs):
    print_header = True
    for name, print_name in zip(names, print_names):
        for no_cic in [False, True, False]:
            if no_cic:
                name = name + "_kmax0"
                print_name = "(no CiC)"
            elif name.endswith("_kmax0"):
                name = name.replace("_kmax0", "_noassembias")
                print_name = "(no $A_{\\rm bias}$)"
            # Get best-fits from the n, wp, and cic values stored in my_blob
            ################################################################
            bfname = name.replace("_puff", "") if "kmax0" in name else name
            my_blob, arg = get_bestfit(name=bfname, num_best_fits=1)
            arg = arg[0]
            my_blob = my_blob.iloc[arg]
            n, wp, cic = my_blob["n"], my_blob["wp"], my_blob["cic"]

            obs_name = bfname.replace("_htcic", "").replace("_noassembias", "")
            dist = galtab.paper2.param_sampler.BetterMultivariateNormal(
                mean=desi_obs[obs_name]["mean"], cov=desi_obs[obs_name]["cov"], allow_singular=True)
            res = (np.array([n, *wp, *cic]) - desi_obs[obs_name]["mean"])

            normres = res * dist.norm
            chi2 = np.sum(np.square(np.dot(normres, dist.cov_object.U)))
            dof = dist.cov_object.rank
            # chi2 = -2 * my_blob.iloc[arg]["loglike"]
            # dof = 1 + len(wp) + len(cic)
            pval = scipy.stats.chi2(df=dof).sf(chi2)
            zscore = scipy.stats.norm.ppf(1 - pval/2)
            #################################################################

            dirname = pathlib.Path.home() / f"Paper2Data/desi_results/results_{bfname}"
            chain_param_names = np.load(dirname / "sampler.npy", allow_pickle=True)[0].param_names
            param_names = list(galtab.paper2.param_config.kuan_params[-20.5].keys())
            if print_name == "(no $A_{\\rm bias}$)":
                param_names.remove("mean_occupation_centrals_assembias_param1")
                param_names.remove("mean_occupation_satellites_assembias_param1")
            param_names_copy = param_names.copy()
            m1_ind, mmin_ind, m0_ind = param_names.index("logM1"), param_names.index("logMmin"), param_names.index("logM0")
            chain_param_order = [chain_param_names.index(x.replace("logM0", "logM0_quant")) for x in param_names]

            param_names[param_names.index("logMmin")] = "$\\log M_{\\rm min}$"
            param_names[param_names.index("sigma_logM")] = "$\\sigma_{\\log M}$"
            param_names[param_names.index("alpha")] = "$\\alpha$"
            param_names[param_names.index("logM1")] = "$\\log M_1$"
            param_names[param_names.index("logM0")] = "$\\log M_0$"
            if print_name != "(no $A_{\\rm bias}$)":
                param_names[param_names.index("mean_occupation_centrals_assembias_param1")] = "$A_{\\rm cen}$"
                param_names[param_names.index("mean_occupation_satellites_assembias_param1")] = "$A_{\\rm sat}$"
            if print_header:
                print_header = False
                colnames = ["Threshold"] + param_names
                header = " & ".join([f"\\colhead{{{x}}}" for x in colnames])
                header = f"\\tablehead{{\n{header}\n}}"
                print(header)
                print("\\startdata")

            param_slice = list(range(len(param_names)))

            if no_assembias_params:
                assert not assembias_params_only
                param_slice.remove(param_names.index("$A_{\\rm cen}$"))
                param_slice.remove(param_names.index("$A_{\\rm sat}$"))
                if no_logm0:
                    param_slice.remove(param_names.index("$\\log M_0$"))
            elif assembias_params_only:
                param_slice.remove(param_names.index("$\\log M_{\\rm min}$"))
                param_slice.remove(param_names.index("$\\sigma_{\\log M}$"))
                param_slice.remove(param_names.index("$\\alpha$"))
                param_slice.remove(param_names.index("$\\log M_1$"))
                param_slice.remove(param_names.index("$\\log M_0$"))
            elif no_logm0:
                param_slice.remove(param_names.index("$\\log M_0$"))

            file = dirname / "emcee_backend.h5"
            backend = emcee.backends.HDFBackend(file, read_only=True)
            original_chain = backend.get_chain(flat=True, discard=discard)
            chain = original_chain.copy()
            for i in range(len(param_names)):
                chain[:, i] = original_chain[:, chain_param_order[i]]
            logprob = backend.get_log_prob(flat=True, discard=discard)

            # Convert logM0_quant (sampling parameter) -> logM0 (model parameter)
            logm0_range = chain[:, m1_ind] - (logm0_min := chain[:, mmin_ind])
            chain[:, m0_ind] = logm0_min + chain[:, m0_ind] * logm0_range

            if verbose:
                print(f"{name}: number of trial points (after discarding {discard*backend.shape[0]}) = "
                      f"{len(chain)} w/ acceptance frac = {backend.accepted.mean()/backend.iteration:.4f}")

            def format_conf_int(quants_16_50_84):
                q16, q50, q84 = quants_16_50_84
                upper, lower = q84 - q50, q50 - q16
                return f"${q50:.3f}_{{-{lower:.3f}}}^{{+{upper:.3f}}}$"
            param_string = " & ".join(map(format_conf_int, np.percentile(chain, [16, 50, 84], axis=0).T))
            if print_name == "(no $A_{\\rm bias}$)":
                param_string += " & & "

            newlinespace = "[5pt]" if print_name == "(no $A_{\\rm bias}$)" else ""
            print(f"{print_name} & {param_string}\\\\{newlinespace}")
    print("\\enddata")

def prob_of_positive_acen(name, discard=100):
    dirname = pathlib.Path.home() / f"Paper2Data/desi_results/results_{name}"
    chain_param_names = np.load(dirname / "sampler.npy", allow_pickle=True)[0].param_names
    param_names = list(galtab.paper2.param_config.kuan_params[-20.5].keys())
    param_names_copy = param_names.copy()
    m1_ind, mmin_ind, m0_ind = param_names.index("logM1"), param_names.index("logMmin"), param_names.index("logM0")
    chain_param_order = [chain_param_names.index(x.replace("logM0", "logM0_quant")) for x in param_names]

    param_names[param_names.index("logMmin")] = "$\\log M_{\\rm min}$"
    param_names[param_names.index("sigma_logM")] = "$\\sigma_{\\log M}$"
    param_names[param_names.index("alpha")] = "$\\alpha$"
    param_names[param_names.index("logM1")] = "$\\log M_1$"
    param_names[param_names.index("logM0")] = "$\\log M_0$"
    param_names[param_names.index("mean_occupation_centrals_assembias_param1")] = "$A_{\\rm cen}$"
    param_names[param_names.index("mean_occupation_satellites_assembias_param1")] = "$A_{\\rm sat}$"

    file = dirname / "emcee_backend.h5"
    backend = emcee.backends.HDFBackend(file, read_only=True)
    original_chain = backend.get_chain(flat=True, discard=discard)
    chain = original_chain.copy()
    for i in range(len(param_names)):
        chain[:, i] = original_chain[:, chain_param_order[i]]

    acen = chain[:, param_names.index("$A_{\\rm cen}$")]
    return 1 - scipy.stats.percentileofscore(acen, 0) / 100

def autocorr_length(name, discard=100):
    dirname = pathlib.Path.home() / f"Paper2Data/desi_results/results_{name}"
    chain_param_names = np.load(dirname / "sampler.npy", allow_pickle=True)[0].param_names
    param_names = list(galtab.paper2.param_config.kuan_params[-20.5].keys())
    param_names_copy = param_names.copy()
    m1_ind, mmin_ind, m0_ind = param_names.index("logM1"), param_names.index("logMmin"), param_names.index("logM0")
    chain_param_order = [chain_param_names.index(x.replace("logM0", "logM0_quant")) for x in param_names]

    param_names[param_names.index("logMmin")] = "$\\log M_{\\rm min}$"
    param_names[param_names.index("sigma_logM")] = "$\\sigma_{\\log M}$"
    param_names[param_names.index("alpha")] = "$\\alpha$"
    param_names[param_names.index("logM1")] = "$\\log M_1$"
    param_names[param_names.index("logM0")] = "$\\log M_0$"
    param_names[param_names.index("mean_occupation_centrals_assembias_param1")] = "$A_{\\rm cen}$"
    param_names[param_names.index("mean_occupation_satellites_assembias_param1")] = "$A_{\\rm sat}$"

    file = dirname / "emcee_backend.h5"
    backend = emcee.backends.HDFBackend(file, read_only=True)
    original_chain = backend.get_chain(flat=False, discard=discard)
    chain = original_chain.copy()
    for i in range(len(param_names)):
        chain[:, i] = original_chain[:, chain_param_order[i]]

    tau = emcee.autocorr.integrated_time(chain, quiet=True)
    return np.max(tau)

In [None]:
cosmo

# Tables and numbers needed for the paper

In [None]:
# Print LaTeX tables for best-fits and confidence intervals
# =========================================================

# Table 2:
print_bestfit_table(["20p0", "20p5", "21p0_puff", "21p0_z0p2-0p3"],
                    ["-20.0", "-20.5", "-21.0", "-21.0 (high z)"])

In [None]:
# Table 3:
print_conf_interval_table(["20p0", "20p5", "21p0_puff", "21p0_z0p2-0p3"],
                          ["-20.0", "-20.5", "-21.0", "-21.0 (high z)"])

In [None]:
# Probability that A_cen > 0
# ==========================
print([prob_of_positive_acen(name) for name in ["20p0", "20p5", "21p0_puff", "21p0_z0p2-0p3"]])

In [None]:
print([prob_of_positive_acen(name) for name in ["20p0_kmax0", "20p5_kmax0", "21p0_kmax0", "21p0_z0p2-0p3_kmax0"]])

In [None]:
# Integrated Autocorrelation Time
# ===============================
print([autocorr_length(name) for name in ["20p0", "20p5", "21p0", "21p0_z0p2-0p3"]])

In [None]:
# Autocorrelation length using halotools instead of galtab CiC
# implementation (should be longer since halotools is noisy)
# ============================================================
print([autocorr_length(name) for name in ["20p0_htcic", "20p5_htcic", "21p0_htcic", "21p0_z0p2-0p3_htcic"]])

In [None]:
# Autocorrelation lengths using only wp(rp), and not CiC
# ======================================================
print([autocorr_length(name) for name in ["20p0_kmax0", "20p5_kmax0", "21p0_kmax0", "21p0_z0p2-0p3_kmax0"]])

# Plot Figures

## Fig 1: desi-footprint

In [None]:
print(250**3 / ms.util.volume(173.3052, [0.1, 0.2], cosmo))
print(400**3 / ms.util.volume(173.3052, [0.1, 0.2], cosmo))
# Bolshoi-Planck is 5.513x the volume of 0.1 < z < 0.2 in SV3
# SMDPL is 22.581x the volume of 0.1 < z < 0.2 in SV3

In [None]:
# python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_21p0_kmax5.npz --abs-mr-max " -21.0" --wp-rand-frac 0.1 --apply-pip-weights --cic-kmax 5
kw = dict(abs_mr_max=-21.0, wp_rand_frac=0.1, apply_pip_weights=True, dont_apply_pip_weights=False,
          cic_kmax=5, cosmo=galtab.paper2.param_config.cosmo, data_dir=pathlib.Path.home() / "data/DESI/SV3/clean_fuji/",
          progress=True, verbose=True, first_n=None, num_threads=1, zmin=galtab.paper2.param_config.zmin,
          zmax=galtab.paper2.param_config.zmax, logmmin=-np.inf, passive_evolved_mags=False, kuan_mags=False,
          rp_edges=galtab.paper2.param_config.rp_edges, pimax=galtab.paper2.param_config.pimax,
          cic_edges=galtab.paper2.param_config.cic_edges, proj_search_radius=galtab.paper2.param_config.proj_search_radius,
          cylinder_half_length=galtab.paper2.param_config.cylinder_half_length, purity_factor=1.0, effective_area_sqdeg=None)
calc = galtab.paper2.desi_observables.ObservableCalculator(**kw)

In [None]:
print("Average cylinder sky completeness:", np.mean(calc.randcyl_density[calc.randcic_mask] / calc.num_rand_files / 2500))

In [None]:
print("Average fiber completeness (-21.0):", np.mean(1 / calc.weights))

In [None]:
galtab.paper2.param_config.zmax

In [None]:
# python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_20p0.npz --abs-mr-max " -20.0" --wp-rand-frac 0.1 --apply-pip-weights --cic-kmax 5
kw20p0 = dict(abs_mr_max=-20.0, wp_rand_frac=0.1, apply_pip_weights=True, dont_apply_pip_weights=False,
          cic_kmax=5, cosmo=galtab.paper2.param_config.cosmo, data_dir=pathlib.Path.home() / "data/DESI/SV3/clean_fuji/",
          progress=True, verbose=True, first_n=None, num_threads=1, zmin=galtab.paper2.param_config.zmin,
          zmax=galtab.paper2.param_config.zmax, logmmin=-np.inf, passive_evolved_mags=False, kuan_mags=False,
          rp_edges=galtab.paper2.param_config.rp_edges, pimax=galtab.paper2.param_config.pimax,
          cic_edges=galtab.paper2.param_config.cic_edges, proj_search_radius=galtab.paper2.param_config.proj_search_radius,
          cylinder_half_length=galtab.paper2.param_config.cylinder_half_length, purity_factor=1.0, effective_area_sqdeg=None)
calc20p0 = galtab.paper2.desi_observables.ObservableCalculator(**kw20p0)
print("Average fiber completeness (-20.0):", np.mean(1 / calc20p0.weights))

# python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_20p5.npz --abs-mr-max " -20.5" --wp-rand-frac 0.1 --apply-pip-weights --cic-kmax 5
kw20p5 = dict(abs_mr_max=-20.5, wp_rand_frac=0.1, apply_pip_weights=True, dont_apply_pip_weights=False,
          cic_kmax=5, cosmo=galtab.paper2.param_config.cosmo, data_dir=pathlib.Path.home() / "data/DESI/SV3/clean_fuji/",
          progress=True, verbose=True, first_n=None, num_threads=1, zmin=galtab.paper2.param_config.zmin,
          zmax=galtab.paper2.param_config.zmax, logmmin=-np.inf, passive_evolved_mags=False, kuan_mags=False,
          rp_edges=galtab.paper2.param_config.rp_edges, pimax=galtab.paper2.param_config.pimax,
          cic_edges=galtab.paper2.param_config.cic_edges, proj_search_radius=galtab.paper2.param_config.proj_search_radius,
          cylinder_half_length=galtab.paper2.param_config.cylinder_half_length, purity_factor=1.0, effective_area_sqdeg=None)
calc20p5 = galtab.paper2.desi_observables.ObservableCalculator(**kw20p5)
print("Average fiber completeness (-20.5):", np.mean(1 / calc20p5.weights))

# python -m galtab.paper2.desi_observables -vpn4 -o desi_obs_21p0_z0p2-0p3.npz --abs-mr-max " -21.0" --wp-rand-frac 0.1 --apply-pip-weights --cic-kmax 5 --zmin 0.2 --zmax 0.3
kw21p0hz = dict(abs_mr_max=-20.0, wp_rand_frac=0.1, apply_pip_weights=True, dont_apply_pip_weights=False,
          cic_kmax=5, cosmo=galtab.paper2.param_config.cosmo, data_dir=pathlib.Path.home() / "data/DESI/SV3/clean_fuji/",
          progress=True, verbose=True, first_n=None, num_threads=1, zmin=0.2,
          zmax=0.3, logmmin=-np.inf, passive_evolved_mags=False, kuan_mags=False,
          rp_edges=galtab.paper2.param_config.rp_edges, pimax=galtab.paper2.param_config.pimax,
          cic_edges=galtab.paper2.param_config.cic_edges, proj_search_radius=galtab.paper2.param_config.proj_search_radius,
          cylinder_half_length=galtab.paper2.param_config.cylinder_half_length, purity_factor=1.0, effective_area_sqdeg=None)
calc21p0hz = galtab.paper2.desi_observables.ObservableCalculator(**kw21p0hz)
print("Average fiber completeness (-21.0 high-z):", np.mean(1 / calc21p0hz.weights))

In [None]:
samples = [dict(thresh=-20.0, zmin=0.1, zmax=0.2),
           dict(thresh=-20.5, zmin=0.1, zmax=0.2),
           dict(thresh=-21.0, zmin=0.1, zmax=0.2),
           dict(thresh=-21.0, zmin=0.2, zmax=0.3)]
for i, sample in enumerate(samples):
    thresh, zmin, zmax = sample["thresh"], sample["zmin"], sample["zmax"]
    if i > 0:
        print()
    print(f"S{i+1}: M_R <= {thresh}; {zmin} <= z <= {zmax}\n=================================")
    print("N_tot =", np.sum((calc.fastphot["abs_rmag_0p1"] <= thresh) & (zmin <= calc.fastphot["Z"]) & (calc.fastphot["Z"] <= zmax)))
    print("N_cyl =", np.sum((calc.fastphot["abs_rmag_0p1"] <= thresh) & (zmin <= calc.fastphot["Z"]) & (calc.fastphot["Z"] <= zmax) & calc.randcic_mask))

In [None]:
plot_footprint(calc, region_num=11)
plt.savefig("desi-footprint.png", bbox_inches="tight", dpi=200)
plt.savefig("desi-footprint.pdf", bbox_inches="tight", dpi=200)
plt.show()

## Fig 2: desi-luminosity-z

In [None]:
plot_redshift_vs_rmag(fastphot["Z"], fastphot["abs_rmag_0p1"])

plt.savefig("desi-luminosity-z.png", bbox_inches="tight", dpi=200)
plt.savefig("desi-luminosity-z.pdf", bbox_inches="tight", dpi=200)
plt.show()

## Fig 3: hod-feature-importance

In [None]:
cic_kmax = 10

param_names = list(galtab.paper2.param_config.kuan_params[-20.5].keys())
cic_moment_numbers = importance_results["cic_moment_numbers"][:cic_kmax]
hod_samples = importance_results["hod_samples"]
nobs_samples = importance_results["nobs_samples"]

n_rp_bins = len(galtab.paper2.param_config.rp_edges) - 1
feature_labels = [-n_rp_bins - 2, *range(-n_rp_bins, 0), *cic_moment_numbers]
feature_labels_str = ["CiC" + str(x) if x > 0 else
                      "n" if x == -n_rp_bins - 2 else
                      "wp" + str(x + n_rp_bins + 1) for x in feature_labels]
hod_df = pd.DataFrame(dict(zip(param_names, hod_samples.T)))
nobs_df = pd.DataFrame(dict(zip(feature_labels_str, nobs_samples.T)))

rf = {}
for hod_param_name in param_names:
    rf[hod_param_name] = sklearn.ensemble.RandomForestRegressor(
        n_estimators=20, max_leaf_nodes=500)  # 10,000 leaves per RF - use less to make SHAP run faster
    rf[hod_param_name].fit(nobs_df, hod_df[hod_param_name])

In [None]:
print(*[est.get_n_leaves() for est in rf[hod_param_name].estimators_])

perm_result = {name: sklearn.inspection.permutation_importance(
    rf[name], nobs_df, hod_df["logM1"]) for name in param_names}
shap_values = {name: shap.TreeExplainer(rf[name])(nobs_df) for name in param_names}

In [None]:
hod_param_names = ["logMmin", "sigma_logM", "mean_occupation_centrals_assembias_param1",
                   "logM1", "alpha", "mean_occupation_satellites_assembias_param1"]
plot_importance(shap_values, hod_param_names, perm_result=None)
plt.suptitle("     Central HOD parameters" + "\n"*13 + "     Satellite HOD parameters", fontsize=32, weight="bold", y=0.94)

plt.savefig("hod-feature-importance.png", bbox_inches="tight")
plt.savefig("hod-feature-importance.pdf", bbox_inches="tight")
plt.show()

### ~~Fig 3b~~ Fig 9: hod-feature-shap-beeswarm

In [None]:
plot_shap_beeswarm(shap_values, hod_param_names, perm_result=None)

# fig = plt.gcf()
# fig.canvas.draw_idle()  # necessary to get the proper positions
# for cax in fig.axes[-6:]:
#     pos = cax.get_position()
#     print(pos.width, pos.x0, pos.x1)
#     pos.x1 += 5 * pos.width
#     cax.set_position(pos)
#     cax.set_aspect(30)
plt.suptitle("Central HOD parameters" + "\n"*13 + "Satellite HOD parameters", fontsize=32, weight="bold", y=0.94)


plt.savefig("hod-feature-shap-beeswarm.png", bbox_inches="tight")
plt.savefig("hod-feature-shap-beeswarm.pdf", bbox_inches="tight")
plt.show()

## Fig 4: cictabulator-cartoon

- Mostly done in Powerpoint, but the bar plot uses real CICTabulator data below

In [None]:
plt.subplots(figsize=(5, 5))
plt.bar(cic_cens, desi_mean["20p5"][cic_mask], width=0.9*np.diff(cic_edges))

mean, std = desi_mean["20p5_kmax5"][cic_mask][:2]
plt.axvline(mean, color="k", ls="--")
plt.arrow(mean, 0.1, std/2, 0, color="k", lw=0.5, head_length=0.5, head_width=0.005)
plt.arrow(mean, 0.1, -std/2, 0, color="k", lw=0.5, head_length=0.5, head_width=0.005)
plt.text(mean+std/50, 0.15, "$\\rm \\mu_{CiC}$", fontsize=14)
plt.text(mean+std/6, 0.105, "$\\rm \\sigma_{CiC}$", fontsize=14)
plt.gca().xaxis.set_visible(False)
plt.gca().yaxis.set_visible(False)
plt.xlim(-0.6, 9.6)
# plt.savefig("cictabulator-cartoon-barplot.png", bbox_inches="tight")
plt.show()

## Fig 5: mcmc-posterior-20p5

In [None]:
fig = plot_mcmc_posterior("20p5", show_titles=True)
[x.set_title(x.get_title(), fontsize=16) for x in fig.axes]
[x.set_xlabel(x.get_xlabel(), fontsize=18) for x in fig.axes]
[x.set_ylabel(x.get_ylabel(), fontsize=18) for x in fig.axes]
[x.tick_params(labelsize=14) for x in fig.axes]
plt.savefig("mcmc-posterior-20p5.png", bbox_inches="tight")
plt.savefig("mcmc-posterior-20p5.pdf", bbox_inches="tight")
plt.show()

## Fig 6: corner-only-assembias-vs-kmax0

In [None]:
fig = plot_mcmc_posterior("20p5_kmax0", show_titles=True, assembias_params_only=True, color="C0", plot_datapoints=False)
titles = fig.axes[0].get_title(), fig.axes[3].get_title()
plot_mcmc_posterior("20p5", show_titles=True, assembias_params_only=True, plot_datapoints=False,
                    weights=198_000/58_000*np.ones(58_000), fig=fig)

fig.axes[1].text(0.5, 0.5, "Fiducial (-20.5)", color="k", ha="center", fontsize=16,
                 transform=fig.axes[1].transAxes)
fig.axes[1].text(0.5, 0.65, "No CiC", color="C0", ha="center", fontsize=16,
                 transform=fig.axes[1].transAxes)
fig.axes[0].text(0.5, 1.2, titles[0], color="C0", ha="center", fontsize=14, transform=fig.axes[0].transAxes)
fig.axes[3].text(0.5, 1.2, titles[1], color="C0", ha="center", fontsize=14, transform=fig.axes[3].transAxes)
plt.savefig("corner-only-assembias-20p5-vs-kmax0.png", bbox_inches="tight")
plt.savefig("corner-only-assembias-20p5-vs-kmax0.pdf", bbox_inches="tight")
plt.show()

## Fig 7: desi-summary-stats

### galtab [mc-mode] CiC distribution (*this* one is going in the paper)

In [None]:
# Same, but scale P(CiC) by CiC^2
fig, axes = plt.subplots(ncols=3, nrows=4, figsize=(13, 18), gridspec_kw={"width_ratios": [1, 5, 5], "wspace": 0.3})

all_threshes = [-20.0, -20.5, -21.0, -21.0]
all_highz = [False, False, False, True]
for i in range(len(all_threshes)):
    figax = fig, axes[i]
    for j in range(len(all_threshes)):
        if j == i:
            continue
        else:
            threshes = [all_threshes[j]]
            highz = all_highz[j]
            plot_desi_obs(nolines=True, highz=highz, threshes=threshes, figax=figax,
                          color_start=7, override_label=True, alpha=0.15, xscale_cic=True)
    threshes = [all_threshes[i]]
    highz = all_highz[i]
    plot_desi_obs(nolines=True, highz=highz, threshes=threshes, figax=figax, color_start=i, xscale_cic=True)

for i in range(len(all_threshes)):
    axes[i][-1].legend(frameon=False, fontsize=14)
    figax = fig, axes[i]
    plot_bestfit_obs("20p0", figax=figax, color_ind=0, print_chi2=i==0, bring_front=i==0, lw=1+2*(i==0), xscale_cic=True)
    plot_bestfit_obs("20p5", figax=figax, color_ind=1, print_chi2=i==1, bring_front=i==1, lw=1+2*(i==1), xscale_cic=True)
    plot_bestfit_obs("21p0_puff", figax=figax, color_ind=2, print_chi2=i==2, bring_front=i==2, lw=1+2*(i==2), xscale_cic=True)
    plot_bestfit_obs("21p0_z0p2-0p3", figax=figax, color_ind=3, print_chi2=i==3, bring_front=i==3, lw=1+2*(i==3), xscale_cic=True)

plt.savefig("desi-summary-stats.png", bbox_inches="tight")
plt.savefig("desi-summary-stats.pdf", bbox_inches="tight")
plt.show()

### [below is the single row version of this figure I use for talks]

In [None]:
figax = plt.subplots(ncols=3, nrows=1, figsize=(13, 4.5), gridspec_kw={"width_ratios": [1, 5, 5], "wspace":0.3})
name, thresh = "20p0", -20.0
color = 0
highz = False
plot_desi_obs(nolines=True, highz=highz, threshes=[thresh], figax=figax,
              color_start=color, xscale_cic=True)
plot_bestfit_obs(name, figax=figax, color_ind=color, xscale_cic=True, print_chi2=True)
plot_bestfit_obs(name + "_noassembias", figax=figax, color_ind=color, xscale_cic=True, print_chi2=True, ls="--")
name, thresh = "21p0_puff", -21.0
color = 3
highz = False
plot_desi_obs(nolines=True, highz=highz, threshes=[thresh], figax=figax,
              color_start=color, xscale_cic=True)
plot_bestfit_obs(name, figax=figax, color_ind=color, xscale_cic=True, print_chi2=True)
# plot_bestfit_obs(name + "_noassembias", figax=figax, color_ind=color, xscale_cic=True, print_chi2=True, ls="--")
plot_bestfit_obs(name[:-5] + "_kmax0", plot_cic=False, figax=figax, color_ind=color, xscale_cic=True, print_chi2=True, ls="--")
figax[-1][-1].plot([], [], color="grey", lw=3, ls="--", label="No $A_{\\rm bias}$")
figax[-1][-1].legend(frameon=False, fontsize=14)
# plt.savefig("desi-summary-stats-quick-with-fits.png", bbox_inches="tight")
plt.show()

## Fig 8: hod-evolution-punchline

- Evolution of Acen, Mmin (+ sigma), and M1 (+ alpha) with luminosity and redshift

In [None]:
plot_hod_evolution()
plt.savefig("hod-evolution-punchline.png", bbox_inches="tight")
plt.savefig("hod-evolution-punchline.pdf", bbox_inches="tight")
plt.show()

# Appendix Figures

## Fig 9 (Appendix A):

- [see directly below Fig 3 for this figure]

## ~~Fig 4~~ Fig 10 (Appendix B): tabulator-accuracy-vs-runtime

In [None]:
# Smaller version just for talks
plot_accuracy_vs_runtime(kmax=2, remove_max_weight_time_plot=True, percent_error=True)
fig = plt.gcf()
for ax in fig.axes[-2:]:
    ax.set_ylim(top=-0.5*ax.get_ylim()[0])
for ax in fig.axes[-4:-2]:
    ax.set_ylim(top=-0.4*ax.get_ylim()[0])
    ax.set_yticks([-2, -1, 0, 1])
# plt.savefig("tabulator-accuracy-kmax2.png", bbox_inches="tight")
plt.show()

In [None]:
plot_accuracy_vs_runtime(kmax=5, multiply_time_by=1, plot_ht_mean_err=False, remove_max_weight_time_plot=True,
                         touchaxes=True, percent_error=True)
fig = plt.gcf()
for ax in fig.axes[-2:]:
    ax.set_ylim(-55, 55)
    ax.set_yticks([-50, -25, 0, 25, 50])
for ax in fig.axes[-4:-2]:
    ax.set_ylim(-30, 30)
    ax.set_yticks([-20, -10, 0, 10, 20])
for ax in fig.axes[-6:-4]:
    ax.set_ylim(-12, 12)
    ax.set_yticks([-10, -5, 0, 5, 10])
for ax in fig.axes[-8:-6]:
    ax.set_ylim(top=-0.5*ax.get_ylim()[0])
    ax.set_yticks([-4, -2, 0, 2])
for ax in fig.axes[-10:-8]:
    ax.set_ylim(top=-0.4*ax.get_ylim()[0])
    ax.set_yticks([-2, -1, 0, 1])

# plt.savefig("tabulator-accuracy-vs-runtime.png", bbox_inches="tight")
# plt.savefig("tabulator-accuracy-vs-runtime.pdf", bbox_inches="tight")
plt.show()

## ~~Fig 8~~ Figure 11 (Appendix B)

In [None]:
fig = plot_mcmc_posterior("20p5_htcic", show_titles=True, assembias_params_only=True, color="C1", plot_datapoints=False)
titles = fig.axes[0].get_title(), fig.axes[3].get_title()
plot_mcmc_posterior("20p5", show_titles=True, assembias_params_only=True, plot_datapoints=False, fig=fig)

fig.axes[1].text(0.5, 0.5, "Fiducial (-20.5)", color="k", ha="center", fontsize=16,
                 transform=fig.axes[1].transAxes)
fig.axes[1].text(0.5, 0.65, "halotools CiC", color="C1", ha="center", fontsize=16,
                 transform=fig.axes[1].transAxes)
fig.axes[0].text(0.5, 1.2, titles[0], color="C1", ha="center", fontsize=14, transform=fig.axes[0].transAxes)
fig.axes[3].text(0.5, 1.2, titles[1], color="C1", ha="center", fontsize=14, transform=fig.axes[3].transAxes)
plt.savefig("corner-only-assembias-20p5-vs-htcic.png", bbox_inches="tight")
plt.savefig("corner-only-assembias-20p5-vs-htcic.pdf", bbox_inches="tight")
plt.show()