# Learnable Distributions

In [1]:
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt

def dgompertz(x, eta, b):
    """This function calculates the probability density function (PDF) of the Gompertz distribution.

    Args:
        x (_type_): _description_
        eta (_type_): _description_
        b (_type_): _description_

    Returns:
        _type_: _description_
    """
    return b * eta * jnp.exp(eta + b * x - eta * jnp.exp(b * x))

def cgompertz(x, eta, b):
    """This function calculates the cumulative distribution function (CDF) of the Gompertz distribution.

    Args:
        x (_type_): _description_
        eta (_type_): _description_
        b (_type_): _description_

    Returns:
        _type_: _description_
    """
    return 1 - jnp.exp(-eta * (jnp.exp(b * x) - 1))

def qgompertz(p, eta, b):
    """This function calculates the quantile function (inverse CDF) of the Gompertz distribution.

    Args:
        p (_type_): _description_
        eta (_type_): _description_
        b (_type_): _description_

    Returns:
        _type_: _description_
    """
    return jnp.log(-jnp.log(1 - p) / eta + 1) / b

def rgompertz(key, n, eta, b):
    """This function generates random samples from the Gompertz distribution.

    Args:
        key (_type_): _description_
        n (_type_): _description_
        eta (_type_): _description_
        b (_type_): _description_

    Returns:
        _type_: _description_
    """
    u = random.uniform(key, shape=(n,))
    return jnp.log(-jnp.log(1 - u) / eta + 1) / b

In [2]:
import numpyro.distributions as dist
from numpyro.distributions import constraints

class Gompertz2(dist.Distribution):
    def __init__(self, eta, b, validate_args=None):
        self.eta = eta
        self.b = b
        batch_shape = jnp.broadcast_shapes(jnp.shape(eta), jnp.shape(b))
        super(Gompertz2, self).__init__(batch_shape=batch_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        return rgompertz(key, sample_shape + self.shape(), self.eta, self.b)

    def log_prob(self, value):
        return jnp.log(dgompertz(value, self.eta, self.b))

    def cdf(self, value):
        return cgompertz(value, self.eta, self.b)

    def icdf(self, q):
        return qgompertz(q, self.eta, self.b)


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

# Define the model
def model(data=None):
    # Prior distributions for eta and b
    eta = numpyro.sample('eta', dist.Normal(0, 1))
    b = numpyro.sample('b', dist.Normal(0, 1))

    # Define the Gompertz distribution
    gompertz = Gompertz2(eta, b)

    # If data is provided, use it to compute the likelihood
    if data is not None:
        numpyro.sample('obs', gompertz, obs=data)

# Generate some data
key = random.PRNGKey(0)
eta_true = 1.0
b_true = 0.5
n_samples = 100
single_sample = rgompertz(key, 1, eta_true, b_true)
data = jnp.repeat(single_sample, n_samples)

# Run MCMC
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(key, data=data)

# Print the summary of the posterior distribution
mcmc.print_summary()


sample: 100%|██████████| 2000/2000 [00:03<00:00, 569.74it/s, 7 steps of size 1.42e-02. acc. prob=0.74]   



                mean       std    median      5.0%     95.0%     n_eff     r_hat
         b      9.14      0.42      9.23      8.50      9.77     42.25      1.02
       eta      0.00      0.00      0.00      0.00      0.00     50.16      1.02

Number of divergences: 80


In [4]:
def model2(data=None):
    # Prior distributions for eta and b
    eta = numpyro.sample('eta', dist.Normal(0, 1))
    b = numpyro.sample('b', dist.Normal(0, 1))

    # Define the Gompertz distribution
    gompertz = Gompertz(eta, b)

    # If data is provided, use it to compute the likelihood
    if data is not None:
        numpyro.sample('obs', gompertz, obs=data)

# Run MCMC
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
mcmc.run(key, data=data)

# Print the summary of the posterior distribution
mcmc.print_summary()

sample: 100%|██████████| 2000/2000 [00:03<00:00, 582.35it/s, 7 steps of size 1.42e-02. acc. prob=0.74]  



                mean       std    median      5.0%     95.0%     n_eff     r_hat
         b      9.14      0.42      9.23      8.50      9.77     42.25      1.02
       eta      0.00      0.00      0.00      0.00      0.00     50.16      1.02

Number of divergences: 80
