# Demo of differentiable model for galaxy merging

The model in the `differentiable_smhm` repo is intended to be applied to a _complete_ catalog of subhalos, i.e., to a catalog of every subhalo that ever lived, including those that previously merged with their host halo. The idea behind this is twofold:
1. It is somewhat ambiguous when a subhalo has actually merged with its host, and the details of when that happens depend sensitively upon the subhalo-finding algorithm.
2. Even with perfect knowledge of when the subhalo merged, it is the time when the _satellite galaxy_ merges with its associated central that is what matters for predicting the observations, and the relationship between subhalo and galaxy merging is highly uncertain.

Based on (1) & (2), we choose to parametrize the moment the galaxy merges, and to vary these parameters on an equal footing with the SMHM parameters. As with all ingredients in this modeling approach, this is done probabilistically. The quantity we actually parametrize is $P_{\rm merge}(X_{\rm halo}),$ the probability that a subhalo/galaxy with properties $X_{\rm halo}$ has merged with its associated central. In the model, we _first_ map some value of $M_{\star}$ onto every subhalo that has ever lived. And _then_ we map the quantity $P_{\rm merge}$ onto every subhalo in the catalog. For satellites, we then transfer $P_{\rm merge}\cdot M_{\star}$ into the associated central, so that the mass of the satellite ends up as $(1-P_{\rm merge})\cdot M_{\star}.$ For centrals, $P_{\rm merge}=0.$ 

Computationally, this mass transfer can actually be quite fast because we can precompute the _index_ into which each satellite will transfer some portion of its mass. This is straightforward for a single-redshift model like we are working on here, but for time-evolving models this will need to be generalized.

The code in the `sigmoid_disruption` module controls the behavior of $P_{\rm merge}(X_{\rm halo})$ in terms of model parameters $\theta_{\rm merge},$ where $X_{\rm halo} = \{\nu, M_{\rm host}\},$ where $\nu\equiv \log_{10}V_{\rm max}/V_{\rm peak},$ the ratio of present-day circular velocity to its peak historical value, and $M_{\rm host},$ the mass of the host halo. Physically, subhalos experience strong tidal forces as they orbit within the host halo, and these tidal forces eventually disrupt the subhalo and destroy the satellite at the center of the subhalo; thus $\nu$ is a natural variable to consider as a proxy for the primary regulator of this process. We also know that satellite-specific processes vary with host halo mass, and so our parametrization allows $M_{\rm host}$ to play an additional role.


### Warmup exercises

1. Fiddle around with the merging model parameters and remake the plot below to build intuition behind what each parameter does. It may be easier to dispense with the fancy color-coding and just make a simpler single-curve version. 
2. Check your intuitive understanding of the model parameters by reviewing the source code and ensuring that it makes good sense.

Pay special attention to how the two sigmoid functions are stitched together to create the two-dimensional dependence. The basic idea is that we build our parametrization such that $P_{\rm merge}$ has a simple sigmoid dependence upon $\nu.$ That is, the "first-order" dependence of $P_{\rm merge}$ looks like this:
$$P_{\rm merge}(\nu) = p_{\rm low} + \frac{p_{\rm high} - p_{\rm low}}{1 + \exp\left[-k\cdot(\nu-\nu_0)\right]},$$
where $p_{\rm low}$ and $p_{\rm high}$ control the asymptotic behavior of $P_{\rm merge}.$ To capture the additional dependence upon $M_{\rm host},$ we elevate the quantities $p_{\rm low}$ and $p_{\rm high}$ to themselves be functions of $M_{\rm host}.$

The way this works in the source code may not be immediately obvious, but this is important and is worth taking the time to understand in detail, because we will use these same techniques again and again as we continue to build models that capture multivariate dependencies. It may be useful to have a look at [this gist](https://gist.github.com/aphearin/526e8c67e7dd1ed1adeec52fef5b241e) to develop a thorough understanding of how the sigmoid-stitching technique works.
    
3. Once more practice making a few plots of the gradients of using `jax.grad`. Don't go overboard with this: by now this is really just for practice since you already studied the mechanics of `jax.grad` in the previous notebook.

In [1]:
import matplotlib.cm as cm
from matplotlib import lines as mlines
from matplotlib import pyplot as plt
import numpy as np

In [2]:
from differentiable_smhm.galhalo_models import sigmoid_disruption

ModuleNotFoundError: No module named 'differentiable_smhm'

In [None]:
n_colors = 500
colors = cm.coolwarm(np.linspace(0, 1, n_colors)) # blue first
n_h = 1_000
zz = np.zeros(n_h)

In [None]:
blue_line = mlines.Line2D([],[],ls='-', c=colors[0], label=r'$M_{\rm host}=10^{12}M_{\odot}$')
red_line = mlines.Line2D([],[],ls='-', c=colors[-1], label=r'$M_{\rm host}=10^{15}M_{\odot}$')

In [None]:
log_vmax_by_vmpeak_arr = np.linspace(-2.5, 0, n_h)
logmhost_plot = np.linspace(12, 15, n_colors)

fig, ax = plt.subplots(1, 1)
xscale = ax.set_xscale('log')
ylim = ax.set_ylim(0.0, 1.02)
xlim = ax.set_xlim(10**log_vmax_by_vmpeak_arr.min(), 1)

for logmhost, c in zip(logmhost_plot, colors):
    dprob = sigmoid_disruption.satellite_disruption_probability(log_vmax_by_vmpeak_arr, zz+logmhost)
    __=ax.plot(10**log_vmax_by_vmpeak_arr, dprob, color=c)

xlabel = ax.set_xlabel(r'$V_{\rm max}/V_{\rm peak}$')
ylabel = ax.set_ylabel(r'$P_{\rm merge}(V_{\rm max}, V_{\rm peak},M_{\rm host})$')
title = ax.set_title(r'${\rm differentiable\ merging\ model}$')
leg = ax.legend(handles=[red_line, blue_line], loc='center left')

fig.savefig('dprob_sats_vs_vmax_ratio.png', bbox_extra_artists=[xlabel, ylabel], bbox_inches='tight', dpi=200) 

### Not-so-warmup exercises

Again let's apply this model to a catalog of simulated subhalo.

1. Repeat the not-so-warmup exercise from the SMHM notebook, but extend your analysis to include the merging parameters.
2. Now calculate a new quantity: $F_{\rm ex-situ}(m_{\star}),$ the _ex-situ fraction_ of central galaxies as a function of $m_{\star}\equiv\log_{10}M_{\star}.$ The ex-situ fraction in this context is defined as the fraction of the central's mass brought in by satellite mergers. Select some bins of $m_{\star},$ and write a new JAX function that is differentiable with respect to the model parameters $\theta_{\rm merge}.$ Again make plots for each of your merging parameters, and verify that the gradients behave in accord with your understanding of the parameters.