In [None]:
import pandas as pd
import numpy as np

In [31]:
class ActuarialCalculator:
    """

    """

    # =================================================================================
    # --- 1. INITIALIZER & SETUP METHODS
    # =================================================================================

    def __init__(self,
                 interest_rate: float,
                 mode: str,
                 data: dict,
                 select_period: int = 2,
                 radix_age: int = 20,
                 radix_lx: int = 100_000,
                 max_age: int = 120):
        """
        Initializes the calculator based on the chosen mode.

        Args:
            interest_rate (float): The effective annual interest rate.
            mode (str): Mortality data mode. One of: 'ultimate_tabular', 'select_tabular', 'parametric_makeham'.
            data (dict): Data for the chosen mode.
            select_period (int, optional): Select period in years for select tables.
            radix_age, radix_lx, max_age (int, optional): Params for parametric generation.
        """
        # --- Interest Rate Setup ---
        self.i = interest_rate
        self.v = 1 / (1 + self.i)
        self.d = 1 - self.v
        self.d_m = {m: m * (1 - self.v**(1/m)) for m in range(1, 13)} # Pre-calculates d^(m) for m=1 to 12
        self.i_m = {m: m * ((1 + self.i)**(1/m) - 1) for m in range(1, 13)}

        # --- Mortality Data Setup (Dispatcher) ---
        self.mode = mode
        self.is_select = mode in ['select_tabular', 'parametric_makeham']
        self.select_table = None
        self.ultimate_table = None
        self.select_period = select_period

        if mode == 'ultimate_tabular':
            self._setup_tabular_ultimate(data)
        elif mode == 'select_tabular':
            self._setup_tabular_select(data)
        elif mode == 'parametric_makeham':
            self._setup_parametric_makeham(data, radix_age, radix_lx, max_age)
        else:
            raise ValueError(f"Mode '{mode}' is not recognized.")


        self._original_data = data
        self._original_mode = mode

        # --- Finalize Ultimate Table (Common for all modes) ---
        if self.ultimate_table is None:
            raise Exception("Ultimate table could not be constructed.")

        self.ultimate_table['dx'] = self.ultimate_table['lx'].diff(periods=-1)
        self.ultimate_table.loc[self.ultimate_table.index.max(), 'dx'] = self.ultimate_table['lx'].iloc[-1]
        self.ultimate_table['qx'] = self.ultimate_table['dx'] / self.ultimate_table['lx']
        self.max_age = self.ultimate_table.index.max()

    def _setup_tabular_ultimate(self, data: dict):
        """Helper for ultimate table from a DataFrame."""
        if 'table' not in data or not isinstance(data['table'], pd.DataFrame):
            raise ValueError("For 'ultimate_tabular' mode, data must contain a 'table' DataFrame.")
        self.ultimate_table = data['table'].set_index('age').copy()

    def _setup_tabular_select(self, data: dict):
        """Helper for select table from a DataFrame, with optional ultimate table."""
        if 'select' not in data or not isinstance(data['select'], pd.DataFrame):
            raise ValueError("For 'select_tabular' mode, data must contain a 'select' DataFrame.")

        self.select_table = data['select'].set_index('age').copy()

        if self.select_period is None:
            self.select_period = sum(1 for col in self.select_table.columns if str(col).startswith('l_['))

        if 'ultimate' in data and data['ultimate'] is not None:
            self.ultimate_table = data['ultimate'].set_index('age').copy()
        else:
            ultimate_age_col = self.select_table.index + self.select_period
            ultimate_lx_col = self.select_table[f'l_[x]+{self.select_period-1}']
            self.ultimate_table = pd.DataFrame({'lx': ultimate_lx_col.values}, index=ultimate_age_col)

    def _setup_parametric_makeham(self, data: dict, radix_age: int, radix_lx: int, max_age: int):
        """Helper that generates both tables from Makeham's Law and book formula 3.15."""
        if 'ultimate_params' not in data or 'select_params' not in data:
            raise ValueError("Parametric mode requires 'ultimate_params' and 'select_params'.")

        # Step 1: Generate the ultimate table.
        self.ultimate_table = self._generate_ultimate_table_from_makeham(
            params=data['ultimate_params'],
            radix_age=radix_age, radix_lx=radix_lx, max_age=max_age
        )


        self.max_age = self.ultimate_table.index.max()
        # Step 2: Generate the select table using the direct formula approach.
        self.select_table = self._generate_select_table_from_formula(
            ultimate_params=data['ultimate_params'],
            select_params=data['select_params']
        )

    def _generate_ultimate_table_from_makeham(self, params: dict, radix_age: int, radix_lx: int, max_age: int) -> pd.DataFrame:
        """Generates an ultimate life table DataFrame using Makeham's Law."""
        A, B, c = params['A'], params['B'], params['c']
        ln_c = np.log(c)

        ages = np.arange(radix_age, max_age + 1)
        px = np.exp(-A - (B * (c**ages) * (c - 1)) / ln_c)

        lx_values = np.zeros_like(ages, dtype=float)
        lx_values[0] = radix_lx
        for i in range(len(ages) - 1):
            lx_values[i+1] = lx_values[i] * px[i]

        return pd.DataFrame({'lx': lx_values}, index=ages)

    def _calculate_t_p_select(self, x: int, t: float, ultimate_params: dict, select_params: dict) -> float:
        """Implements the direct formula (3.15 from the book) to calculate t_p_[x]."""
        if t == 0:
            return 1.0

        A, B, c = ultimate_params['A'], ultimate_params['B'], ultimate_params['c']
        m = select_params['modifier']

        log_m = np.log(m)
        log_m_div_c = np.log(m / c)

        term_A = (1 - m**t) / log_m
        term_B = (c**t - m**t) / log_m_div_c

        exponent = (m**(2-t))* (term_A * A + term_B * B * (c**x))
        return np.exp(exponent)


    def _generate_select_table_from_formula(self, ultimate_params: dict, select_params: dict) -> pd.DataFrame:
        """
        Generates the select table for any select period 's' using the general
        backward calculation method.
        """
        s = self.select_period
        valid_ages = self.ultimate_table.index[self.ultimate_table.index <= self.max_age - s]
        select_df = pd.DataFrame(index=valid_ages)

        # Step 1: Get the final values l_{x+s} from the ultimate table
        l_x_plus_s = np.array([self.ultimate_table.loc[x + s, 'lx'] for x in valid_ages])

        # Step 2: Pre-calculate all t_p_[x] values from t=0 to s to avoid repeated calculations
        # This is an optimization for performance.
        t_p_select_matrix = {
            t: np.array([self._calculate_t_p_select(x, t, ultimate_params, select_params) for x in valid_ages])
            for t in range(s + 1)
        }

        # Step 3: Loop through each year of the select period to calculate the columns
        for i in range(s):
            # We want to calculate the column l_[x]+i

            # Calculate the required survival probability: (s-i)_p_[x]+i = (s_p_[x]) / (i_p_[x])
            p_s = t_p_select_matrix[s]
            p_i = t_p_select_matrix[i]

            survival_prob = np.divide(p_s, p_i, out=np.zeros_like(p_s), where=p_i != 0)

            # Determine the column name, e.g., 'l_[x]' for i=0 or 'l_[x]+1' for i=1
            col_name = f'l_[x]' if i == 0 else f'l_[x]+{i}'

            # Apply the general backward formula: l_[x]+i = l_{x+s} / survival_prob
            select_df[col_name] = np.divide(l_x_plus_s, survival_prob, out=np.zeros_like(l_x_plus_s), where=survival_prob != 0)

        return select_df



    # =================================================================================
    # --- 2. Helper METHODS For Select vs. Ultimate tables
    # =================================================================================

    def _get_lx(self, x: int, t: int = 0) -> float:
        """
        The central helper method to get the number of lives at a certain point.
        It intelligently handles select vs. ultimate tables.

        Args:
            x (int): The age at selection. For ultimate tables, this is just the age.
            t (int): The duration in years since age x. Defaults to 0.

        Returns:
            float: The number of lives (lx value).
        """
        attained_age = x + t

        # Case 1: The table is a simple ultimate table.
        if not self.is_select:
            if attained_age > self.max_age or attained_age not in self.ultimate_table.index:
                return 0.0
            return self.ultimate_table.loc[attained_age, 'lx']

        # Case 2: The table is a select & ultimate table.
        else:
            # Check if we are still within the selection period.
            if t < self.select_period:
                col_name = f'l_[x]' if t == 0 else f'l_[x]+{t}'

                # --- THE FIX IS HERE ---
                # First, check if the selection age 'x' exists in the table's index.
                if x in self.select_table.index:
                    # Use .loc to correctly select the row by its index label.
                    row = self.select_table.loc[x]
                    # Then, use .get() on the resulting Series to safely get the column value.
                    return row.get(col_name, 0.0)
                else:
                    # If the selection age is not in our table, there are no survivors.
                    return 0.0
            else:
                # The selection period is over; use the ultimate table for the attained age.
                if attained_age > self.max_age or attained_age not in self.ultimate_table.index:
                    return 0.0
                return self.ultimate_table.loc[attained_age, 'lx']


    def _get_dx(self, x: int, t: int = 0) -> float:
          """
          Calculates the number of deaths (dx) on the fly for any cohort.
          It uses the _get_lx helper to handle select vs. ultimate logic.

          Args:
              x (int): The age at selection.
              t (int): The duration in years since age x.

          Returns:
              float: The number of deaths during the year, d_[x]+t or d_x+t.
          """
          # Get the number of lives at the start of the year
          lx_start = self._get_lx(x, t)

          # Get the number of lives at the end of the year
          lx_end = self._get_lx(x, t + 1)

          return lx_start - lx_end

    def validate_inputs(self, x: int, n: int = None):
          """
          Validates age and optional term length for both select and ultimate tables.
          """
          # --- Step 1: Validate the starting age x ---
          if self.is_select:
              # For a select table, x is the selection age.
              # It must be within the range of the select_table's index.
              if self.select_table is None or not (self.select_table.index.min() <= x <= self.select_table.index.max()):
                  raise ValueError(f"Selection age x ({x}) is out of the select table's range.")
          else:
              # For an ultimate table, x is the current age.
              if not (self.ultimate_table.index.min() <= x <= self.max_age):
                  raise ValueError(f"Age x ({x}) is out of the ultimate table's range.")

          # --- Step 2: Validate the term n, if provided ---
          if n is not None:
              # The final attained age (x+n) must not exceed the model's absolute maximum age.
              if x + n > self.max_age:
                  raise ValueError(f"Attained age x + n ({x + n}) exceeds the limiting age {self.max_age}.")

              if n < 1:
                  raise ValueError("Term n must be at least 1 year.")




## ----------------------------------------------------------------------------------------------------------------------------------------------------------------
##                                                                       Insurance Benefit


    ## ------------------------------------------------------------------
    ##  1: Whole Life Insurance
    ## ------------------------------------------------------------------

    def whole_life_insurance(self, x: int, m: int = 1) -> float:
        """
        Calculates the EPV of a whole life insurance using an efficient,
        pre-calculation (caching) approach for both ultimate and select tables.
        Calculates the EPV of a whole life insurance.
        For m-thly payments (m>1), it uses the UDD approximation.

        Args:
            x (int): The starting age (age at selection for select tables).
            m (int): Number of payment periods per year. Defaults to 1 (annual).

        Returns:
            float: The EPV of the whole life insurance.
        """
        # Step 1: Validate inputs
        self.validate_inputs(x)
        if not isinstance(m, int) or m < 1:
            raise ValueError("m must be a positive integer.")

        # Step 2: Ensure the necessary caches are populated. This is done only once.
        if not hasattr(self, '_ultimate_Ax_cache'):
            self._calculate_ultimate_Ax()
        if self.is_select and not hasattr(self, '_select_Ax_cache'):
            self._calculate_all_select_Ax()

        # Step 3: Get the annual insurance value from the appropriate cache.
        if self.is_select:
            annual_epv = self._select_Ax_cache.get(x, 0.0)
        else:
            annual_epv = self._ultimate_Ax_cache.get(x, 0.0)

        # Step 4: Adjust for m-thly payments if necessary.
        if m == 1:
            return annual_epv
        else:
            # ... (m-thly adjustment logic remains the same) ...
            if self.i == 0: return annual_epv
            i_m_val = self.i_m.get(m)
            if i_m_val is None:
                i_m_val = m * ((1 + self.i)**(1/m) - 1)
                self.i_m[m] = i_m_val
            approximation_factor = self.i / i_m_val
            return annual_epv * approximation_factor

    # --- Helper methods to be added/replaced in the class ---

    def _calculate_ultimate_Ax(self):
        """
        Calculates whole life insurance values for all ages in the ultimate table
        and caches them in a dictionary.
        """
        self._ultimate_Ax_cache = {}
        max_age = self.max_age

        self._ultimate_Ax_cache[max_age] = self.v

        for age in range(max_age - 1, self.ultimate_table.index.min() - 1, -1):
            q_age = self.ultimate_table.loc[age, 'qx']
            p_age = 1 - q_age
            Ax_plus_1 = self._ultimate_Ax_cache[age + 1]
            self._ultimate_Ax_cache[age] = self.v * (q_age + p_age * Ax_plus_1)

    def _calculate_all_select_Ax(self):
        """
        Calculates A_[x] for ALL selection ages using vectorized backward recursion
        and caches the results in a pandas Series. This is the optimal method.
        """
        s = self.select_period
        valid_ages = self.select_table.index

        # Start the recursion with the ultimate values at the end of the select period.
        Ax_next_vector = np.array([self._ultimate_Ax_cache.get(x + s, 0.0) for x in valid_ages])

        # Loop backwards from t = s-1 down to 0
        for t in range(s - 1, -1, -1):
            # Get the vectors of l_x and d_x for all ages for the current duration 't'
            lx_t_vector = np.array([self._get_lx(x, t) for x in valid_ages])
            dx_t_vector = np.array([self._get_dx(x, t) for x in valid_ages])

            # Calculate p and q vectors
            q_select_vector = np.divide(dx_t_vector, lx_t_vector, out=np.zeros_like(dx_t_vector), where=lx_t_vector!=0)
            p_select_vector = 1 - q_select_vector

            # Apply the vectorized recursive formula
            current_Ax_vector = self.v * (q_select_vector + p_select_vector * Ax_next_vector)

            # The result of this iteration becomes the 'next' vector for the previous one.
            Ax_next_vector = current_Ax_vector

        # After the loop, the final vector contains all A_[x] values. Store it.
        self._select_Ax_cache = pd.Series(Ax_next_vector, index=valid_ages)




    ## ------------------------------------------------------------------
    ##  2: Term Insurance
    ## ------------------------------------------------------------------

    def term_insurance(self, x: int, n: int, m: int = 1) -> float:
        """
        Calculates the EPV of an n-year term insurance.

        It handles both annual (m=1) and m-thly (m>1) cases.
        For m=1 Formula is: (1/lx) * sum(v^(k+1) * d_{x+k}) for k=0 to n-1
        For m-thly payments, it uses the UDD approximation: A^(m) ≈ (i / i^(m)) * A.

        Args:
            x (int): The age of the insured.
            n (int): The term of the policy in years.
            m (int): Number of payments per year. Defaults to 1 (annual).

        Returns:
            float: The EPV of the term insurance.
        """
        # Step 1: Validate inputs.
        self.validate_inputs(x, n)
        if not isinstance(m, int) or m < 1:
            raise ValueError("m must be a positive integer.")

        # --- Step 2: Always calculate the annual term insurance value first. ---
        lx_start = self._get_lx(x, t=0)
        if lx_start == 0:
            return 0.0

        dx_values = [self._get_dx(x, k) for k in range(n)]

        k_range = np.arange(n)
        v_powers = self.v ** (k_range + 1)

        discounted_sum = np.sum(dx_values * v_powers)
        annual_epv = discounted_sum / lx_start

        # --- Step 3: Check the payment frequency 'm' ---
        if m == 1:
            # If payments are annual, we are done.
            return annual_epv
        else:
            # If payments are m-thly, apply the UDD approximation factor.
            # Check for zero interest rate.
            if self.i == 0:
                return annual_epv # When i=0, i^(m)=0, ratio is 1.

            # The i_m dictionary was pre-calculated in __init__ for efficiency.
            i_m_val = self.i_m[m]
            approximation_factor = self.i / i_m_val

            return annual_epv * approximation_factor


    ## ------------------------------------------------------------------
    ##  3: Pure Endowment
    ## ------------------------------------------------------------------

    def pure_endowment(self, x: int, n: int) -> float:

      """Calculates the EPV of an n-year pure endowment of 1.

      This method computes the value of a payment of 1 made in n years,
      contingent on survival. It automatically handles both ultimate (_nE_x)
      and select (_nE_[x]) calculations based on the table provided at initialization.

      The underlying formula is v^n * n_p_x.

      Args:
          x (int): The starting age. For an ultimate table, this is the person's
                  current age. For a select table, this is the age at selection.
          n (int): The term of the pure endowment in years.

      Returns:
          float: The Expected Present Value (EPV) of the pure endowment.
      """
      # Step 1: Validate inputs
      self.validate_inputs(x, n)

      # Step 2: Get the number of lives at the start using our helper.
      # This returns l_[x] for select tables or l_x for ultimate tables.
      lx_start = self._get_lx(x, t=0)

      # If no one is alive at the start, the value is 0.
      if lx_start == 0:
          return 0.0

      # Step 3: Get the number of lives surviving after n years.
      # The helper automatically finds l_[x]+n or l_{x+n} from the correct table.
      lx_future = self._get_lx(x, t=n)

      # Step 4: Calculate survival probability and discount factor.
      n_px = lx_future / lx_start
      discount_factor = self.v ** n

      # Step 5: Calculate the final EPV.
      epv = discount_factor * n_px

      return epv


    ## ------------------------------------------------------------------
    ##  4: Endowment Insurance
    ## ------------------------------------------------------------------

    def endowment_insurance(self, x: int, n: int, m: int = 1) -> float:
        """Calculates the EPV of an n-year endowment insurance of 1.

        An n-year endowment insurance provides a payment of 1 at the end of the
        year of death if it occurs within n years, or a payment of 1 at the
        end of n years if the insured survives.

        This method correctly handles both select (A_{[x]:n|}) and ultimate (A_{x:n|})
        cases, as well as m-thly payments for the death benefit portion (using the
        UDD approximation).

        Formula: A_{x:n|} = (Term Insurance) + (Pure Endowment)

        Args:
            x (int): The starting age. For an ultimate table, this is the person's
                    current age. For a select table, this is the age at selection.
            n (int): The term of the policy in years.
            m (int): Number of payment periods per year for the death benefit.
                    Defaults to 1 (annual).

        Returns:
            float: The Expected Present Value (EPV) of the endowment insurance.
        """
        # Step 1: Validate inputs.
        self.validate_inputs(x, n)

        # Step 2: Calculate the pure endowment component
        pure_endowment_component = self.pure_endowment(x, n)

        # Step 3: Calculate the term insurance component
        # This handles both annual (m=1) and m-thly (m>1) cases automatically.
        term_insurance_component = self.term_insurance(x, n, m=m)

        print("--- endowment_insurance DEBUG ---")
        print(f"Pure Endowment Component: {pure_endowment_component}")
        print(f"Term Insurance Component: {term_insurance_component}")
        print("---------------------------------")


        # Step 4: The total EPV is simply the sum of the two components.
        return pure_endowment_component + term_insurance_component


    ## ------------------------------------------------------------------
    ##  5: Defferred Insurance
    ## ------------------------------------------------------------------

    def deferred_insurance(self, x: int, u: int, n: int = None, m: int = 1) -> float:
        """Calculates the EPV of a deferred insurance.

        This single method handles both deferred term and deferred whole life
        insurance based on whether the term 'n' is provided. It also handles
        m-thly payments using the UDD approximation.

        - If 'n' is provided: Calculates deferred term insurance (u|A_x:n|).
        - If 'n' is None: Calculates deferred whole life insurance (u|A_x).

        Args:
            x (int): The initial age (age at selection for select tables).
            u (int): The deferral period in years.
            n (int, optional): The term of insurance coverage after deferral.
                              If None, a whole life insurance is assumed. Defaults to None.
            m (int): Number of payment periods per year. Defaults to 1 (annual).

        Returns:
            float: The EPV of the deferred insurance.
        """
        # 1. Validate inputs. Note that 'n or 0' handles the case where n is None.
        self.validate_inputs(x, u + (n or 0))
        if u < 0: raise ValueError("Deferral period u must be non-negative.")

        # 2. Calculate the annual version of the deferred insurance (EPV_annual)
        # This section uses an if/else to decide which type of insurance to calculate.

        if n is not None:
            # --- Case 1: Deferred Term Insurance ---
            if n < 1: raise ValueError("Term n must be at least 1 year.")

            # We use the robust direct summation method
            lx_start = self._get_lx(x, t=0)
            if lx_start == 0: return 0.0

            k_range = np.arange(u, u + n)
            dx_values = [self._get_dx(x, k) for k in k_range]
            v_powers = self.v ** (k_range + 1)

            discounted_sum = np.sum(dx_values * v_powers)
            annual_epv = discounted_sum / lx_start

        else:
            # --- Case 2: Deferred Whole Life Insurance ---

            # Formula: uE_x * A_{x+u}
            # Calculate the pure endowment (actuarial discount factor)
            u_E_x = self.pure_endowment(x, u)

            # Calculate the whole life insurance value at the future age
            # Note: We always use the annual (m=1) value for the core calculation.
            future_insurance_val = self.whole_life_insurance(x + u, m=1)

            annual_epv = u_E_x * future_insurance_val

        # 3. Adjust for m-thly payments if necessary
        if m == 1:
            return annual_epv
        else:
            if self.i == 0: return annual_epv

            i_m_val = self.i_m.get(m)
            if i_m_val is None:
                i_m_val = m * ((1 + self.i)**(1/m) - 1)
                self.i_m[m] = i_m_val

            approximation_factor = self.i / i_m_val
            return annual_epv * approximation_factor



## ----------------------------------------------------------------------------------------------------------------------------------------------------------------
##                                                                          Annuities



    ## ----------------------------------------------------------------------------------------------------------------------------------------------------------------
    ##                                                                   Certain annuities
    ## -----------------------------------------------------------------------------------------------------------------------------------------------------------------

    ## ------------------------------------------------------------------
    ##  Helper: Annuity Certain
    ## ------------------------------------------------------------------

    def annuity_certain(self, n: int, m: int = 1, timing: str = 'due'):
        """
        Calculates the value of an annuity-certain.
        Handles annual (m=1) and m-thly (m>1) cases for both due and immediate.

        Args:
            n (int): The term of the annuity in years.
            m (int): Number of payments per year. Defaults to 1.
            timing (str): Type of annuity, 'due' or 'immediate'. Defaults to 'due'.

        Returns:
            float: The present value of the annuity-certain.
        """
        if n < 0:
            raise ValueError("Term n must be non-negative.")
        if n == 0:
            return 0.0
        if self.i == 0: # Edge case for zero interest
            return n if timing == 'due' else n # This is a simplification

        v_n = self.v ** n

        if timing == 'due':
            if m == 1: # annual due
                return (1 - v_n) / self.d
            else: # m-thly due
                return (1 - v_n) / self.d_m[m]
        elif timing == 'immediate':
            if m == 1: # annual immediate
                return (1 - v_n) / self.i
            else: # m-thly immediate
                return (1 - v_n) / self.i_m[m]
        else:
            raise ValueError("Timing must be 'due' or 'immediate'.")



## ----------------------------------------------------------------------------------------------------------------------------------------------------------------
##                                                                          Life annuities
## -----------------------------------------------------------------------------------------------------------------------------------------------------------------


    ## ------------------------------------------------------------------
    ##  1:  Whole Life Annuity
    ## ------------------------------------------------------------------

    def whole_life_annuity(self, x: int, m: int = 1, timing: str = 'due') -> float:
        """Calculates the EPV of a whole life annuity.

        This unified method handles annual (m=1) and m-thly (m>1) payments,
        as well as due (payment at start of period) and immediate (payment at end
        of period) timing. It automatically works for both ultimate (a_x) and
        select (a_[x]) lives by leveraging the whole_life_insurance method.

        Args:
            x (int): The starting age (age at selection for select tables).
            m (int): Number of payments per year. Defaults to 1 (annual).
            timing (str): 'due' or 'immediate'. Defaults to 'due'.

        Returns:
            float: The EPV of the whole life annuity.
        """
        # 1. Validate inputs
        self.validate_inputs(x)
        if timing not in ['due', 'immediate']:
            raise ValueError("Timing must be 'due' or 'immediate'.")
        if not isinstance(m, int) or m < 1 or (m > 1 and m not in self.d_m):
            raise ValueError(f"m ({m}) must be a valid integer for payment frequency.")
        if self.i == 0:
            raise ValueError("Formula is not applicable for zero interest rate.")

        # 2. Get the corresponding whole life insurance value.
        # This single call correctly gets A_x, A_[x], A_x^(m), or A_[x]^(m).
        insurance_epv = self.whole_life_insurance(x, m=m)

        # 3. Always calculate the annuity-DUE value first, based on the insurance value.
        if m == 1:
            # Annual case: a_due = (1 - A_annual) / d
            annuity_due_epv = (1 - insurance_epv) / self.d
        else:
            # m-thly case: a_due^(m) = (1 - A^(m)) / d^(m)
            annuity_due_epv = (1 - insurance_epv) / self.d_m[m]

        # 4. Adjust for timing if the request is for an immediate annuity.
        if timing == 'due':
            return annuity_due_epv
        else: # timing == 'immediate'
            # Apply the identity: a = a_due - (first_payment)
            return annuity_due_epv - (1 / m)


    ## ------------------------------------------------------------------
    ##  2: Term Annuity
    ## ------------------------------------------------------------------

    def term_annuity(self, x: int, n: int, m: int = 1, timing: str = 'due') -> float:
        """Calculates the EPV of an n-year term life annuity.

        This unified method handles annual (m=1) and m-thly (m>1) payments,
        as well as due and immediate timing. It automatically works for both
        ultimate (a_x:n|) and select (a_[x]:n|) lives by leveraging other
        refactored methods.

        Args:
            x (int): The starting age (age at selection for select tables).
            n (int): The term of the annuity in years.
            m (int): Number of payments per year. Defaults to 1 (annual).
            timing (str): 'due' or 'immediate'. Defaults to 'due'.

        Returns:
            float: The EPV of the term life annuity.
        """
        # 1. Validate inputs
        self.validate_inputs(x, n)
        if timing not in ['due', 'immediate']:
            raise ValueError("Timing must be 'due' or 'immediate'.")
        if not isinstance(m, int) or m < 1 or (m > 1 and m not in self.d_m):
            raise ValueError(f"m ({m}) must be a valid integer for payment frequency.")
        if self.i == 0 and 'due' in timing: # Due formula fails at i=0
            raise ValueError("Formula is not applicable for zero interest rate.")

        # 2. Get the corresponding endowment insurance value.
        # This single call handles select/ultimate and annual/m-thly.
        insurance_epv = self.endowment_insurance(x, n, m=m)

        # 3. Always calculate the annuity-DUE value first.
        if m == 1:
            annuity_due_epv = (1 - insurance_epv) / self.d
        else:
            annuity_due_epv = (1 - insurance_epv) / self.d_m[m]

        # 4. Adjust for timing if the request is for an immediate annuity.
        if timing == 'due':
            return annuity_due_epv
        else: # timing == 'immediate'
            # Get the pure endowment component for the adjustment
            n_E_x = self.pure_endowment(x, n)

            # Apply the identity: a_immediate = a_due - (1/m) * (1 - nEx)
            return annuity_due_epv - (1 / m) * (1 - n_E_x)


    ## ------------------------------------------------------------------
    ##  3: Deferred annuities
    ## ------------------------------------------------------------------

    def deferred_annuity(self, x: int, u: int, n: int = None, m: int = 1, timing: str = 'due') -> float:
        """Calculates the EPV of a deferred annuity using the difference method.

        This single method handles all deferred annuity cases:
        - Whole Life (n is None) vs. Term (n is specified).
        - Annual (m=1) vs. m-thly (m>1).
        - Due vs. Immediate timing.

        Formulas used:
        - Deferred Whole Life: a_x - a_x:u|
        - Deferred Term: a_x:u+n| - a_x:u|

        Args:
            x (int): The starting age (age at selection for select tables).
            u (int): The deferral period in years.
            n (int, optional): The term of the annuity after deferral.
                              If None, a whole life annuity is assumed.
            m (int): Number of payments per year. Defaults to 1.
            timing (str): 'due' or 'immediate'. Defaults to 'due'.

        Returns:
            float: The EPV of the deferred annuity.
        """
        # 1. Validate inputs
        self.validate_inputs(x, u + (n or 0))
        if u < 1:
            raise ValueError("Deferral period u must be at least 1 year.")

        # 2. Logic branches based on whether it's a whole life or term annuity
        if n is None:
            # --- Case 1: Deferred Whole Life Annuity (a_x - a_x:u|) ---

            # Calculate the whole life annuity from age x
            whole_life_annuity_val = self.whole_life_annuity(x, m=m, timing=timing)

            # Calculate the temporary annuity for the deferral period
            temp_annuity_val = self.term_annuity(x, u, m=m, timing=timing)

            return whole_life_annuity_val - temp_annuity_val

        else:
            # --- Case 2: Deferred Term Annuity (a_x:u+n| - a_x:u|) ---
            if n < 1:
                raise ValueError("Term n must be at least 1 year.")

            # Calculate the temporary annuity for the full period (u+n)
            long_temp_annuity_val = self.term_annuity(x, u + n, m=m, timing=timing)

            # Calculate the temporary annuity for the deferral period
            short_temp_annuity_val = self.term_annuity(x, u, m=m, timing=timing)

            return long_temp_annuity_val - short_temp_annuity_val

    ## ------------------------------------------------------------------
    ##  3: Guaranteed annuities
    ## ------------------------------------------------------------------

    def guaranteed_annuity(self, x: int, n_guarantee: int, k_term: int = None, m: int = 1, timing: str = 'due') -> float:
          """Calculates the EPV of an annuity with a guaranteed period.

          This unified method handles all variations:
          - Whole Life Guaranteed (k_term is None).
          - Term Guaranteed (k_term is an integer > n_guarantee).
          - Annual (m=1) vs. m-thly (m>1).
          - Due vs. Immediate timing.

          Formula: EPV = (Annuity Certain for n) + (Deferred Life Annuity after n)

          Args:
              x (int): The starting age (age at selection for select tables).
              n_guarantee (int): The guaranteed payment period in years.
              k_term (int, optional): The total term of the annuity. If provided,
                                      a term annuity is calculated. If None, a
                                      whole life annuity is assumed. Defaults to None.
              m (int): Number of payments per year. Defaults to 1.
              timing (str): 'due' or 'immediate'. Defaults to 'due'.

          Returns:
              float: The EPV of the guaranteed annuity.
          """
          # 1. Validate inputs
          if k_term is not None:
              self.validate_inputs(x, k_term)
              if k_term <= n_guarantee:
                  raise ValueError("Total term (k_term) must be greater than the guarantee period (n_guarantee).")
          else:
              self.validate_inputs(x, n_guarantee)

          # 2. Calculate Component 1: The n-year annuity-certain.
          # This part is paid regardless of survival.
          guaranteed_part = self.annuity_certain(n=n_guarantee, m=m, timing=timing)

          # 3. Calculate Component 2: The deferred life annuity.
          # This part is contingent on survival after n_guarantee years.
          # We use our unified 'deferred_annuity' method.

          if k_term is None:
              # Case A: Deferred WHOLE LIFE annuity
              # It's deferred for 'n_guarantee' years and runs for the rest of life.
              deferred_part = self.deferred_annuity(
                  x, u=n_guarantee, n=None, m=m, timing=timing
              )
          else:
              # Case B: Deferred TERM annuity
              # It's deferred for 'n_guarantee' years and runs for 'k_term - n_guarantee' years.
              deferred_term_length = k_term - n_guarantee
              deferred_part = self.deferred_annuity(
                  x, u=n_guarantee, n=deferred_term_length, m=m, timing=timing
              )

          # 4. The total EPV is the sum of the two components.
          return guaranteed_part + deferred_part


      ## ------------------------------------------------------------------
      ##  3: Guaranteed annuities
      ## ------------------------------------------------------------------

    def geometrically_increasing_annuity(self, x: int, j: float, n: int = None, m: int = 1, timing: str = 'due') -> float:
          """
          Calculates the EPV of a geometrically increasing life annuity.

          Payments increase at a rate 'j' per year. This method uses the adjusted
          interest rate (i*) approach for efficient calculation. It handles all
          variations: whole life/term, annual/m-thly, due/immediate, and
          select/ultimate tables.

          Args:
              x (int): The starting age (age at selection for select tables).
              j (float): The geometric rate of increase (e.g., 0.02 for 2%).
              n (int, optional): The term of the annuity. If None, a whole life
                                annuity is assumed. Defaults to None.
              m (int): Number of payments per year. Defaults to 1.
              timing (str): 'due' or 'immediate'. Defaults to 'due'.

          Returns:
              float: The EPV of the geometrically increasing annuity.
          """
          # 1. Validate inputs
          self.validate_inputs(x, n)
          if 1 + j <= 0:
              raise ValueError("1+j must be positive.")

          # 2. Calculate the adjusted interest rate, i*
          i_star = (self.i - j) / (1 + j)

          # 3. Create a temporary calculator instance with the new interest rate
          # It uses the original data and mode to replicate the life table setup.
          temp_calc = ActuarialCalculator(
              interest_rate=i_star,
              mode=self._original_mode,
              data=self._original_data,
              select_period=self.select_period,
              radix_age=self.ultimate_table.index.min(), # Get original radix age
              radix_lx=self.ultimate_table['lx'].iloc[0], # Get original radix lx
              max_age=self.max_age
          )

          # 4. Call the standard level annuity method on the temporary calculator
          if n is None:
              # Whole Life case
              return temp_calc.whole_life_annuity(x, m=m, timing=timing)
          else:
              # Term case
              return temp_calc.term_annuity(x, n, m=m, timing=timing)

In [32]:
#Example usage

# 1. Define the parameters based on the book's example
data = {
    'ultimate_params': {'A': 0.00022, 'B': 2.7e-6, 'c': 1.124},
    'select_params':   {'modifier':0.9}
}

# 2. Create an instance of the calculator
print("Creating calculator instance...")
calc = ActuarialCalculator(
    interest_rate=0.05,
    mode='parametric_makeham',
    data=data,
    select_period=2
)
print("Instance created.")

# 3. Trigger the cache creation
# The caches are created lazily (the first time they are needed).
# Calling whole_life_insurance will trigger the creation of both caches.
print("\nCalculating A_[20] to trigger cache generation...")
A_20_select = calc.whole_life_insurance(x=20)
print(f"Value of A_[20] is: {A_20_select:.5f}")


# 4. Now, you can access and view the cache attributes
print("\n\n--- Viewing Cache Contents ---")

# Check if the ultimate cache exists and print a sample
if hasattr(calc, '_ultimate_Ax_cache'):
    print("\n## Ultimate Ax Cache (Sample):")
    # Convert dict to Series for nice printing, show first 5 values
    ultimate_cache_series = pd.Series(calc._ultimate_Ax_cache)
    print(ultimate_cache_series.head())
else:
    print("\nUltimate cache not created yet.")

# Check if the select cache exists and print a sample
if hasattr(calc, '_select_Ax_cache'):
    print("\n## Select Ax Cache (Sample):")
    print(calc._select_Ax_cache.head())
else:
    print("\nSelect cache not created yet.")

Creating calculator instance...
Instance created.

Calculating A_[20] to trigger cache generation...
Value of A_[20] is: 0.04918


--- Viewing Cache Contents ---

## Ultimate Ax Cache (Sample):
120    0.952381
119    0.950436
118    0.949516
117    0.948406
116    0.947033
dtype: float64

## Select Ax Cache (Sample):
20    0.049175
21    0.051399
22    0.053731
23    0.056176
24    0.058740
dtype: float64


In [33]:
ages = [20, 40, 60, 80]

Example_usage = { 'age' : ages,
                 'ax_due': [float(calc.whole_life_annuity(x=age, m=1, timing='due')) for age in ages ] ,
                 'ax_im' : [float(calc.whole_life_annuity(x=age, m=1, timing='immediate')) for age in ages ],
                 'ax(4)_due':[float(calc.whole_life_annuity(x=age, m=4, timing='due')) for age in ages ],
                 'ax(4)_im':[float(calc.whole_life_annuity(x=age, m=4, timing='immediate')) for age in ages ]}

Example_usage_table = pd.DataFrame(Example_usage)
print(Example_usage_table.set_index('age'))

        ax_due      ax_im  ax(4)_due   ax(4)_im
age                                            
20   19.967315  18.967315  19.588312  19.338312
40   18.459564  17.459564  18.080279  17.830279
60   14.913398  13.913398  14.533454  14.283454
80    8.597221   7.597221   8.216103   7.966103
