In [7]:
# without the os and sys, error that module jaxtuary is not found
import os # https://stackoverflow.com/questions/34478398/import-local-function-from-a-module-housed-in-another-directory-with-relative-im
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
# sorry for jank
import jax.numpy as jnp
from pymort import getIdGroup, MortXML
from jaxtuary import get_q_from_select_ultimate, get_npx_from_q, project_duration

For this we will be using the study `2017_CSO loaded preferred_structure gender_distinct ANB`. We can get the PyMort object that represents this collection by referencing any of the tables belonging to the collection. 

In [14]:
print(getIdGroup(3299))

IdGroup(study='2017_CSO', grouping='loaded preferred_structure gender_distinct ANB', ids=(3299, 3300, 3301, 3302, 3303, 3304, 3305, 3306, 3307, 3308), genders=('male', 'male', 'male', 'female', 'female', 'female', 'male', 'male', 'female', 'female'), risks=('nonsmoker super_preferred', 'nonsmoker preferred', 'nonsmoker residual', 'nonsmoker super_preferred', 'nonsmoker preferred', 'nonsmoker residual', 'smoker preferred', 'smoker residual', 'smoker preferred', 'smoker residual'))


Load mortality tables to tensor format.

In [15]:
ids = getIdGroup(3299).ids
select = jnp.array([MortXML(id).Tables[0].Values.unstack().values for id in ids])
ultimate = jnp.array([MortXML(id).Tables[1].Values.unstack().values for id in ids])
print(f"select.shape: {select.shape}") # tableIds [3299, 3308], issue_ages [18, 95], durations [1, 25]
print(f"ultimate.shape: {ultimate.shape}") # tableIds [3299, 3308], attained_ages [18, 120]

select.shape: (10, 78, 25)
ultimate.shape: (10, 103)


This is the model points. What is the best way to represent model points? I don't like having them all as variables. I think that having them in a flax module actually makes sense just so they are available globally with autocomplete. Maybe another day.

In [41]:
mortality_table_index = jnp.array([0,1,2])
issue_age = jnp.array([30, 40, 50])
duration = jnp.array([0, 0, 0]) # new business
face = jnp.array([1000*x for x in [100, 500, 250]])

Traditional actuarial modeling techniques do calculations recursively. In contrast, we compute cashflows for all points in time simultaneously. This allows us to parallelize over the time dimension on the GPU. We an array of shape (timesteps, modelpoints) that represents the durations for each modelpoint at each point in time.

In [42]:
print(f"duration: \n {duration}")
print(f"duration projected: \n {project_duration(duration, 10)}")

duration: 
 [0 0 0]
duration projected: 
 [[0 0 0]
 [1 1 1]
 [2 2 2]
 [3 3 3]
 [4 4 4]
 [5 5 5]
 [6 6 6]
 [7 7 7]
 [8 8 8]
 [9 9 9]]


Get the cashflows.

In [43]:
timesteps = 10

q = get_q_from_select_ultimate(
  mortality_table_index,
  issue_age,
  project_duration(duration, timesteps), # when projecting, remains broadcastableration_projected,
  select, 
  ultimate
)

npx = get_npx_from_q(q)

face * npx * q

DeviceArray([[ 15.000001,  95.      , 185.      ],
             [ 15.9976  , 174.96675 , 239.82239 ],
             [ 20.99349 , 249.86502 , 336.9265  ],
             [ 23.987522, 294.6933  , 406.25833 ],
             [ 26.979483, 339.4461  , 487.71072 ],
             [ 29.96911 , 389.09988 , 568.70844 ],
             [ 32.95613 , 433.6575  , 624.3945  ],
             [ 37.936962, 488.0628  , 689.5518  ],
             [ 42.912357, 557.2394  , 798.5357  ],
             [ 52.869232, 636.1317  , 943.3463  ]], dtype=float32)

Get the net premium.

In [44]:
discount_factor = 1/(1.02)
discounts = discount_factor ** jnp.arange(timesteps)[:, None]
discounts_lagged = discounts * discount_factor

discounted_expected_claims = face * npx * q * discounts_lagged
# net present value (NPV) of outgoing cashflows
discounted_expected_claims_NPV_per_policy = jnp.sum(discounted_expected_claims, axis=0)
net_premium_per_policy = discounted_expected_claims_NPV_per_policy / jnp.sum(npx * discounts, 0)
net_premium_per_policy

DeviceArray([ 28.778994, 350.5083  , 508.1781  ], dtype=float32)