Skip to content

Commit

Permalink
Extend dcegm interface (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxBlesch committed May 27, 2024
1 parent 7355ba6 commit 6694e41
Show file tree
Hide file tree
Showing 14 changed files with 636 additions and 621 deletions.
223 changes: 25 additions & 198 deletions src/dcegm/egm/interpolate_marginal_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
from typing import Dict
from typing import Tuple

import numpy as np
from dcegm.interpolation import get_index_high_and_low
from dcegm.interpolation import linear_interpolation_formula
from dcegm.interpolation import interp_value_and_policy_on_wealth
from jax import numpy as jnp
from jax import vmap


def interpolate_value_and_calc_marginal_utility(
def interpolate_value_and_marg_utility_on_next_period_wealth(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
Expand All @@ -19,7 +17,7 @@ def interpolate_value_and_calc_marginal_utility(
value_child_state_choice: jnp.ndarray,
params: Dict[str, float],
) -> Tuple[float, float]:
"""Interpolate marginal utilities.
"""Interpolate value and policy in the child states and compute the marginal value.
Args:
compute_marginal_utility (callable): User-defined function to compute the
Expand Down Expand Up @@ -51,198 +49,27 @@ def interpolate_value_and_calc_marginal_utility(
containing the interpolated value function.
"""
# For all choices, the wealth is the same in the solution
ind_high, ind_low = get_index_high_and_low(
x=endog_grid_child_state_choice, x_new=wealth_beginning_of_period
)
marg_utils, value_interp = vmap(
vmap(
_interpolate_value_and_marg_util,
in_axes=(0, 0, 0, 0, 0, 0, 0, None, None, None, None, None, None),
),
in_axes=(0, 0, 0, 0, 0, 0, 0, None, None, None, None, None, None),
)(
jnp.take(policy_child_state_choice, ind_high),
value_child_state_choice[ind_high],
endog_grid_child_state_choice[ind_high],
jnp.take(policy_child_state_choice, ind_low),
value_child_state_choice[ind_low],
endog_grid_child_state_choice[ind_low],
wealth_beginning_of_period,
compute_utility,
compute_marginal_utility,
endog_grid_child_state_choice[1],
value_child_state_choice[0],
state_choice_vec,
params,
)

return marg_utils, value_interp


def _interpolate_value_and_marg_util(
policy_high: float | jnp.ndarray,
value_high: float | jnp.ndarray,
wealth_high: float | jnp.ndarray,
policy_low: float | jnp.ndarray,
value_low: float | jnp.ndarray,
wealth_low: float | jnp.ndarray,
new_wealth: float | jnp.ndarray,
compute_utility: Callable,
compute_marginal_utility: Callable,
endog_grid_min: float | jnp.ndarray,
value_at_zero_wealth: float | jnp.ndarray,
state_choice_vec: Dict[str, int],
params: Dict[str, float],
) -> Tuple[float, float]:
"""Calculate interpolated marginal utility and value function.
Args:
policy_high (float): Policy function value at the higher end of the
interpolation interval.
value_high (float): Value function value at the higher end of the
interpolation interval.
wealth_high (float): Endogenous wealth grid value at the higher end of the
interpolation interval.
policy_low (float): Policy function value at the lower end of the
interpolation interval.
value_low (float): Value function value at the lower end of the
interpolation interval.
wealth_low (float): Endogenous wealth grid value at the lower end of the
interpolation interval.
new_wealth (float): New endogenous wealth grid value.
compute_marginal_utility (callable): Function for calculating the marginal
utility from consumption level. The input ```params``` is already
partialled in.
endog_grid_min (float): Minimum endogenous wealth grid value.
value_min (float): Minimum value function value.
choice (int): Discrete choice of an agent.
params (dict): Dictionary containing the model parameters.
Returns:
tuple:
- marg_util_interp (float): Interpolated marginal utility function.
- value_interp (float): Interpolated value function.
"""
policy_interp = linear_interpolation_formula(
y_high=policy_high,
y_low=policy_low,
x_high=wealth_high,
x_low=wealth_low,
x_new=new_wealth,
)

value_interp = interp_value_and_check_creditconstraint(
value_high=value_high,
wealth_high=wealth_high,
value_low=value_low,
wealth_low=wealth_low,
new_wealth=new_wealth,
compute_utility=compute_utility,
endog_grid_min=endog_grid_min,
value_at_zero_wealth=value_at_zero_wealth,
state_choice_vec=state_choice_vec,
params=params,
)

marg_utility_interp = compute_marginal_utility(
consumption=policy_interp, params=params, **state_choice_vec
)

return marg_utility_interp, value_interp

# Generate interpolation function for single wealth point where the endogenous grid,
# policy and value are fixed.
def interp_on_single_wealth(wealth):
policy_interp, value_interp = interp_value_and_policy_on_wealth(
wealth=wealth,
endog_grid=endog_grid_child_state_choice,
policy=policy_child_state_choice,
value=value_child_state_choice,
compute_utility=compute_utility,
state_choice_vec=state_choice_vec,
params=params,
)
marg_utility_interp = compute_marginal_utility(
consumption=policy_interp, params=params, **state_choice_vec
)
return marg_utility_interp, value_interp

# Generate vectorized function for savings and income shock dimension
vector_interp_func = vmap(vmap(interp_on_single_wealth))

marg_utils, value_interp = vector_interp_func(wealth_beginning_of_period)

def interp_value_and_check_creditconstraint(
value_high: float | jnp.ndarray,
wealth_high: float | jnp.ndarray,
value_low: float | jnp.ndarray,
wealth_low: float | jnp.ndarray,
new_wealth: float | jnp.ndarray,
compute_utility: Callable,
endog_grid_min: float,
value_at_zero_wealth: float,
state_choice_vec: Dict[str, int],
params: Dict[str, float],
) -> float | jnp.ndarray:
"""Calculate interpolated marginal utility and value function.
Args:
policy_high (float): Policy function value at the higher end of the
interpolation interval.
value_high (float): Value function value at the higher end of the
interpolation interval.
wealth_high (float): Endogenous wealth grid value at the higher end of the
interpolation interval.
policy_low (float): Policy function value at the lower end of the
interpolation interval.
value_low (float): Value function value at the lower end of the
interpolation interval.
wealth_low (float): Endogenous wealth grid value at the lower end of the
interpolation interval.
new_wealth (float): New endogenous wealth grid value.
endog_grid_min (float): Minimum endogenous wealth grid value.
value_min (float): Minimum value function value.
state_choice_vec (Dict): Dictionary containing a single state and choice.
params (dict): Dictionary containing the model parameters.
Returns:
tuple:
- marg_util_interp (float): Interpolated marginal utility function.
- value_interp (float): Interpolated value function.
"""

value_interp_on_grid = linear_interpolation_formula(
y_high=value_high,
y_low=value_low,
x_high=wealth_high,
x_low=wealth_low,
x_new=new_wealth,
)

value_interp = check_value_if_credit_constrained(
value_interp_on_grid=value_interp_on_grid,
value_at_zero_wealth=value_at_zero_wealth,
new_wealth=new_wealth,
endog_grid_min=endog_grid_min,
params=params,
state_choice_vec=state_choice_vec,
compute_utility=compute_utility,
)
return value_interp


def check_value_if_credit_constrained(
value_interp_on_grid,
value_at_zero_wealth,
new_wealth,
endog_grid_min,
params,
state_choice_vec,
compute_utility,
):
"""This function takes the value interpolated on the solution and checks if it is in
the region, where consume all your wealth is the optimal solution.
This is by construction endog_grid_min. If so, it returns the closed form solution
for the value function, by calculating the utility of consuming all the wealth and
adding the discounted expected value of zero wealth. Otherwise, it returns the
interpolated value function.
"""
utility = compute_utility(
consumption=new_wealth,
params=params,
**state_choice_vec,
)
value_interp_closed_form = utility + params["beta"] * value_at_zero_wealth

credit_constraint = new_wealth <= endog_grid_min
value_interp = (
credit_constraint * value_interp_closed_form
+ (1 - credit_constraint) * value_interp_on_grid
)
return value_interp
return marg_utils, value_interp
125 changes: 125 additions & 0 deletions src/dcegm/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import jax.numpy as jnp
from dcegm.interpolation import interp_policy_on_wealth
from dcegm.interpolation import interp_value_and_policy_on_wealth
from dcegm.interpolation import interp_value_on_wealth


def policy_and_value_for_state_choice_vec(
state_choice_vec,
wealth,
map_state_choice_to_index,
state_space_names,
endog_grid_solved,
policy_solved,
value_solved,
compute_utility,
params,
):
"""Get policy and value for a given state and choice vector.
Args:
state_choice_vec (Dict): Dictionary containing a single state and choice.
model (Model): Model object.
params (Dict): Dictionary containing the model parameters.
Returns:
Tuple[float, float]: Policy and value for the given state and choice vector.
"""
state_choice_tuple = tuple(
state_choice_vec[st] for st in state_space_names + ["choice"]
)

state_choice_index = map_state_choice_to_index[state_choice_tuple]
policy, value = interp_value_and_policy_on_wealth(
wealth=wealth,
endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0),
policy=jnp.take(policy_solved, state_choice_index, axis=0),
value=jnp.take(value_solved, state_choice_index, axis=0),
compute_utility=compute_utility,
state_choice_vec=state_choice_vec,
params=params,
)
return policy, value


def value_for_state_choice_vec(
state_choice_vec,
wealth,
map_state_choice_to_index,
state_space_names,
endog_grid_solved,
value_solved,
compute_utility,
params,
):
"""Get policy and value for a given state and choice vector.
Args:
state_choice_vec (Dict): Dictionary containing a single state and choice.
model (Model): Model object.
params (Dict): Dictionary containing the model parameters.
Returns:
Tuple[float, float]: Policy and value for the given state and choice vector.
"""
state_choice_tuple = tuple(
state_choice_vec[st] for st in state_space_names + ["choice"]
)

state_choice_index = map_state_choice_to_index[state_choice_tuple]

value = interp_value_on_wealth(
wealth=wealth,
endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0),
value=jnp.take(value_solved, state_choice_index, axis=0),
compute_utility=compute_utility,
state_choice_vec=state_choice_vec,
params=params,
)
return value


def policy_for_state_choice_vec(
state_choice_vec,
wealth,
map_state_choice_to_index,
state_space_names,
endog_grid_solved,
policy_solved,
):
"""Get policy and value for a given state and choice vector.
Args:
state_choice_vec (Dict): Dictionary containing a single state and choice.
model (Model): Model object.
params (Dict): Dictionary containing the model parameters.
Returns:
Tuple[float, float]: Policy and value for the given state and choice vector.
"""
state_choice_tuple = tuple(
state_choice_vec[st] for st in state_space_names + ["choice"]
)

state_choice_index = map_state_choice_to_index[state_choice_tuple]

policy = interp_policy_on_wealth(
wealth=wealth,
endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0),
policy=jnp.take(policy_solved, state_choice_index, axis=0),
)

return policy


def get_state_choice_index_per_state(
map_state_choice_to_index, states, state_space_names
):
indexes = map_state_choice_to_index[
tuple((states[key],) for key in state_space_names)
]
# As the code above generates a dummy dimension in the first we eliminate that
return indexes[0]
Loading

0 comments on commit 6694e41

Please sign in to comment.