In [15]:
import sys
sys.path.insert(0, "submodules/dc-egm/src/")
import numpy as np

In [2]:
from dcegm.solve import get_solve_function

In [8]:
options_test = {
     # mandatory
    "n_periods": 5,
    "n_discrete_choices": 3,
    "n_exog_states": 1,
    # custom
    "belief_update_increment": 0.05,
    "n_policy_states": 3,
    "resolution_period": 3,
}

params_dict_test = {
    "mu": 0.5,
    "delta": 4
}

# State space - No need to be jax compatible

In [21]:
def create_state_space(options):
    n_periods = options["n_periods"]
    n_choices = options["n_discrete_choices"]
    n_exog_states = options["n_exog_states"]
    n_policy_states = options["n_policy_states"]
    resolution_period = options["resolution_period"]

    # shape = (n_periods, n_choices, n_exog_states)
    state_space = []

    for period in range(n_periods):
        for choice in range(n_choices):
            for exp in range(period + 1):
                for policy_state in range(period + 1):
                    age = 25 + period
                    if period > resolution_period:
                        continue
                    # You can not retire before the resolution period
                    elif (period < resolution_period) & (choice == 2):
                        continue
                    # If you have not worked last period, you can't have worked all your live 
                    elif (choice == 0) & (period == exp):
                        continue
                    else:
                        state_space += [[period, choice, exp, policy_state, 0]]

    return np.array(state_space)
                
            

In [22]:
state_space = create_state_space(options_test)
period_to_inspect = 1

state_space[state_space[:, 0] == period_to_inspect]

array([[1, 0, 0, 0, 0],
       [1, 0, 0, 1, 0],
       [1, 0, 1, 0, 0],
       [1, 0, 1, 1, 0],
       [1, 1, 0, 0, 0],
       [1, 1, 0, 1, 0]])

# Utility funcs

In [22]:
def utility_func(consumption, choice, params_dict):
    mu = params_dict["mu"]
    delta = params_dict["delta"]
    is_working = choice == 1
    utility = consumption ** (1- mu) / (1 - mu) - delta * is_working
    return utility


def marg_utility(consumption, params_dict):
    mu = params_dict["mu"]
    marg_util = consumption ** -mu
    return marg_util


def inverse_marginal(marginal_utility, params_dict):
    mu = params_dict["mu"]
    return marginal_utility ** (-1/mu)


utility_functions = {
        "utility": utility_func,
        "inverse_marginal_utility": inverse_marginal,
        "marginal_utility": marg_utility,
    }
    

# Budget constraint

In [None]:
def budget_constraint(state_beginning_of_period, savings_end_of_previous_period, income_shock_previous_period, params, options):
    

In [3]:
get_solve_function?

[0;31mSignature:[0m
[0mget_solve_function[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0moptions[0m[0;34m:[0m [0mDict[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mint[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mexog_savings_grid[0m[0;34m:[0m [0mjax[0m[0;34m.[0m[0mArray[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mutility_functions[0m[0;34m:[0m [0mDict[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mCallable[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbudget_constraint[0m[0;34m:[0m [0mCallable[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mstate_space_functions[0m[0;34m:[0m [0mDict[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mCallable[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfinal_period_solution[0m[0;34m:[0m [0mCallable[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtransition_function[0m[0;34m:[0m [0mCallable[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0mCallable[0m[0;34m[0m[0;34m[0m[0m
