# Mortality Tables in Python on the GPU

This article shows how to -

* Use mortality tables in Python
* Run actuarial models on the GPU




In [50]:
# install dependencies
!pip install pymort==0.4.1 xarray jax



I was allocated the Tesla T4, but usually I get the P100 which is faster.

In [51]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-3dfb8f62-d320-e541-afa6-fbe8c79ece36)


## Mortality tables in Python

The Society of Actuaries (SOA) hosts a collection of thousands of mortality tables at [mort.soa.org](https://mort.soa.org/). I developed the [Pymort](https://github.com/actuarialopensource/pymort) package to make working with with the tables easy. 

Instead of working with over **3000 separate files** given by the SOA, Pymort provides all of the information you need in **3 normalized tables** related by [primary/foreign keys](https://www.ibm.com/docs/en/ida/9.1.1?topic=entities-primary-foreign-keys), a design taken from relational databases.

In [52]:
import pymort
# instantiate the Relational class
db = pymort.Relational()

Each table has a unique identifier called an `id`. There is information associated with this table like 
* What is the name of the mortality study producing the table? (i.e. 2017 CSO vs. 2015 VBT)
* Is there a grouping that the table belongs to within the study? (i.e. unloaded preferred_structure gender_distinct ANB vs. loaded smoker_distinct gender_blended ALB)
* Gender (male vs. female)
* Risk (smoker vs. nonsmoker)

We call this information about a table the `metadata` and store it as an attribute of our `Relational` object.

In [53]:
meta = db.metadata
meta

Unnamed: 0_level_0,study,grouping,gender,risk
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
3209,2015_VBT,relative_risk ALB,female,nonsmoker RR50
3210,2015_VBT,relative_risk ALB,female,nonsmoker RR60
3211,2015_VBT,relative_risk ALB,female,nonsmoker RR70
3212,2015_VBT,relative_risk ALB,female,nonsmoker RR80
3213,2015_VBT,relative_risk ALB,female,nonsmoker RR90
...,...,...,...,...
3368,2017_CSO,unloaded smoker_distinct gender_distinct ANB,female,smoker
3369,2017_CSO,unloaded smoker_distinct gender_distinct ALB,male,nonsmoker
3370,2017_CSO,unloaded smoker_distinct gender_distinct ALB,female,nonsmoker
3371,2017_CSO,unloaded smoker_distinct gender_distinct ALB,male,smoker


Tables are generally either `select` or `ultimate`. Select tables depend on the issue age, and the years since issuing the contract. This helps account for the reduction in mortality rates in the earlier years of the contract caused by the selective effect of medical underwriting. 

In [54]:
db.select

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,vals
id,Age,Duration,Unnamed: 3_level_1
3209,18,1,0.00018
3209,18,2,0.00018
3209,18,3,0.00017
3209,18,4,0.00017
3209,18,5,0.00017
...,...,...,...
3372,95,21,0.50000
3372,95,22,0.50000
3372,95,23,0.50000
3372,95,24,0.50000


Ultimate tables only depend on the current age of the insured and represent the mortality rates once the effects of medical underwriting have worn off.

In [55]:
db.ultimate

Unnamed: 0_level_0,Unnamed: 1_level_0,vals
id,Age,Unnamed: 2_level_1
3209,18,0.00018
3209,19,0.00018
3209,20,0.00017
3209,21,0.00017
3209,22,0.00017
...,...,...
3372,116,0.50000
3372,117,0.50000
3372,118,0.50000
3372,119,0.50000


The tables hosted by the SOA have an associated [XML standard](https://mort.soa.org/About.aspx) which is provided to users of Pymort via the `PyXML` class. Pymort has **all the information** for **every** table.

In [56]:
xml = pymort.PyXML(3372)

xml.ContentClassification.TableName

'2017 Unloaded CSO Smoker Distinct Smoker Female ALB '

## Use with GPU

Let's convert the tables from study `2017_CSO` with group `unloaded smoker_distinct gender_distinct ALB` to tensors on the GPU.

In [57]:
db = pymort.Relational()

# select the tables of interest from the metadata
meta = db.metadata
meta = meta[(meta.study == "2017_CSO") & (meta.grouping == "loaded smoker_distinct gender_distinct ANB")]

# join the select rate table and set the MultiIndex properly to allow for conversion to tensor
select = meta.join(db.select)
select = select.reset_index()
select = select[["gender", "risk", "Age", "Duration", "vals"]]
select = select.set_index(["gender", "risk", "Age", "Duration"])

ult = meta.join(db.ultimate)
ult = ult.reset_index()
ult = ult[["gender", "risk", "Age", "vals"]]
ult = ult.set_index(["gender", "risk", "Age"])

xarr_sel = select.to_xarray()
xarr_ult = ult.to_xarray()

# this is an "xarray", which is basically just a "named tensor"
xarr_sel


The `.to_xarray` method of the dataframe converts the MultiIndex of a Pandas dataframe to the axes of a tensor. Let's convert this xarray to a JAX DeviceArray that runs on the GPU. 

In [58]:
from jax import numpy as jnp
from jax import random
import jax

# the dimensions are from xarray. So [female, male]=[0,1] and [nonsmoker, smoker]=[0,1] for indexing
j_sel = jnp.array(xarr_sel.to_array()).squeeze()
j_ult = jnp.array(xarr_ult.to_array()).squeeze()


We simulate 10,000,000 modelpoints

In [60]:
key = random.PRNGKey(0)
key, *subkeys = random.split(key, 10)

model_points = 10_000_000

pols_inforce = jnp.ones(model_points)
issue_age = random.randint(subkeys[0], (model_points,), 35, 55)
gender = random.choice(key, a=jnp.array([0, 1]), shape=(model_points,), p=jnp.array([0.5, 0.5]))
risk = random.choice(key, a=jnp.array([0, 1]), shape=(model_points,), p=jnp.array([0.8, 0.2]))
face = random.randint(subkeys[4], (model_points,), 1, 10) * 200_000

the dimensions are from xarray. So [female, male]=[0,1] and [nonsmoker, smoker]=[0,1] for indexing


In [61]:
def get_mortality_rate(select, ultimate, gender, risk, issue_age, t: int):
  # be careful with the offset on the age
  return jnp.where(t < select.shape[-1],
                   select[gender, risk, issue_age-18, t],
                   ultimate[gender, risk, (issue_age-18)+t])

def getPVDB(interest_rate, select, ultimate, gender, risk, face, issue_age):
  PVDB = 0
  pols_inforce = jnp.ones(gender.shape[0])
  for t in range(100):
    rates = get_mortality_rate(select, ultimate, gender, risk, issue_age, t)
    PVDB += jnp.sum(pols_inforce * rates * face * pow(1 + interest_rate, -(t+1)))
    pols_inforce -= pols_inforce * rates
  return PVDB

In [62]:
%%time
inputs = [.06, j_sel, j_ult, gender, risk, face, issue_age]
print(getPVDB(*inputs))

1264811900000.0
CPU times: user 1.48 s, sys: 258 ms, total: 1.74 s
Wall time: 1.98 s


It seems very fast. We will need to set up some infrastructure to compare the performance of actuarial models in a way that is reproducible as we continue this discussion.