# Prefix sum

## Definition and relevance to actuarial calculations

For a [binary associative operator](https://en.wikipedia.org/wiki/Associative_property) $\oplus$ and a sequence of numbers $x_0,x_1,x_2,...$ the **prefix sum** is a sequence of numbers $y_0, y_1, y_2, ...$ where

$$
\begin{align*}
y_0 &= x_0 \\
y_1 &= x_0 \oplus x_1 \\
y_2 &= x_0 \oplus x_1 \oplus x_2 \\
...
\end{align*}
$$

Let $p_x$ be the probability that a person age $x$ survives to age $x+1$ and $_np_x$ be the probability that they survive to age $x+n$. An example of a prefix sum in actuarial science is the following.

$$
\begin{align*}
p_x &= p_x \\
_2p_x &= p_x \cdot p_{x+1} \\
_3p_x &= p_x \cdot p_{x+1} \cdot p_{x+2} \\
...
\end{align*}
$$


## Parallelism on the GPU

[Prefix sums can be computed in parallel](http://courses.csail.mit.edu/18.337/2004/book/Lecture_03-Parallel_Prefix.pdf). JAX has a special method for this, [`jax.lax.associative_scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.associative_scan.html#jax.lax.associative_scan). 

We have seen that survival probabilities are calculated with a cumulative product. [`jax.numpy.cumprod`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.cumprod.html) is the typical way to take the cumulative product and [uses the associative scan in its implementation](https://github.com/google/jax/blob/main/jax/_src/lax/control_flow/loops.py#L1950).

JAX has a non-associative scan [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) which is useful for carrying state inside of a loop, but [cannot be parallelized](https://github.com/google/jax/discussions/10233).


In [1]:
import jax.numpy as jnp
from pymort import getIdGroup, MortXML

ids = getIdGroup(3299).ids
select = jnp.array([MortXML(id).Tables[0].Values.unstack().values for id in ids])
ultimate = jnp.array([MortXML(id).Tables[1].Values.unstack().values for id in ids])


ModuleNotFoundError: No module named 'jaxlib'