# Net premium calculation from scratch in JAX

Using plain JAX and a library for loading mortality tables, calculate the net premiums for some policies.


In [12]:
import jax.numpy as jnp
from pymort import getIdGroup, MortXML

## Loading mortality tables

For this we will be using the study `2017_CSO loaded preferred_structure gender_distinct ANB`. We can get the [pymort](https://github.com/actuarialopensource/pymort) object that represents this collection by referencing any of the [table ids](https://mort.soa.org/) belonging to the collection. 

In [33]:
print(getIdGroup(3299))

IdGroup(study='2017_CSO', grouping='loaded preferred_structure gender_distinct ANB', ids=(3299, 3300, 3301, 3302, 3303, 3304, 3305, 3306, 3307, 3308), genders=('male', 'male', 'male', 'female', 'female', 'female', 'male', 'male', 'female', 'female'), risks=('nonsmoker super_preferred', 'nonsmoker preferred', 'nonsmoker residual', 'nonsmoker super_preferred', 'nonsmoker preferred', 'nonsmoker residual', 'smoker preferred', 'smoker residual', 'smoker preferred', 'smoker residual'))


Load mortality tables to tensor format.

In [32]:
ids = getIdGroup(3299).ids
select = jnp.array([MortXML(id).Tables[0].Values.unstack().values for id in ids])
ultimate = jnp.array([MortXML(id).Tables[1].Values.unstack().values for id in ids])
print(f"select.shape: {select.shape}") # tableIds [3299, 3308], issue_ages [18, 95], durations [1, 25]
print(f"ultimate.shape: {ultimate.shape}") # tableIds [3299, 3308], attained_ages [18, 120]

select.shape: (10, 78, 25)
ultimate.shape: (10, 103)


## Policyholder attributes

In [31]:
mortality_table_index = jnp.array([0,1,2])
issue_age = jnp.array([30, 40, 50])
duration = jnp.array([0, 0, 0]) # new business
face = jnp.array([1000*x for x in [100, 500, 250]])

## The time dimension

Traditional actuarial modeling techniques do calculations recursively. In contrast, we compute cashflows for all points in time simultaneously. This allows parallelization over the time dimension on the GPU. 

Take the initial `duration` vector of `shape (modelpoints, )` and turn it into a `duration_projected` matrix `shape (timesteps, modelpoints)` where each row represents a different timestep.

To do this use broadcasting. Broadcasting is explained in detail [here](https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules), it discusses this exact problem.

In [30]:
timesteps = 5 # The policy is a 5-year policy
print(f"duration: \n {duration}")
time_axis = jnp.arange(timesteps)[:, jnp.newaxis]
print(f"time_axis: \n {time_axis}")
duration_projected = time_axis + duration
print(f"duration_projected: \n {duration_projected}")

duration: 
 [0 0 0]
time_axis: 
 [[0]
 [1]
 [2]
 [3]
 [4]]
duration_projected: 
 [[0 0 0]
 [1 1 1]
 [2 2 2]
 [3 3 3]
 [4 4 4]]


## Decrements and claims

Look up the mortality rates `q`. Calculate the probability of each policy surviving to the beginning of the duration `npx`. Indexing is such that `npx * q` calculates probabilities of policies alive at duration 0 dying during the timestep. We can multiply it by the payout on death (`face` amount of contract).

In [29]:
# Once duration passes select table, use ultimate table
q = jnp.where(
    duration_projected < select.shape[-1],
    select[mortality_table_index, issue_age - 18, duration_projected],
    ultimate[mortality_table_index, (issue_age - 18) + duration_projected],
)

npx = jnp.concatenate([jnp.ones((1, q.shape[1])), jnp.cumprod(1-q, axis=0)[:-1]])

cashflows = face * npx * q

print("cashflows: \n", cashflows)

cashflows: 
 [[ 15.000001  95.       185.      ]
 [ 15.9976   174.96675  239.82239 ]
 [ 20.99349  249.86502  336.9265  ]
 [ 23.987522 294.6933   406.25833 ]
 [ 26.979483 339.4461   487.71072 ]]


## Discount the cashflows by broadcasting

Broadcasting is awesome!

In [28]:
discount_factor = 1/(1.02)
discounts = discount_factor ** jnp.arange(timesteps)[:, jnp.newaxis]
discounts_lagged = discounts * discount_factor

discounted_expected_claims = face * npx * q * discounts_lagged
# net present value (NPV) of outgoing cashflows
discounted_expected_claims_NPV_per_policy = jnp.sum(discounted_expected_claims, axis=0)
net_premium_per_policy = discounted_expected_claims_NPV_per_policy / jnp.sum(npx * discounts, 0)
net_premium_per_policy

DeviceArray([ 20.070742, 224.05084 , 322.29498 ], dtype=float32)