In [2]:
from functools import partial
import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

# power law

@jax.jit
def log_power_law(z, gamma):
    return gamma*jnp.log1p(z)

# madau dickinson

@jax.jit
def log_madau_dickinson(z, gamma, kappa, zp):
    return gamma*jnp.log1p(z) - jnp.log1p( ((1.+z)/(1.+zp))**(gamma+kappa))

@jax.jit
def log_madau_dickinson_norm(z, R0, gamma, kappa, zp):
    return jnp.log(R0) + jnp.log1p((1.+zp)**(-gamma-kappa)) + _log_phiMD(z, gamma, kappa, zp)

In [4]:
import logging
logger = logging.getLogger(__name__)

class dummy_rate:
    def __init__(self, fiducial_lambda_r = None): 
        self.name      = 'dummy rate'
        self.lambda_r  = {}
        self.rate_keys = self.lambda_r.keys()
        self.vectorize_computation = False
        self.update_fiducial(fiducial_lambda_r)
        logger.info(f'Created `{self.name}` model with fiducial parameters: {self.fiducial_lambda_r}')
    
    def update_fiducial(self, fiducial_lambda_r=None):
        # useless here, but will be inheritated by all other classes
        if fiducial_lambda_r:
            # check if fiducial_lambda_c contains correct keys
            invalid_keys = [k for k in fiducial_lambda_r.keys() if k not in self.rate_keys]
            if invalid_keys:
                raise ValueError(f"`fiducial_lambda_m` contains invalid keys {invalid_keys}. Valid keys are {self.rate_keys}.")
            else:
                if all(not isinstance(v, jnp.ndarray) for v in fiducial_lambda_r.values()):
                    self.lambda_r.update(fiducial_lambda_r)
                else:
                    raise ValueError("Use only `float` parameters to initialize the class.")
            self.fiducial_lambda_r = self.lambda_r.copy()   # store fiducials 
        else:
            self.fiducial_lambda_r = self.lambda_r.copy()   # store fiducials 
    
    def update_params(self, lambda_r):
        
        # check if lambda_c contains correct keys
        invalid_keys = [k for k in lambda_r.keys() if k not in self.rate_keys]
        if invalid_keys:
            raise ValueError(f"`lambda_m` contains invalid keys {invalid_keys}. Valid keys are {self.rate_keys}.")   
        
        # useless here, but will be inheritated by all other classes
        up_params            = {k:jnp.atleast_1d(lambda_r[k]) for k in lambda_r}
        keys_to_update       = up_params.keys()
        shape_updated_params = jnp.array([up_params[k].shape[0] for k in keys_to_update])
        
        if not jnp.all(jnp.isclose(shape_updated_params, shape_updated_params[0])):
            raise ValueError("All parameter arrays must have the same length")
        
        self.vectorize_computation = shape_updated_params[0] > 1
        
        if shape_updated_params[0] == 1:
            self.lambda_r.update({k:float(up_params[k][0]) for k in up_params}) 
        else:
            # broadcast not updated keys to have the same shape of the updated ones 
            # and values equal to the fiducial ones
            non_up_params = {k:jnp.full(shape_updated_params[0], self.fiducial_lambda_r[k]) for k in self.rate_keys if k not in keys_to_update}
            self.lambda_r.update(up_params)
            self.lambda_r.update(non_up_params)
        pass

    def compute_rate(self, z):
        return jnp.ones_like(z)
     
#######################################
    
class power_law(dummy_rate):
    def __init__(self, fiducial_lambda_r = None): 
        self.name      = 'power law rate'
        self.lambda_r  = {'gamma':2.7}
        self.rate_keys = self.lambda_r.keys()
        self.vectorize_computation = False
        self.update_fiducial(fiducial_lambda_r)
        logger.info(f'Created `{self.name}` model with fiducial parameters: {self.fiducial_lambda_r}')
        
    def compute_rate(self, z):
        if self.vectorize_computation:
            if len(jnp.atleast_1d(z).shape)<3:
                log_rate = jax.vmap(log_power_law, in_axes=(None,0))(z, *self.lambda_r.values())
                return jnp.exp(log_rate)
            else:
                log_rate = jax.vmap(log_power_law, in_axes=(0,0))(z, *self.lambda_r.values())
                return jnp.exp(log_rate)                
        else:
            return jnp.exp(log_power_law(z, *self.lambda_r.values()))    
    
class madau_dickinson(dummy_rate):
    def __init__(self, fiducial_lambda_r = None): 
        self.name      = 'Madau-Dickinson rate'
        self.lambda_r  = {'gamma': 2.7, 'kappa':3., 'zp':2.}
        self.rate_keys = self.lambda_r.keys()
        self.vectorize_computation = False
        self.update_fiducial(fiducial_lambda_r)
        logger.info(f'Created `{self.name}` model with fiducial parameters: {self.fiducial_lambda_r}')
        
    def compute_rate(self, z):
        if self.vectorize_computation:
            if len(jnp.atleast_1d(z).shape)<3:
                log_rate = jax.vmap(log_madau_dickinson, in_axes=(None,0,0,0))(z, *self.lambda_r.values())
                return jnp.exp(log_rate)
            else:
                log_rate = jax.vmap(log_madau_dickinson, in_axes=(0,0,0,0))(z, *self.lambda_r.values())
                return jnp.exp(log_rate)                
        else:
            return jnp.exp(log_madau_dickinson(z, *self.lambda_r.values()))
    
class madau_dickinson_norm(dummy_rate):
    def __init__(self, fiducial_lambda_r = None): 
        self.name      = 'normalized Madau-Dickinson rate'
        self.lambda_r  = {'R0': 1000 , 'gamma': 2.7, 'kappa':3., 'zp':2.}
        self.rate_keys = self.lambda_r.keys()
        self.vectorize_computation = False
        self.update_fiducial(fiducial_lambda_r)
        logger.info(f'Created `{self.name}` model with fiducial parameters: {self.fiducial_lambda_r}')
        
    def compute_rate(self, z):
        if self.vectorize_computation:
            if len(jnp.atleast_1d(z).shape)<3:
                log_rate = jax.vmap(log_madau_dickinson_norm, in_axes=(None,0,0,0,0))(z, *self.lambda_r.values())
                return jnp.exp(log_rate)
            else:
                log_rate = jax.vmap(log_madau_dickinson_norm, in_axes=(0,0,0,0,0))(z, *self.lambda_r.values())
                return jnp.exp(log_rate)                
        else:
            return jnp.exp(log_madau_dickinson_norm(z, *self.lambda_r.values()))

In [6]:
lambda_r = 

rate = model_rate(lambda_r, model = 'PL')


zs = jnp.linspace(0.1, 0.8, 5000)

print(rate.log_rate(zs))

%timeit rate.log_rate(zs)

NameError: name 'check_keys' is not defined