# Build your own decorator

## Why decorators

* Python already uses them to cache functions with `@functools.cache`
* At least three people are using them to build open source life insurance software.
    * https://github.com/fumitoh/modelx
    * https://github.com/acturtle
    * https://github.com/actuarialopensource/benchmarks


We assume you already know how decorators work. There is a great guide on [realpython](https://realpython.com/primer-on-python-decorators/). Let's discuss a couple of topic of interests.

## Clear the cache for all formulas at once

As the number of formulas grows, we need to be able to clear all of the caches at the same time. This is not possible with `@functools.cache`.



In [93]:
from collections import defaultdict
from functools import wraps

class CashBasic:
    def __init__(self):
        self.cache_clear()

    def cache_clear(self):
        self.caches = {}

    def __call__(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            key = (func, args, frozenset(kwargs.items()))
            if key not in self.caches:
                self.caches[key] = func(*args, **kwargs)
            return self.caches[key]
        return wrapper

class Mortality:
    def __init__(self, mortality_rate):
        self.mortality_rate = mortality_rate

mort = Mortality(.01)

cash_basic = CashBasic()

@cash_basic
def dead(t):
    return alive(t) * mort.mortality_rate

@cash_basic
def alive(t):
    if t <= 0:
        return 1
    return alive(t-1) - dead(t-1)

print(f"{dead(9)=}")
cash_basic.cache_clear() # We don't need to clear each function individually like with @functools.cache
mort.mortality_rate = .02
print(f"{dead(9)=}")


dead(9)=0.00913517247483641
dead(9)=0.016674955242603


## Clear unused cache values at runtime

If we are certain that once we calculate certain a value for a timestep `t`no values from timestep`t-1` are necessary, we are able to clear the cache for timestep `t-1`.

`modelx` has something similar to this. There are a lot of ways you can try to accomplish this, here we take an approach where users must manually clear the cache for a particular timestep.

We will discuss this more later, but vectorizing calculations improves performance significantly. This comes at the cost of increased memory consumption, since the cached values are large vectors. The next example is vectorized using NumPy, and includes logic for calculating the memory consumption of the model.



In [57]:
from collections import defaultdict
from functools import wraps
import numpy as np

class CashMemoryOptimized:
    def __init__(self):
        self.cache_clear()

    def cache_clear(self):
        self.caches = defaultdict(dict)
        self.max_cache_size = 0
        self.cache_misses = 0

    def get_cache_size(self):
        total = 0
        for timestep_cache in self.caches.values():
            for np_array in timestep_cache.values():
                total += np_array.nbytes
        self.max_cache_size = max(self.max_cache_size, total)

    def cache_clear_at_timestep(self, t):
        del self.caches[t]

    def __call__(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            key = (func, args, frozenset(kwargs.items()))
            has_timestep = (len(args) > 0) and type(args[0]) == int # We are assuming each function has an int timestep as first arg
            timestep_arg = args[0] if has_timestep else None
            if key not in self.caches[timestep_arg]:
                self.cache_misses += 1
                self.caches[timestep_arg][key] = func(*args, **kwargs)
                self.get_cache_size() # recalculate max cache size
            return self.caches[timestep_arg][key]
        return wrapper

class Mortality:
    def __init__(self, mortality_rate):
        self.mortality_rate = mortality_rate

mort = Mortality(np.linspace(.01, .02, 1000))

cash_optimized = CashMemoryOptimized()

@cash_optimized
def dead(t):
    return alive(t) * mort.mortality_rate

@cash_optimized
def alive(t):
    if t <= 0:
        return np.ones(len(mort.mortality_rate))
    return alive(t-1) - dead(t-1)


Below we see over 99% reduction in memory usage is possible. The performance gains are larger than expected, unsure the reasoning behind that.

In [92]:
import time

max_timestep = 240
number_of_policies = 500_000
mort.mortality_rate = np.linspace(.01, .02, number_of_policies)

print(f"{number_of_policies=}\n")

print("#### Unoptimized statistics ####")
cash_optimized.cache_clear()
start_time_unoptimized = time.time()
result_unoptimized = np.sum(dead(max_timestep))
print("--- %s seconds ---" % (time.time() - start_time_unoptimized))
start_time_unoptimized = time.time()
print(f"{result_unoptimized=}")
unoptimized_memory_consumption_in_bytes = cash_optimized.max_cache_size
print(f"Memory consumption {unoptimized_memory_consumption_in_bytes/(10**9)} GB")
print(f"{cash_optimized.cache_misses=}\n")

print("#### Optimized statistics ####")
cash_optimized.cache_clear()
start_time_optimized = time.time()
for t in range(1,max_timestep+1):
    dead(t)
    cash_optimized.cache_clear_at_timestep(t-1)
result_optimized = np.sum(dead(max_timestep))
print("--- %s seconds ---" % (time.time() - start_time_optimized))
print(f"{result_optimized=}")
optimized_memory_consumption_in_bytes = cash_optimized.max_cache_size
print(f"Memory consumption {optimized_memory_consumption_in_bytes/(10**9)} GB")
print(f"{cash_optimized.cache_misses=}")

print("\n#### Memory savings ####")
print(f"1 - optimized/unoptimized = {1-optimized_memory_consumption_in_bytes/unoptimized_memory_consumption_in_bytes}")


number_of_policies=500000

#### Unoptimized statistics ####
--- 1.906851053237915 seconds ---
result_unoptimized=221.0718846279971
Memory consumption 1.928 GB
cash_optimized.cache_misses=482

#### Optimized statistics ####
--- 0.6406657695770264 seconds ---
result_optimized=221.0718846279971
Memory consumption 0.016 GB
cash_optimized.cache_misses=482

#### Memory savings ####
1 - optimized/unoptimized = 0.991701244813278


## Aggregate and store results at runtime

If we are clearing cache values at runtime, we won't have immediate access to the cached values to generate tables with quantities like

* `[np.sum(deaths(t)) for t in range(max_timesteps+1)]`
* `[np.sum(alive(t)) for t in range(max_timesteps+1)]`

Let's look at one way to do this.

In [100]:
from collections import defaultdict
from functools import wraps
import pandas as pd

class CashAggregated:
    def __init__(self):
        self.cache_clear()

    def cache_clear(self):
        self.caches = {}
        self.stored_values = defaultdict(dict)

    def __call__(self, storage_func=None):
        def decorator_factory(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                key = (func, args, frozenset(kwargs.items()))
                if key not in self.caches:
                    self.caches[key] = func(*args, **kwargs)
                if not storage_func is None:
                    self.stored_values[func.__name__][args[0]] = storage_func(self.caches[key])
                return self.caches[key]
            return wrapper
        return decorator_factory

mort = Mortality(.01 * np.ones(1000))

cash_aggregated = CashAggregated()

@cash_aggregated(lambda x: np.sum(x))
def dead(t):
    return alive(t) * mort.mortality_rate

@cash_aggregated(lambda x: np.sum(x))
def alive(t):
    if t <= 0:
        return np.ones(len(mort.mortality_rate))
    return alive(t-1) - dead(t-1)

dead(10)
pd.DataFrame(cash_aggregated.stored_values)


Unnamed: 0,alive,dead
0,1000.0,10.0
1,990.0,9.9
2,980.1,9.801
3,970.299,9.70299
4,960.59601,9.60596
5,950.99005,9.5099
6,941.480149,9.414801
7,932.065348,9.320653
8,922.744694,9.227447
9,913.517247,9.135172


## Summary

* You will certainly want to be able to clear the cache for all formulas at once.
* It would be nice if you could limit memory consumption. This is most easily accomplished if all formulas at timestep `t` only depend on timestep `t-1`.
* You can aggregate a result when it is calculated and store it in a special format for displaying it later.