In [16]:
import jax
import jax.scipy.stats.nbinom as nbinom


def compute_single_weight(
    reported_data: int, particle_estimate: float | int, r: float | int
) -> float:
    epsilon = 0.005
    weight = nbinom.logpmf(
        k=reported_data,
        n=r,
        p=r / (r + particle_estimate + epsilon),
    )
    return weight


@jax.jit
def compute_single_weight_jit(
    reported_data: int, particle_estimate: float | int, r: float | int
) -> float:
    epsilon = 0.005
    weight = nbinom.logpmf(
        k=reported_data,
        n=r,
        p=r / (r + particle_estimate + epsilon),
    )
    return weight

In [17]:
reports = jax.numpy.ones(500)
key = jax.random.PRNGKey(0)
random_integers = jax.random.randint(key, shape=(500,), minval=0, maxval=11)

In [18]:
print(reports.shape == random_integers.shape)

True


In [19]:
def regular():
    for i in range(500):
        compute_single_weight(reports[i], random_integers[i], r=10)


def jitted():
    for i in range(500):
        compute_single_weight_jit(reports[i], random_integers[i], r=10)

In [24]:
%timeit regular()

1.31 s ± 208 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [25]:
%timeit jitted()

186 ms ± 10 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [27]:
print("JIT function is approximately", round(1.31 / 0.186, 2), "times faster.")

JIT function is approximately 7.04 times faster.
