# Prefix sum

## Definition and relevance to actuarial calculations

For a [binary associative operator](https://en.wikipedia.org/wiki/Associative_property) $\oplus$ and a sequence of numbers $x_0,x_1,x_2,...$ the **prefix sum** is a sequence of numbers $y_0, y_1, y_2, ...$ where

$$
\begin{align*}
y_0 &= x_0 \\
y_1 &= x_0 \oplus x_1 \\
y_2 &= x_0 \oplus x_1 \oplus x_2 \\
...
\end{align*}
$$

An example of a prefix sum in actuarial science is that when our binary associative operator is multiplication and our sequence of numbers is survival probabilities we have the following - 

$$
\begin{align*}
p_x &= p_x \\
_2p_x &= p_x \cdot p_{x+1} \\
_3p_x &= p_x \cdot p_{x+1} \cdot p_{x+2} \\
...
\end{align*}
$$

## Parallelism on the GPU

It is not obvious that parallelism will help us in calculating prefix sums, but [apparently it does](http://courses.csail.mit.edu/18.337/2004/book/Lecture_03-Parallel_Prefix.pdf). 



Which operations happen in parallel and which don't?

In [5]:
import jax.numpy as jnp
from jax.lax import associative_scan
from jaxtyping import f
from typing import Tuple

def survivorship(p1: f["policies"], p2: f["policies"]) -> f["policies"]:
    return p1*p2

q = jnp.array([[.1, .1], [.2, .2]])
p = 1-q
associative_scan(survivorship, p)


DeviceArray([[0.9       , 0.9       ],
             [0.71999997, 0.71999997]], dtype=float32)