In [7]:
# without the os and sys, error that module jaxtuary is not found
import os # https://stackoverflow.com/questions/34478398/import-local-function-from-a-module-housed-in-another-directory-with-relative-im
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
# sorry for jank
import jax.numpy as jnp
from pymort import getIdGroup, MortXML
from jaxtuary import get_q_from_select_ultimate, get_npx_from_q

## Loading mortality tables

For this we will be using the study `2017_CSO loaded preferred_structure gender_distinct ANB`. We can get the PyMort object that represents this collection by referencing any of the tables belonging to the collection. 

In [8]:
print(getIdGroup(3299))

IdGroup(study='2017_CSO', grouping='loaded preferred_structure gender_distinct ANB', ids=(3299, 3300, 3301, 3302, 3303, 3304, 3305, 3306, 3307, 3308), genders=('male', 'male', 'male', 'female', 'female', 'female', 'male', 'male', 'female', 'female'), risks=('nonsmoker super_preferred', 'nonsmoker preferred', 'nonsmoker residual', 'nonsmoker super_preferred', 'nonsmoker preferred', 'nonsmoker residual', 'smoker preferred', 'smoker residual', 'smoker preferred', 'smoker residual'))


Load mortality tables to tensor format.

In [9]:
ids = getIdGroup(3299).ids
select = jnp.array([MortXML(id).Tables[0].Values.unstack().values for id in ids])
ultimate = jnp.array([MortXML(id).Tables[1].Values.unstack().values for id in ids])
print(f"select.shape: {select.shape}") # tableIds [3299, 3308], issue_ages [18, 95], durations [1, 25]
print(f"ultimate.shape: {ultimate.shape}") # tableIds [3299, 3308], attained_ages [18, 120]



select.shape: (10, 78, 25)
ultimate.shape: (10, 103)


## Policyholder attributes

In [10]:
mortality_table_index = jnp.array([0,1,2])
issue_age = jnp.array([30, 40, 50])
duration = jnp.array([0, 0, 0]) # new business
face = jnp.array([1000*x for x in [100, 500, 250]])

## The time dimension

Traditional actuarial modeling techniques do calculations recursively. In contrast, we compute cashflows for all points in time simultaneously. This allows us to parallelize over the time dimension on the GPU. 

To project the initial `duration` into an array of shape (timesteps, modelpoints) we use broadcasting. Broadcasting is explained in detail [here](https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules), they even talk about this exact problem.

In [11]:
# How many years are we projecting?
timesteps = 10
print(f"duration: \n {duration}")
duration_projected = jnp.arange(timesteps)[:, jnp.newaxis] + duration
print(f"duration_projected: \n {duration_projected}")

duration: 
 [0 0 0]
duration_projected: 
 [[0 0 0]
 [1 1 1]
 [2 2 2]
 [3 3 3]
 [4 4 4]
 [5 5 5]
 [6 6 6]
 [7 7 7]
 [8 8 8]
 [9 9 9]]


Get the cashflows.

In [10]:
def get_q_from_select_ultimate(
    mortality_table_index: jnp.ndarray,
    issue_age: jnp.ndarray,
    duration: jnp.ndarray,
    select: jnp.ndarray,
    ultimate: jnp.ndarray,
    min_age_select=18,
    min_age_ultimate=18,
):
    """
    Get the mortality rates from select/ultimate mortality table. 

    When duration is out of bounds it pulls rates from the end of the table.
    https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing
    """
    return jnp.where(
        duration < select.shape[-1],
        select[mortality_table_index, issue_age - min_age_select, duration],
        ultimate[mortality_table_index, (issue_age - min_age_ultimate) + duration],
    )

q = get_q_from_select_ultimate(
  mortality_table_index,
  issue_age,
  duration_projected, # when projecting, remains broadcastable,
  select, 
  ultimate
)

npx = get_npx_from_q(q)

cashflows = face * npx * q

print("cashflows: \n", cashflows)

cashflows: 
 [[ 15.000001  95.       185.      ]
 [ 15.9976   174.96675  239.82239 ]
 [ 20.99349  249.86502  336.9265  ]
 [ 23.987522 294.6933   406.25833 ]
 [ 26.979483 339.4461   487.71072 ]
 [ 29.96911  389.09988  568.70844 ]
 [ 32.95613  433.6575   624.3945  ]
 [ 37.936962 488.0628   689.5518  ]
 [ 42.912357 557.2394   798.5357  ]
 [ 52.869232 636.1317   943.3463  ]]


## Discount the cashflows by broadcasting

How do we represent the discounts? We could store them as a tensor of shape (timesteps, ) but the q tensor is of shape (timesteps, modelpoints).  

* `(timesteps, ) * (timesteps, modelpoints)` **not broadcastable**.
* `(timesteps, 1) * (timesteps, modelpoints)` **is broadcastable** and results in shape `(timesteps, modelpoints)`.

Understanding the [general broadcasting rule](https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules) eliminates the need to `from jaxtuary import project_duration` because -

```py
def project_duration(duration: jnp.ndarray, timesteps: int):
    return duration[None, :] + jnp.arange(timesteps)[:, None]
```

In [1]:
print(jnp.array([1,3]).shape)

NameError: name 'jnp' is not defined

print(jnp.array([1,3]).shape)

In [44]:
discount_factor = 1/(1.02)
discounts = discount_factor ** jnp.arange(timesteps)[:, None]
discounts_lagged = discounts * discount_factor

discounted_expected_claims = face * npx * q * discounts_lagged
# net present value (NPV) of outgoing cashflows
discounted_expected_claims_NPV_per_policy = jnp.sum(discounted_expected_claims, axis=0)
net_premium_per_policy = discounted_expected_claims_NPV_per_policy / jnp.sum(npx * discounts, 0)
net_premium_per_policy

DeviceArray([ 28.778994, 350.5083  , 508.1781  ], dtype=float32)