In [1]:
!pip install pymort equinox jaxtyping

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pymort
  Downloading pymort-0.6.0-py3-none-any.whl (12.5 MB)
[K     |████████████████████████████████| 12.5 MB 315 kB/s 
[?25hCollecting equinox
  Downloading equinox-0.6.0-py3-none-any.whl (66 kB)
[K     |████████████████████████████████| 66 kB 1.7 MB/s 
[?25hCollecting jaxtyping
  Downloading jaxtyping-0.1.0-py3-none-any.whl (16 kB)
Collecting typeguard>=2.13.3
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Collecting typing-extensions
  Downloading typing_extensions-4.3.0-py3-none-any.whl (25 kB)
Installing collected packages: typing-extensions, typeguard, pymort, jaxtyping, equinox
  Attempting uninstall: typing-extensions
    Found existing installation: typing-extensions 4.1.1
    Uninstalling typing-extensions-4.1.1:
      Successfully uninstalled typing-extensions-4.1.1
  Attempting uninstall: typeguard
    Found existing installation: typeguard 2.7.1
    U

# Assumption tables as Python classes

Before discussing parallel computing, we have to set up our modelpoints and assumption tables.

* Assumption tables can be represented as classes, I use Equinox for working with JAX. 
  * There is [a paper](https://arxiv.org/pdf/2111.00254.pdf) that explains the programming model.
  * I use Equinox as a syntactic sugar for partial functions, where the fixed parameters of the partial function are accessible in an object oriented manner.
* Rate tables are stored as attributes. 
* Rate lookups are instance methods.



In [2]:
import jax.numpy as jnp
import jax
from pymort import getIdGroup, MortXML
import equinox as eqx
from jaxtyping import f, i # declaring what the axes are is nice

class GetQ(eqx.Module):
  select: f["table_index issue_age duration"] 
  ultimate: f["table_index attained_age"]

  def __init__(self, collection_id: int):
    ids = getIdGroup(collection_id).ids
    self.select = jnp.array([MortXML(id).Tables[0].Values.unstack().values for id in ids])
    self.ultimate = jnp.array([MortXML(id).Tables[1].Values.unstack().values for id in ids])

  def __call__(self, table_index: f["policies"], issue_age: f["policies"], duration_projected: f["timesteps policies"]) -> f["timesteps policies"]:
    return jnp.where(
        duration_projected < self.select.shape[-1],
        self.select[table_index, issue_age - 18, duration_projected], # offset because jax arrays are 0-indexed
        self.ultimate[table_index, (issue_age - 18) + duration_projected],
    )

get_q = GetQ(3299)



In [9]:
getIdGroup(3299)

IdGroup(study='2017_CSO', grouping='loaded preferred_structure gender_distinct ANB', ids=(3299, 3300, 3301, 3302, 3303, 3304, 3305, 3306, 3307, 3308), genders=('male', 'male', 'male', 'female', 'female', 'female', 'male', 'male', 'female', 'female'), risks=('nonsmoker super_preferred', 'nonsmoker preferred', 'nonsmoker residual', 'nonsmoker super_preferred', 'nonsmoker preferred', 'nonsmoker residual', 'smoker preferred', 'smoker residual', 'smoker preferred', 'smoker residual'))

Define another module to represent our modelpoints.



In [None]:
class ModelPoints(eqx.Module):
  mortality_table_index: i["policies"]
  issue_age: i["policies"]
  duration: i["policies"]
  face: i["policies"]

  def projected_q(self, timesteps: int) -> f["timesteps policies"]:
    """
    Get the mortality rates q for the policies across time.
    """
    time_axis = jnp.arange(timesteps)[:, jnp.newaxis]
    duration_projected = time_axis + self.duration
    return get_q(self.mortality_table_index, self.issue_age, duration_projected)

Make up some data and construct our classes.

In [None]:
# every combination of mortality table index from [0,10), duration from [0, 25), and issue_age from [18, 51)
# 10*25*(51-18) = 8250 modelpoints in total
mortality_table_index = jnp.arange(10)
duration = jnp.arange(1)
issue_age = jnp.arange(18, 51)
# like a cartesian product, generate all combinations of each class/duration/issue_age
mortality_table_index, duration, issue_age = [x.flatten() for x in jnp.meshgrid(mortality_table_index, duration, issue_age)]
# all policies have same face value
face = jnp.ones_like(issue_age) * 1_000_000
mp = ModelPoints(mortality_table_index, issue_age, duration, face)

Now the discussion of parallel computing.

# Parralel Prefix Sum

## Prefix sum definition and relevance to actuarial calculations

For a [binary associative operator](https://en.wikipedia.org/wiki/Associative_property) $\oplus$ and a sequence of numbers $x_0,x_1,x_2,...$ the **prefix sum** is a sequence of numbers $y_0, y_1, y_2, ...$ where

$$
\begin{align*}
y_0 &= x_0 \\
y_1 &= x_0 \oplus x_1 \\
y_2 &= x_0 \oplus x_1 \oplus x_2 \\
...
\end{align*}
$$

Let $p_x$ be the probability that a person age $x$ survives to age $x+1$ and $_np_x$ be the probability that they survive to age $x+n$. An example of a prefix sum in actuarial science is the following.

$$
\begin{align*}
p_x &= p_x \\
_2p_x &= p_x \cdot p_{x+1} \\
_3p_x &= p_x \cdot p_{x+1} \cdot p_{x+2} \\
...
\end{align*}
$$


## Parallelism on the GPU

[Prefix sums can be computed in parallel](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.128.6230&rep=rep1&type=pdf). JAX has a special method for this, [`jax.lax.associative_scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.associative_scan.html#jax.lax.associative_scan). 

We have seen that survival probabilities are calculated with a cumulative product. [`jax.numpy.cumprod`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.cumprod.html) is the typical way to take the cumulative product and [uses the associative scan in its implementation](https://github.com/google/jax/blob/main/jax/_src/lax/control_flow/loops.py#L1950).


In [None]:
timesteps = 20
q = mp.projected_q(timesteps)
# jnp.cumprod uses parallel prefix sum
npx = jnp.concatenate([jnp.ones((1, q.shape[1])), jnp.cumprod(1-q, axis=0)[:-1]])

We can calculate the reserve at each point in time using a cumulative sum in reverse, another prefix sum operation that can be parallelized on the time dimension. Reserves for this term product can be entirely parallelized along the time dimension, with no strong data dependencies from one time step to the next.

In [None]:
expected_claims = npx * q * mp.face
v = (1/1.02)
discounts_boy: f["timesteps 1"] = v ** jnp.arange(timesteps)[:, jnp.newaxis] # boy is "beginning of year"
discounts_eoy: f["timesteps 1"] = v * discounts_boy # eoy is "end of year"
annual_premium =  jnp.sum(expected_claims * discounts_eoy, 0)/jnp.sum(npx * discounts_boy, 0)
expected_premiums = annual_premium * npx

def reserves(expected_claims: f["timesteps policies"], expected_premiums: f["timesteps policies"], v: float):
  discounts_boy: f["timesteps 1"] = v ** jnp.arange(timesteps)[:, jnp.newaxis]
  discounts_eoy: f["timesteps 1"] = v * discounts_boy
  discounted_expected_claims = expected_claims * discounts_eoy
  discounted_expected_premiums = expected_premiums * discounts_boy
  net_cashflows = discounted_expected_claims - discounted_expected_premiums
  reserves = jax.lax.cumsum(net_cashflows, reverse=True)
  return jnp.sum(reserves, axis=1)

# t=0 reserve is -.04 due to accumulation of floating point errors
print(reserves(expected_claims, expected_premiums, v))

[-4.2449951e-02  5.2286897e+05  9.9272012e+05  1.4035186e+06
  1.7658292e+06  2.0858644e+06  2.3629015e+06  2.5952758e+06
  2.7802620e+06  2.9128385e+06  2.9866622e+06  3.0035835e+06
  2.9611145e+06  2.8515672e+06  2.6744210e+06  2.4279880e+06
  2.1049722e+06  1.7013820e+06  1.2181765e+06  6.5228344e+05]


## When parallelism is complicated

[This paper](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.84.2724&rep=rep1&type=pdf) presents the FORA algorithm which solves the recurrence relation $Y_k = a_kY_{k-1} + X_k$, equation 1.23 from [Computation and Modelling in Insurance and Finance](https://www.amazon.com/Computation-Modelling-Insurance-International-Actuarial/dp/0521830486). This algorithm converts the recurrence into a prefix sum operation. In JAX discussions on GitHub [someone has implemented a linear recurrence](https://github.com/google/jax/discussions/9856) and reports performance problems which I haven't yet made heads or tails of. [This stackoverflow answer](https://stackoverflow.com/questions/70085324/cuda-force-instruction-execution-order) gives an implementation in C++. Also [this reference which I already gave](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.128.6230&rep=rep1&type=pdf) covers the FORA algorithm.

In idiomatic JAX, this recurrence is solved with [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) which is not done in parallel and is reportedly ["slower than expected"](https://github.com/google/jax/issues/2491) on the GPU with some optimism from the JAX team surrounding future improvements on GPU. This discussion of "slow" vs "fast" is meaningless until we have reproducible benchmarks that are representative of industry workloads, something I'd like to work on.

Theory aside, here's an example to distinguish when we can vs. can't avoid `jax.lax.scan` (it's faster to avoid it).

## Bank accounts and UL

We put an amount $p_t$ into a bank account every year for 100 years and accumulate it with interest. Let the account value at time $t$ be $AV_t$. The recurrence is then $AV_t = AV_{t-1} \cdot (1+i_t) + p_t$

Prefix sums can easily calculate the **present value** of the account value at points in time.



In [None]:
key = jax.random.PRNGKey(0)
timesteps = 10
payments = jax.random.randint(key, (timesteps, ), 1, 10)
v = 1/(1.02)
discounts = v ** jnp.arange(timesteps)
discounted_cashflows = payments * discounts
# get **present value** of account values using prefix sum
present_value_of_future_account_values = jnp.cumsum(discounted_cashflows)
present_value_of_future_account_values

DeviceArray([ 9.      , 10.960784, 19.611301, 21.495947, 24.267483,
             25.173214, 28.7251  , 33.0779  , 34.78488 , 38.1319  ],            dtype=float32, weak_type=True)

It feels like the idiomatic thing for the **current value** of account values at points in time is the non-parallel `jax.lax.scan` - https://ericmjl.github.io/dl-workshop/02-jax-idioms/02-loopy-carry.html. For a constant interest rate - 

In [None]:
# Notice that it would be a pain to incorporate a non-constant interest rate into the scan operation.
def accumulate_account_value(res, payment):
  new = res*(v**-1) + payment
  return new, new  # ("carryover", "accumulated")

result_init = 0
final, result = jax.lax.scan(accumulate_account_value, result_init, payments)
result

DeviceArray([ 9.      , 11.18    , 20.4036  , 22.811672, 26.267904,
             27.793262, 32.34913 , 37.99611 , 40.75603 , 45.57115 ],            dtype=float32)

If interest rates are not constant, we pass a Pytree in as the `xs` argument of scan. Equinox modules are registered as Pytrees, so this should work.

In [None]:
key, subkey = jax.random.split(key)
rates = 1 + jax.random.randint(subkey, (timesteps, ), 1, 10) / 100

class RatesPayments(eqx.Module):
  rates: f["timesteps"]
  payments: f["timesteps"]

rp = RatesPayments(rates, payments)
def accumulate_account_value2(res, rp: RatesPayments):
  new = res*rp.rates + rp.payments
  return new, new

result_init = 0
final, result = jax.lax.scan(accumulate_account_value2, result_init, rp)
# yay it works
result 

DeviceArray([ 9.      , 11.539999, 20.6554  , 23.481615, 28.360146,
             30.77815 , 35.701492, 42.486565, 45.336296, 50.696384],            dtype=float32)

Some product mechanics depend on the current account value, so it is important to be able to scan. Even though the accumulation of an account value is a first order linear recurrence, idiomatic JAX uses scan which is more general but does not execute in parallel over the time dimension. The GPU may be fully utilized if enough accounts are being projected in parallel... here is the associative scan implemented for a first order linear recurrence.

In [None]:
rp_array = jnp.stack([rp.rates, rp.payments], axis=1)
rp_array = jnp.insert(rp_array, 0, jnp.array([0, 0]), axis=0)

DeviceArray([[0.  , 0.  ],
             [1.09, 9.  ],
             [1.06, 2.  ],
             [1.01, 9.  ],
             [1.04, 2.  ],
             [1.08, 3.  ],
             [1.05, 1.  ],
             [1.03, 4.  ],
             [1.05, 5.  ],
             [1.02, 2.  ],
             [1.03, 4.  ]], dtype=float32)

In [None]:
import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(0)
timesteps = 10
payments = jax.random.randint(key, (timesteps, ), 1, 10)
key, subkey = jax.random.split(key)
rates = 1 + jax.random.randint(subkey, (timesteps, ), 1, 10) / 100
rp_array = jnp.stack([rates, payments], axis=1)
rp_array = jnp.insert(rp_array, 0, jnp.array([0, 0]), axis=0)

def scan_operator(ci, cj):
    """Operator to be used for scan and associative scan which solves a linear 
    recurrence with a diagonal transition matrix"""
    def A_op(Ai, Aj):
            return Ai * Aj

    def b_op(Aj, bi, bj):
        return Aj * bi + bj

    return jnp.stack([A_op(ci[0], cj[0]), b_op(cj[0], ci[1], cj[1])])

parallel_scan_operator = jax.vmap(scan_operator)

jax.lax.associative_scan(parallel_scan_operator, rp_array, axis=0)

DeviceArray([[ 0.      ,  0.      ],
             [ 0.      ,  9.      ],
             [ 0.      , 11.09    ],
             [ 0.      , 20.6445  ],
             [ 0.      , 23.05739 ],
             [ 0.      , 27.67141 ],
             [ 0.      , 30.05498 ],
             [ 0.      , 36.15883 ],
             [ 0.      , 41.520416],
             [ 0.      , 45.596436],
             [ 0.      , 51.42029 ]], dtype=float32)

## Things that are worth trying

* Meaningful performance testing benchmarks
* Hardware accelerated first order linear recurrences that have a good developer experience (iff the speedup seems significant).