In [1]:
import jax
from jax import random
import jax.numpy as jnp
import numpy as onp
import pymort

This gives us the mortality table as a JAX array.

In [9]:
maleXML = pymort.load("2017 Loaded CSO Composite Male ALB  - t3289")
femaleXML = pymort.load("2017 Loaded CSO Composite Female ALB  - t3290")
maleRates = jnp.array(onp.squeeze(maleXML.Tables[1].Values.to_numpy()))
femaleRates = jnp.array(onp.squeeze(femaleXML.Tables[1].Values.to_numpy()))
# this is what we want
mortalityRates = jnp.array([maleRates, femaleRates])

Now we generate the model points.

In [13]:
key = random.PRNGKey(0)
key, *subkeys = random.split(key, 10)

model_points = 1000

pols_inforce = jnp.ones(model_points)
issue_age = random.randint(subkeys[0], (model_points,), 35, 55)
gender = random.choice(key, a=jnp.array([0, 1]), shape=(model_points,), p=jnp.array([0.2, 0.8]))
face = random.randint(subkeys[4], (model_points,), 1, 10) * 200_000

This function calculates the present value of death benefits.

In [14]:
def getPVDB(interest_rate, mortalityRates, issue_age, gender, face):
    PVDB = 0
    pols_inforce = jnp.ones(model_points)
    for t in range(120):
        rates = mortalityRates[gender, issue_age + t]
        PVDB += jnp.sum(pols_inforce * rates * face * pow(1 + interest_rate, -(t+1)))
        pols_inforce -= pols_inforce * rates
    return PVDB

We can calculate the present value of death benefits.

In [19]:
inputs = [.06, mortalityRates, issue_age, gender, face]
getPVDB(*inputs)

DeviceArray(1.3774626e+08, dtype=float32)

What is novel is that you can calculate the derivative with respect to assumptions.

In [26]:
getPVDB_grad = jax.value_and_grad(getPVDB, argnums=(0,1))
PVDB, (d_di, d_dmort) = getPVDB_grad(*inputs)
print('d_di', d_di)

d_di -3722314800.0


Supposing you have a high dimensional space of assumptions/margins you might even use gradient descent to set the "most adverse" assumption.