In [11]:
import numpy as np
import jax
import jax.numpy as jnp

from collections import OrderedDict
import h5py

from diffsmhm.galhalo_models.sigmoid_smhm import (
    logsm_from_logmhalo_jax,
    DEFAULT_PARAM_VALUES as smhm_params
)

from diffsmhm.galhalo_models.sigmoid_smhm_sigma import (
    logsm_sigma_from_logmhalo_jax,
    DEFAULT_PARAM_VALUES as smhm_sigma_params
)

from diffsmhm.diff_stats.cpu.tw_kernels import (                                
    tw_kern_mstar_bin_weights_and_derivs_cpu                                    
)   

## About this Notebook

Demonstration of our use of jax to produce a weight for each object.
Note that this notebook uses a simplified model to make mock data generation easier.

In [34]:
# generate data
np.random.seed(42)

loghalomass = np.random.uniform(8.0, 16.0, size=1000)

# munge params
smhm_params_array = np.array(list(smhm_params.values()), dtype=np.float32)
smhm_sigma_params_array = np.array(list(smhm_params.values()), dtype=np.float32)

# define grads with jax
stellar_mass_jac = jax.jacfwd(logsm_from_logmhalo_jax, argnums=1)
stellar_mass_sigma_jac = jax.jacfwd(logsm_sigma_from_logmhalo_jax, argnums=1)

In [45]:
# compute mean and sigma
stellar_mass = np.array(logsm_from_logmhalo_jax(loghalomass, smhm_params_array), dtype=np.float32)
stellar_mass_sigma = np.array(logsm_sigma_from_logmhalo_jax(loghalomass, smhm_sigma_params_array), dtype=np.float32)

# compute gradients
stellar_mass_grad = np.array(stellar_mass_jac(loghalomass, smhm_params_array), dtype=np.float32).T
stellar_mass_sigma_grad = np.array(stellar_mass_sigma_jac(loghalomass, smhm_sigma_params_array), dtype=np.float32).T

In [46]:
# compute weight
w = np.zeros(len(loghalomass), dtype=np.float32)
dw = np.zeros((stellar_mass_grad.shape[0], len(loghalomass)), dtype=np.float32)

mass_bin_edges = np.array([9.0, 10.0], dtype=np.float32)

tw_kern_mstar_bin_weights_and_derivs_cpu(
    stellar_mass,
    stellar_mass_grad,
    stellar_mass_sigma,
    stellar_mass_sigma_grad,
    mass_bin_edges[0], mass_bin_edges[1],
    w,
    dw
)