# DC-EGM med Max's Model

## load packages and dependencies

In [30]:
#Dependencies
import jax.numpy as jnp
import numpy as np
from dcegm.solve import get_solve_func_for_model
from dcegm.pre_processing.setup_model import setup_model

from typing import Tuple

import pandas as pd
import matplotlib.pyplot as plt



from dcegm.sim_interface import get_sol_and_sim_func_for_model
from dcegm.simulation.simulate import simulate_all_periods
from dcegm.simulation.sim_utils import create_simulation_df

## Initiate parameters - to be estimated

In [31]:
params = {}
params["interest_rate"] = 0.02
params["sigma"] = 1
params["lambda"] = 1
params["beta"] = 0.96
params["rho"]=0.79488
params["gamma"]=jnp.array([0.005,0.015,0.02])
params["beta0"]=0.5
params["beta1"]=0.003
params["beta2"]=-0.00003
params["kappa1"]=0.05083
params["kappa2"]=0.00008

## Options for model - Choices and states

In [32]:
options = {
    "model_params": {
        #"quadrature_points_stochastic": 5, 
        "n_quad_points_stochastic": 5,
        "hours": jnp.array([0,10,15,35]), #list 
        "min_age": 30,
    },
    "state_space": {
        "n_periods": 45,
        "choices": np.arange(4), # 4 choices
        "continuous_states": {
            "wealth": np.linspace(0, 50, 100),
        },
    },
}

## Defining utility

In [33]:
def flow_util(consumption, choice, params, period, options):
    rho = params["rho"]
    age = options["start_age"] + period
    hours = options["hours"][choice]
    gamma = params["gamma"][choice-1] #remove the first element

    # for disutility of work
    working = choice > 0
    age_1=age < 35
    age_2=age > 45
    
    disutil = jnp.where(working, 1, 0) * (1 + (params["kappa1"] * age) * jnp.where(age_1,1,0) + (params["kappa2"] * age**2) * jnp.where(age_2,1,0)) * (hours*gamma)

    u = consumption ** (1 - rho) / (1 - rho) - disutil #working*gamma*hours*(1+(kappa1*age)*age_1+(kappa2*age_2)**2*age_2) #jax.lax.select(working, gamma, 0) - if a NaN included
    return u

def marginal_utility(consumption, params):
    rho = params["rho"]
    u_prime = consumption ** (-rho)
    return u_prime


def inverse_marginal_utility(marginal_utility, params):
    rho = params["rho"]
    return marginal_utility ** (-1 / rho)


utility_functions = {
    "utility": flow_util,
    "inverse_marginal_utility": inverse_marginal_utility,
    "marginal_utility": marginal_utility,
}


## Budget constraint

Need to define the budget constraint, including the old age benefits and the selfpaid pensions.
Suspect error in applying the choice, since no-one will work in the model (probably super low income?)

Hvilke units på penge?

In [47]:
def budget_dcegm(
    lagged_choice,
    savings_end_of_previous_period,
    income_shock_previous_period,
    params,
    options,
    period,
):  # noqa: 100
    interest_factor = 1 + params["interest_rate"]
    age = options["min_age"] + period
    
    wage = (jnp.exp(params["beta0"] + params["beta1"] * age + params["beta2"] * age ** 2) + income_shock_previous_period)
    hours = options["hours"][lagged_choice]
    threshold = 6.0
    base_rate = 0.37
    top_rate = 0.50
    tax = jnp.where(wage <= threshold, base_rate, top_rate)
    oldage = jnp.where((age > 67) & (hours == 0), 1.4, 0)
    #tax = 0.4
    resource = (
        interest_factor * savings_end_of_previous_period
        + wage*hours*tax + oldage*tax #include pensions
    )

    return jnp.maximum(resource, 0.5), aux_dict


In [35]:
def budget_with_aux(
    period,
    lagged_choice,
    savings_end_of_previous_period,
    income_shock_previous_period,
    options,
    params,
):
    wealth, shock, income = budget_constraint_raw(
        period,
        lagged_choice,
        savings_end_of_previous_period,
        income_shock_previous_period,
        options,
        params,
    )
    aux_dict = {
        "income": income,
    }
    return wealth, aux_dict


In [43]:
def budget_constraint_raw(
    period,
    lagged_choice,
    savings_end_of_previous_period,
    income_shock_previous_period,
    options,
    params,
):
    # Calculate stochastic labor income
    income_from_previous_period = _calc_stochastic_income(
        period=period,
        lagged_choice=lagged_choice,
        wage_shock=income_shock_previous_period,
        min_age=options["min_age"],
        constant=params["beta0"],
        exp=params["beta1"],
        exp_squared=params["beta2"],
    )

    wealth_beginning_of_period = (
        income_from_previous_period
        + (1 + params["interest_rate"]) * savings_end_of_previous_period
    )

    # Retirement safety net, only in retirement model, but we require to have it always
    # as a parameter
    wealth_beginning_of_period = jnp.maximum(
        wealth_beginning_of_period, 0.5 # lowest amount possible 
    )

    return (
        wealth_beginning_of_period,
        income_shock_previous_period,
        income_from_previous_period,
    )

In [44]:
def _calc_stochastic_income(
    period,
    lagged_choice,
    wage_shock,
    min_age,
    constant,
    exp,
    exp_squared,
):
    # For simplicity, assume current_age - min_age = experience
    age = period + min_age

    # Determinisctic component of income depending on experience:
    # constant + alpha_1 * age + alpha_2 * age**2
    exp_coeffs = jnp.array([constant, exp, exp_squared])
    labor_income = exp_coeffs @ (age ** jnp.arange(len(exp_coeffs)))
    working_income = jnp.exp(labor_income + wage_shock)

    return (1 - lagged_choice) * working_income

## Final period util - for solving model

In [48]:
def final_period_utility(wealth: float, choice: int, params, period, options) -> Tuple[float, float]:
    return flow_util(wealth, choice, params, period, options)


def marginal_final(wealth):
    return marginal_utility(wealth, params)


final_period_utility = {
    "utility": final_period_utility,
    "marginal_utility": marginal_final,
}

## Solve the model

In [46]:
def test_sim_and_sol_model(options):
    
    model_with_aux = setup_model(
        options=options,
        utility_functions=utility_functions,
        utility_functions_final_period=final_period_utility,
        budget_constraint=budget_with_aux,
    )

    
    n_agents = 1_000

    states_initial = {
        "period": jnp.zeros(n_agents, dtype=jnp.int32),       # Every individual mins at period 0 (age 30)
        "lagged_choice": jnp.full(n_agents, 3, dtype=jnp.int32),  # Every individual starts with choice 3 (work fulltime)
    }
    n_periods = options["state_space"]["n_periods"]
    seed = 132


    df_aux = create_simulation_df(output_dict_aux["sim_dict"])
    # # First check that income is in df_aux columns
    #assert "income" in df_aux.columns

    # Now drop the column and check that the rest is exactly the same
    #df_aux = df_aux.drop(columns=["income"])
    #assert df_aux.equals(df_without_aux)

test_sim_and_sol_model(options)

State specific choice set not provided. Assume all choices are available in every state.
Update function for state space not given. Assume states only change with an increase of the period and lagged choice.
Sparsity condition not provided. Assume all states are valid.
Starting state space creation
State space created.

Starting state-choice space creation and child state mapping.
State, state-choice and child state mapping created.

Start creating batches for the model.
The batch size of the backwards induction is  16
Model setup complete.



  value_solved = jnp.full(
  policy_solved = jnp.full(
  endog_grid_solved = jnp.full(


TypeError: only integer scalar arrays can be converted to a scalar index

In [49]:
model = setup_model(
    options=options,
    utility_functions=utility_functions,
    utility_functions_final_period=final_period_utility,
    budget_constraint=budget_dcegm,
)

State specific choice set not provided. Assume all choices are available in every state.
Update function for state space not given. Assume states only change with an increase of the period and lagged choice.
Sparsity condition not provided. Assume all states are valid.
Starting state space creation
State space created.

Starting state-choice space creation and child state mapping.
State, state-choice and child state mapping created.

Start creating batches for the model.
The batch size of the backwards induction is  16
Model setup complete.



In [50]:
solve_func = get_solve_func_for_model(model)

In [51]:
value_solved, policy_solved, endog_grid_solved = solve_func(params)
print(value_solved)

  value_solved = jnp.full(
  policy_solved = jnp.full(
  endog_grid_solved = jnp.full(


TypeError: only integer scalar arrays can be converted to a scalar index

## Initial values for simulating. 

In [11]:
# Select number of individuals for simulation
n_individuals = 5000

# set initial states for each individual
states_initial = {
    "period": jnp.zeros(n_individuals, dtype=jnp.int32),       # Every individual starts at period 0 (age 30)
    "lagged_choice": jnp.full(n_individuals, 3, dtype=jnp.int32),  # Every individual starts with choice 3 (work fulltime)
}

# Set wealth at beginning of period, which is the starting wealth for every individual. 
wealth_initial = jnp.full(n_individuals, 3.0)   # Every individual starts with 300k wealth - to be adjusted based on the actual moments

## Simulate model

In [12]:
simulation = simulate_all_periods(states_initial=states_initial,
            wealth_initial=wealth_initial,
            n_periods=45,
            params=params,
            seed=123,
            endog_grid_solved=endog_grid_solved,
            policy_solved=policy_solved,
            value_solved=value_solved,
            model=model)

## Convert simulation results to DataFrame

In [20]:
df = create_simulation_df(simulation)

#print to excel
#df.to_excel("/Users/frederiklarsen/Library/Mobile Documents/com~apple~CloudDocs/KU/Speciale/Resultater/simulation_output.xlsx")
print(df)

ModuleNotFoundError: No module named 'openpyxl'

## Load moments generated from data

vægt = varians af empiriske momenter'
hent varians for momenter hjem


W = var(datamoment)^-1
w_qq = 1/var(x)/n

In [17]:
df_edu_1 = pd.read_csv(r"/Users/frederiklarsen/Downloads/Transfer_703047_090425/moments_udd1.txt")
#rename alder to age
df_edu_1.rename(columns={"ALDER": "age"}, inplace=True)
df_edu_2 = pd.read_csv(r"/Users/frederiklarsen/Downloads/Transfer_703047_090425/moments_udd2.txt")
#rename alder to age
df_edu_2.rename(columns={"ALDER": "age"}, inplace=True)
df_edu_3 = pd.read_csv(r"/Users/frederiklarsen/Downloads/Transfer_703047_090425/moments_udd3.txt")
#rename alder to age
df_edu_3.rename(columns={"ALDER": "age"}, inplace=True)


# print head
print(df_edu_1.head())
print(df_edu_2.head())
print(df_edu_3.head())

   age  prob_work   hours_0   hours_1   hours_2   hours_3  avg_wealth  \
0   30   0.693223  0.306777  0.097603  0.111916  0.483704    0.666128   
1   31   0.691881  0.308119  0.087089  0.099696  0.505097    0.738634   
2   32   0.704009  0.295991  0.083747  0.099067  0.521194    0.764240   
3   33   0.715057  0.284943  0.076142  0.103338  0.535577    0.839890   
4   34   0.720249  0.279751  0.072010  0.096433  0.551806    0.897418   

   work_work  nowork_nowork  avg_wage    avg_hours  var_wage  skew_wage  pens  
0        NaN            NaN  3.160401  1646.132913  2.266921   5.501251   NaN  
1   0.638971       0.248335  3.332977  1679.738401  2.705283   6.952155   NaN  
2   0.645700       0.243295  3.446961  1690.752006  2.522834   3.530664   NaN  
3   0.664658       0.233028  3.561403  1705.418763  3.038437   5.679844   NaN  
4   0.674020       0.232027  3.696943  1719.543811  3.370498   6.299609   NaN  
   age  prob_work   hours_0   hours_1   hours_2   hours_3  avg_wealth  \
0   30  

## Plots

In [None]:
# plot the wealth over age for df_edu_1
for var in df_edu_1.columns:
    plt.figure(figsize=(10, 6))
    plt.plot(df_edu_1["age"], df_edu_1[var], label=" ", color="blue")
    plt.title(" ")
    plt.xlabel("Age")
    # give ylabel to clumn name
    plt.ylabel(var)
    plt.legend()
    plt.grid() 

plt.show()



Traceback (most recent call last):
  File "/Users/frederiklarsen/.vscode/extensions/ms-python.python-2025.0.0-darwin-arm64/python_files/python_server.py", line 133, in exec_user_input
    retval = callable_(user_input, user_globals)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 12, in <module>
  File "/opt/anaconda3/envs/.conda_env/lib/python3.11/site-packages/matplotlib/pyplot.py", line 614, in show
    return _get_backend_mod().show(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/.conda_env/lib/python3.11/site-packages/matplotlib/backend_bases.py", line 3547, in show
    cls.mainloop()
  File "/opt/anaconda3/envs/.conda_env/lib/python3.11/site-packages/matplotlib/backends/backend_macosx.py", line 179, in start_main_loop
    with _allow_interrupt_macos():
  File "/opt/anaconda3/envs/.conda_env/lib/python3.11/contextlib.py", line 144, in __exit__
    next(self.gen)
  File "/opt/anaconda3/envs/.conda_env/lib/pyt

In [None]:
#plot wealth over time
plt.figure(figsize=(10, 6))
plt.plot(df["period"], df["wealth_beginning_of_period"], label="Wealth", color='orange')
plt.xlabel("Period")
plt.ylabel("Wealth")
plt.title("Wealth Over Time")
plt.legend()
plt.grid()
plt.savefig("/Users/frederiklarsen/Library/Mobile Documents/com~apple~CloudDocs/KU/Speciale/figurer/wealth_over_time.png")
plt.show()

Traceback (most recent call last):
  File "/opt/anaconda3/envs/.conda_env/lib/python3.11/site-packages/pandas/core/indexes/base.py", line 3805, in get_loc
    return self._engine.get_loc(casted_key)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "index.pyx", line 167, in pandas._libs.index.IndexEngine.get_loc
  File "index.pyx", line 196, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/hashtable_class_helper.pxi", line 7081, in pandas._libs.hashtable.PyObjectHashTable.get_item
  File "pandas/_libs/hashtable_class_helper.pxi", line 7089, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 'period'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/frederiklarsen/.vscode/extensions/ms-python.python-2025.0.0-darwin-arm64/python_files/python_server.py", line 133, in exec_user_input
    retval = callable_(user_input, user_globals)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<stri