In [20]:
import sys
sys.path.insert(0, "submodules/dc-egm/src/")
import numpy as np
import jax.numpy as jnp

In [2]:
from dcegm.solve import get_solve_function

In [29]:
options_test = {
     # mandatory
    "n_periods": 50, # 25 + 50 = 75
    "n_discrete_choices": 3,
    "n_exog_states": 1,
    "quadrature_points_stochastic": 5,
    # custom: model structure
    "start_age": 25,
    "resolution_age": 60,
    # custom: policy environment
    "minimum_SRA": 67, 
    "maximum_retirement_age": 72,
    "unemployment_benefits": 5,
    "pension_point_value": 0.3,
    "early_retirement_penalty": 0.036,
    # custom: params estimated outside model
    "belief_update_increment": 0.05,
    "gamma_0": 10,
    "gamma_1": 1,
    "gamma_2": -0.1,
    "interest_rate": 0.03
}

params_dict_test = {
    "mu": 0.5,
    "delta": 4 
}

# State space - No need to be jax compatible

In [18]:
def create_state_space(options):
    n_periods = options["n_periods"]
    n_choices = options["n_discrete_choices"]
    n_exog_states = options["n_exog_states"]
    resolution_age = options["resolution_age"]
    start_age = options["start_age"]
    belief_update_increment = options["belief_update_increment"]
    
    # The highest policy state, we consider belongs to the expectation of the youngest.
    n_policy_states = resolution_age - start_age

    # minimum retirement age is 4 years before the lowest statutory ret age 
    min_ret_age = options["minimum_SRA"] - 4
    # maximum (conceivable) retirement age is given by lowest SRA plus the projection of the youngest
    max_ret_age = options["maximum_retirement_age"]
    # number of possible actual retirement ages
    n_ret_ages = max_ret_age - min_ret_age + 1
    
    # shape = (n_periods, n_choices, n_exog_states)
    state_space = []
    
    shape = (n_periods, n_choices, n_periods, n_policy_states, n_ret_ages, 1)
    
    map_state_to_index = np.full(shape, fill_value=-9999, dtype=np.int64)
    i = 0

    for period in range(n_periods):
        for lag_choice in range(n_choices):
            # You cannot have more experience than your age
            for exp in range(period + 1):
                # The policy state we need to consider increases by one increment per period.
                for policy_state in range(period + 1):
                    for actual_retirement_id in range(n_ret_ages):
                        age = start_age + period
                        actual_retirement_age = min_ret_age + actual_retirement_id
                        # You cannot retire before the earliest retirement age
                        if (age <= min_ret_age) & (lag_choice == 2):
                            continue
                        # After the maximum retirement age, you must be retired
                        elif (age > max_ret_age) & (lag_choice != 2):
                            continue
                        # If you weren't retired last period, your actual retirement age is kept at minimum
                        elif (lag_choice != 2) & (actual_retirement_id > 0):
                            continue
                        # If you are retired, your actual retirement age can at most be your current age    
                        elif (lag_choice == 2) & (age <= actual_retirement_age):
                            continue
                        # Starting from resolution age, there is no more adding of policy states. 
                        elif policy_state > n_policy_states - 1:
                            continue
                        # If you have not worked last period, you can't have worked all your live 
                        elif (lag_choice != 1) & (period == exp) & (period > 0):
                            continue
                        else:
                            state_space += [[period, lag_choice, exp, policy_state, actual_retirement_age, 0]]
                            map_state_to_index[period, lag_choice, exp, policy_state, actual_retirement_id, 0] = i
                            i += 1

    return np.array(state_space), map_state_to_index


def get_choice_set(state, map_state_to_index):
    #Todo: replace hard coded parameters
    #Todo: everything dependent on policy state (e.g. can only retire 4 years before SRA)
    
    # if you're younger than min SRA, you cannot retire
    if state[0] < 63 - 25:
        return np.array([0, 1])
     # After the maximum retirement age, you must be retired
    elif state[0] > 73 - 25:
        return np.array([2])
    # retirement is absorbing
    elif state[1] == 2:
        return np.array([2])
    else:
        return np.array([0, 1, 2])
    
def update_state(state, choice):
    state_next = state.copy()
    
    # age increases by one
    state_next[0] += 1
    
    # Set choice as lag choice in next state
    state_next[1] = choice
    
    # experience increases by one if working
    if choice == 1:
        state_next[2] += 1
    
    # expected SRA increases by one increment
    if state[0] + 25 < 60:
        state_next[3] += 1
        
    return state_next

state_space_functions = {
        "create_state_space": create_state_space,
        "get_state_specific_choice_set": get_choice_set,
        "update_endog_state_by_state_and_choice": update_state,
    }
            

# Utility funcs

In [15]:
def utility_func(consumption, choice, params_dict):
    mu = params_dict["mu"]
    delta = params_dict["delta"]
    is_working = choice == 1
    utility = consumption ** (1- mu) / (1 - mu) - delta * is_working
    return utility


def marg_utility(consumption, params_dict):
    mu = params_dict["mu"]
    marg_util = consumption ** -mu
    return marg_util


def inverse_marginal(marginal_utility, params_dict):
    mu = params_dict["mu"]
    return marginal_utility ** (-1/mu)


utility_functions = {
        "utility": utility_func,
        "inverse_marginal_utility": inverse_marginal,
        "marginal_utility": marg_utility,
    }
    

# Last Period Utility (e.g. Bequest)

In [17]:
def solve_final_period_scalar(
    state_vec,
    choice,
    begin_of_period_resources,
    params,
    options,
    compute_utility,
    compute_marginal_utility,
):
    """Compute optimal consumption policy and value function in the final period.

    In the last period, everything is consumed, i.e. consumption = savings.

    Args:
        state (np.ndarray): 1d array of shape (n_state_variables,) containing the
            period-specific state vector.
        choice (int): The agent's choice in the current period.
        begin_of_period_resources (float): The agent's begin of period resources.
        compute_utility (callable): Function for computation of agent's utility.
        compute_marginal_utility (callable): Function for computation of agent's
        params (dict): Dictionary of model parameters.
        options (dict): Options dictionary.

    Returns:
        tuple:

        - consumption (float): The agent's consumption in the final period.
        - value (float): The agent's value in the final period.
        - marginal_utility (float): The agent's marginal utility .

    """
    
    # eat everything
    consumption = begin_of_period_resources
    
    # utility & marginal utility of eating everything
    value = compute_utility(consumption=begin_of_period_resources, choice=choice, params_dict=params)
    
    marginal_utility = compute_marginal_utility(
        consumption=begin_of_period_resources, params_dict=params
    )

    return marginal_utility, value, consumption

# Budget Equation

In [6]:
def budget_constraint(state_beginning_of_period, # s_t, with d_{t-1} at s_t[1]
                      savings_end_of_previous_period, # A_{t-1}
                      income_shock_previous_period, # epsilon_{t - 1}
                      params,
                      options):
    
    
    # fetch necessary parameters (gammas for wage, lambdas & ERP for pension)
    gamma_0 = options["gamma_0"]
    gamma_1 = options["gamma_1"]
    gamma_2 = options["gamma_2"]
    lambd = options["pension_point_value"]
    ERP = options["early_retirement_penalty"]
    
    # read out state
    age = state_beginning_of_period[0]
    lag_choice = state_beginning_of_period[1]
    experience = state_beginning_of_period[2]
    SRA_at_resolution = options["minimum_SRA"] + state_beginning_of_period[3]*options["belief_update_increment"]
    actual_retirement_age = 68 # B: das hier fehlt noch im state space
    
    # calculate applicable SRA and pension deduction/increase factor 
    # (malus for early retirement, bonus for late retirement)
    
    pension_factor = 1 - (actual_retirement_age - SRA_at_resolution)*ERP 
    
    # decision bools
    is_unemployed = lag_choice==0 
    is_worker = lag_choice==1
    is_retired = lag_choice==2
    
    # decision-specific income
    unemployment_benefits = options["unemployment_benefits"]
    labor_income = gamma_0 + gamma_1*experience + gamma_2*experience**2 + income_shock_previous_period 
    retirement_income = lambd*experience*pension_factor 
    
    income = is_unemployed * unemployment_benefits + is_worker * labor_income + is_retired * retirement_income
    
    # calculate beginning of period wealth M_t
    wealth = (1 + options["interest_rate"]) * savings_end_of_previous_period + income
    
    return wealth
    

# State Space and Budget Tests

In [7]:
# state-space test

state_space, indexer = create_state_space(options_test)


budget_constraint(state_beginning_of_period=state_space[47226,:], # s_t, with d_{t-1} at s_t[1]
                      savings_end_of_previous_period=10, # A_{t-1}
                      income_shock_previous_period=0.5, # epsilon_{t - 1}
                      params=params_dict_test,
                      options=options_test)


13.219000000000001

In [8]:
indexer[35,2,20,10,1,0]

-9999

In [9]:
state_space[61043,:]


array([42,  2, 11,  0, 65,  0])

In [10]:
options_test

{'n_periods': 50,
 'n_discrete_choices': 3,
 'n_exog_states': 1,
 'start_age': 25,
 'resolution_age': 60,
 'minimum_SRA': 67,
 'maximum_retirement_age': 72,
 'unemployment_benefits': 5,
 'pension_point_value': 0.3,
 'early_retirement_penalty': 0.036,
 'belief_update_increment': 0.05,
 'gamma_0': 10,
 'gamma_1': 1,
 'gamma_2': -0.1,
 'interest_rate': 0.03}

# Dumm exog process

In [11]:
def dummy_exog(state, params_dict):
    return np.array([1])

In [12]:
state_space, map_state_to_index = create_state_space(options_test)
period_to_inspect = 0

state_space[state_space[:, 0] == period_to_inspect]

array([[ 0,  0,  0,  0, 63,  0],
       [ 0,  1,  0,  0, 63,  0]])

# Exogenous Savings Grid

In [25]:
savings_grid=jnp.arange(start=0,stop=100,step=0.5)

# Call DCEGM

In [26]:
get_solve_function?

In [30]:
get_solve_function(options=options_test, 
                  exog_savings_grid=savings_grid,
                  utility_functions=utility_functions,
                  budget_constraint=budget_constraint,
                  state_space_functions=state_space_functions,
                  final_period_solution=solve_final_period_scalar,
                  transition_function=dummy_exog)

ValueError: too many values to unpack (expected 3)