In [45]:
import pandas as pd
import numpy as np
from typing import Callable
from collections import defaultdict

import functools

class Cash:
    def __init__(self):
        self.cache = {}

    def __call__(self, func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            cache_key = (func.__name__, args, frozenset(kwargs.items()))
            if cache_key not in self.cache:
                self.cache[cache_key] = func(*args, **kwargs)
            return self.cache[cache_key]

        return wrapper

    def clear_cache(self):
        self.cache.clear()

cash = Cash()

# this is a decorator for memoizing functions. The 'Cash' class is really a cache.
class Cash:
    def __init__(self):
        self.reset()
    def reset(self):
        self.caches = defaultdict(dict)
    def __call__(self, f: Callable[[int], np.array]) -> Callable:
        def g(t):
            if t not in self.caches[f.__name__]:
                self.caches[f.__name__][t] = f(t)
            return self.caches[f.__name__][t]
        return g
cash = Cash()

In [37]:
mp = pd.read_csv("src/model_point_table.csv")
class Mortality:
    def __init__(self, issue_age: pd.Series):
        self.mort_np = pd.read_csv("src/mort_table.csv").drop(columns=["Age"]).to_numpy()
        self.issue_age = issue_age
    def get_annual_rate(self, duration: int):
        return self.mort_np[self.issue_age + duration - 18, min(duration, 5)]
    def get_monthly_rate(self, duration: int):
        return 1 - (1 - self.get_annual_rate(duration))**(1/12)
mortality = Mortality(mp["age_at_entry"])

In [38]:
cash = Cash()
def duration(t: int):
    return t // 12
@cash
def pols_death(t: int):
    return pols_inforce(t) * mortality.get_monthly_rate(duration(t))
@cash
def pols_inforce(t: int):
    if t == 0:
        return np.ones(len(mp))
    return pols_inforce(t - 1) - pols_lapse(t - 1) - pols_death(t - 1) - pols_maturity(t)
def lapse_rate(t: int):
    return max(0.1 - 0.02 * duration(t), 0.02)
@cash
def pols_lapse(t: int):
    return (pols_inforce(t) - pols_death(t)) * (1 - (1 - lapse_rate(t))**(1/12))
@cash
def pols_maturity(t: int):
    return (t == 12 * mp["policy_term"]) * (pols_inforce(t - 1) - pols_lapse(t - 1) - pols_death(t - 1))

In [39]:
def summarize_results(cash: Cash) -> pd.DataFrame:
    res = []
    for function_name, cache in cash.caches.items():
        s = pd.Series({k: sum(v) for k, v in cache.items()}, name=function_name)
        res.append(s)
    return pd.concat(res, axis=1)

In [41]:
cash.reset()
pols_lapse(20 * 12)

0       0.0
1       0.0
2       0.0
3       0.0
4       0.0
       ... 
9995    0.0
9996    0.0
9997    0.0
9998    0.0
9999    0.0
Name: policy_term, Length: 10000, dtype: float64

In [42]:
results = summarize_results(cash)
results["duration"] = results.index // 12

In [43]:
results

Unnamed: 0,pols_lapse,pols_inforce,pols_death,pols_maturity,duration
0,87.411971,10000.000000,0.473439,,0
1,86.643747,9912.114590,0.469271,0.000000,0
2,85.882275,9825.001571,0.465140,0.000000,0
3,85.127496,9738.654156,0.461045,0.000000,0
4,84.379349,9653.065616,0.456986,0.000000,0
...,...,...,...,...,...
236,2.965478,1763.417981,0.501071,0.000000,19
237,2.959648,1759.951432,0.499977,0.000000,19
238,2.953830,1756.491808,0.498885,0.000000,19
239,2.948024,1753.039093,0.497795,0.000000,19


In [47]:
from functools import lru_cache, wraps
from typing import Any, Callable, Dict, Tuple

class LRUCache:
    def __init__(self, maxsize=None):
        self.cache = lru_cache(maxsize)(self.get)

    def get(self, key):
        raise NotImplementedError

    def __getitem__(self, key):
        return self.cache(key)

    def __contains__(self, key):
        return key in self.cache

    def get_cache_key(self, args, kwargs):
        return args

class NumPyArrayLRUCache:
    def __init__(self, maxsize=None):
        self.lrucache = LRUCache(maxsize)
        self.aggregate_sums = {}

    def get(self, key):
        return self.lrucache[key]

    def __getitem__(self, key):
        return self.get(key)

    def __contains__(self, key):
        return key in self.lrucache

    def get_cache_key(self, args, kwargs):
        return self.lrucache.get_cache_key(args, kwargs)

    def __setitem__(self, key, value):
        self.aggregate_sums[key] = value.sum()
        self.lrucache.cache.__wrapped__ = lambda k: value
        self.lrucache.cache.cache_clear()
        self.lrucache.cache(key)


TypeError: 'LRUCache' object is not callable