# Mutable state and the cache



## If mutable state, clear all cache values

Suppose the mortality rate comes from a mutable object. This might happen if we want to test different mortality rates within a single program.

In [None]:
from functools import cache

class Mortality:
    def __init__(self, mortality_rate):
        self.mortality_rate = mortality_rate

mort = Mortality(.01)

@cache
def dead(t):
    return alive(t) * mort.mortality_rate

@cache
def alive(t):
    if t <= 0:
        return 1
    return alive(t-1) - dead(t-1)

print("Initial")
print(f"{mort.mortality_rate = }")
print(f"{alive(9)=}")
print(f"{dead(9)=}")
mort.mortality_rate = .02
print("INCORRECT: mortality rate change, no clearing of cache")
print(f"{mort.mortality_rate = }")
print(f"{alive(9)=}")
print(f"{dead(9)=}")
dead.cache_clear()
print("INCORRECT: mortality rate change, only clear dead cache")
print(f"{mort.mortality_rate = }")
print(f"{alive(9)=}")
print(f"{dead(9)=}")
dead.cache_clear()
alive.cache_clear()
print("CORRECT: mortality rate change, clear all caches")
print(f"{mort.mortality_rate = }")
print(f"{alive(9)=}")
print(f"{dead(9)=}")

Initial
mort.mortality_rate = 0.01
alive(9)=0.9135172474836409
dead(9)=0.00913517247483641
INCORRECT: mortality rate change, no clearing of cache
mort.mortality_rate = 0.02
alive(9)=0.9135172474836409
dead(9)=0.00913517247483641
INCORRECT: mortality rate change, only clear dead cache
mort.mortality_rate = 0.02
alive(9)=0.9135172474836409
dead(9)=0.01827034494967282
CORRECT: mortality rate change, clear all caches
mort.mortality_rate = 0.02
alive(9)=0.8337477621301499
dead(9)=0.016674955242603


### Notes on testing

With this approach we end up managing global state throughout our tests to ensure each test has the proper state.

## Eliminate global mutable state with pure functions

[According to StackOverflow](https://softwareengineering.stackexchange.com/questions/148108/why-is-global-state-so-evil) "You could probably write an entire book on why global state is bad." In our example, we can eliminate the global mutable state by just passing the mortality rate into the functions. When a function's output is entirely determined by its inputs it is said to be a **pure function**. Pure functions have the benefit of being testable.

In [None]:
@cache
def dead_pure(t, mortality_rate: float):
    return alive_pure(t, mortality_rate) * mortality_rate

@cache
def alive_pure(t, mortality_rate: float):
    if t <= 0:
        return 1
    return alive_pure(t-1, mortality_rate) - dead_pure(t-1, mortality_rate)

print("Same result as before, but less bugs, more testable")
print(f"{dead_pure(9, .01) = }")
print(f"{dead_pure(9, .02) = }")

Same result as before, but less bugs, more testable
dead_pure(9, .01) = 0.00913517247483641
dead_pure(9, .02) = 0.016674955242603


### Challenges with pure functions

 What if mortality rate is actually a function dependending on `issue_age`, `gender`, `smoker_status` and `t`. Now functions are taking many more arguments and writing the code becomes challenging.

```
def dead(t, issue_age, gender, smoker_status):
    return alive(t) * mortality_rate(t, issue_age, gender, smoker_status)

def alive(t, issue_age, gender, smoker_status):
    if t <= 0:
        return 1
    return alive(t-1, issue_age, gender, smoker_status) - dead(t-1, issue_age, gender, smoker_status)
```

### Simplifying arguments to a pure function

We can reduce the number of arguments passed to a pure function by bundling them into a single argument.

In [None]:
class ModelPoint:
    def __init__(self, issue_age: int, gender: int, smoker_status: int):
        self.issue_age = issue_age
        self.gender = gender
        self.smoker_status = smoker_status

    def __repr__(self):
        return (f'{self.__class__.__name__}('
                f'issue_age={self.issue_age!r}, '
                f'gender={self.gender!r}, '
                f'smoker_status={self.smoker_status!r})')

mp = ModelPoint(40, 1, 0)

def mortality_rate_func(mp):
    base_rate = .01
    factor = 1
    if mp.issue_age > 50:
        factor += 1
    if mp.gender == 0:
        factor -= .5
    if mp.smoker_status == 1:
        factor += 1.5
    return base_rate * factor

@cache
def dead_pure_object(t, mp: ModelPoint):
    return alive_pure_object(t, mp) * mortality_rate_func(mp)

@cache
def alive_pure_object(t, mp: ModelPoint):
    if t <= 0:
        return 1
    return alive_pure_object(t-1, mp) - dead_pure_object(t-1, mp)

print(mp)
print(f"{dead_pure_object(9, mp)=}")
print(f"Change the modelpoint")
mp.issue_age = 60
print(mp)
print(f"{dead_pure_object(9, mp)=}")
print("Oh no, the function isn't changed, let's clear the caches")
dead_pure_object.cache_clear()
alive_pure_object.cache_clear()
print(f"{dead_pure_object(9, mp)=}")
print("Ok, that is correct now")

ModelPoint(issue_age=40, gender=1, smoker_status=0)
dead_pure_object(9, mp)=0.00913517247483641
Change the modelpoint
ModelPoint(issue_age=60, gender=1, smoker_status=0)
dead_pure_object(9, mp)=0.00913517247483641
Oh no, the function isn't changed, let's clear the caches
dead_pure_object(9, mp)=0.016674955242603
Ok, that is correct now


## Objects as cache keys

The reason the cache needed to be cleared in the above code block is because memoization works as follows.

* Check if the function has already been called with the same arguments.
* If it has already been calculated for these same arguments, return the known value.
* Else, calculate the value.

The difference between

```
@cache
def dead_pure(t, mortality_rate: float):
    return alive_pure(t, mortality_rate) * mortality_rate
```

The arguments for `dead_pure(9, .01)` and `dead_pure(9, .02)` Are recognized by the cache as distinct values, we can see below that the cache contains 20 distinct values, `currsize=20`. There isn't a way to access the actual cache, but it would have 20 keys total: `[(0, .01), (0, .02), (1, .01), (1, .02), ..., (9, .02)]`.

In [None]:
dead_pure.cache_info()

CacheInfo(hits=0, misses=20, maxsize=None, currsize=20)

For `dead_pure_object` we have the modelpoints object being cached. And the keys of the cache are like `[(0, mp_object), (1, mp_object), ..., (9, mp_object)]`. If we mutate the `mp_object` the cache will still return the cached value because it has already seen the `(9, mp_obect)` pair as an argument to the function.

In [None]:
dead_pure_object.cache_info()

CacheInfo(hits=0, misses=10, maxsize=None, currsize=10)

## Overview

### Clearing the cache

We clear the cache of every function individually when using `functools.cache`. This is not acceptable as there might be hundreds of formulas, how are we supposed to clear them all?

In the next notebook we discuss implementing our own cache that can clear all of the functions at once.

Clearing the cache is needed if our cached function relies on a reference to an object which has mutated.

### Global state

Actuaries like to write functions with a single parameter, the timestep. This relies on global mutable state.

```py
# I don't like it
def alive(t):
    ...

def dead(t):
    ...
```

This convention likely comes from proprietary actuarial modeling software.

In Python, it will create challenges in unit testing the application. I can't speak to the effectiveness of unit testing in proprietary actuarial software because there is little public information available, but I can imagine they have challenges.

### Making functions pure

We can pass the model points into each function. Which actuaries familiar with proprietary software might not like because they have to type more.

```py
# Almost good, good enough? Up to you.
def alive(t, mp: ModelPoint):
    ...

def dead(t, mp: ModelPoint):
    ...
```

At least if we pass the ModelPoint directly into the function, the result of the function is not determined by data defined outside the scope of the function.

With cached functions, the state of the cache is global state so the function is only deterministic if the cache is clear.