Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend dcegm interface #101

Merged
merged 14 commits into from
May 27, 2024
223 changes: 25 additions & 198 deletions src/dcegm/egm/interpolate_marginal_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
from typing import Dict
from typing import Tuple

import numpy as np
from dcegm.interpolation import get_index_high_and_low
from dcegm.interpolation import linear_interpolation_formula
from dcegm.interpolation import interp_value_and_policy_on_wealth
from jax import numpy as jnp
from jax import vmap


def interpolate_value_and_calc_marginal_utility(
def interpolate_value_and_marg_utility_on_next_period_wealth(
compute_marginal_utility: Callable,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
Expand All @@ -19,7 +17,7 @@ def interpolate_value_and_calc_marginal_utility(
value_child_state_choice: jnp.ndarray,
params: Dict[str, float],
) -> Tuple[float, float]:
"""Interpolate marginal utilities.
"""Interpolate value and policy in the child states and compute the marginal value.

Args:
compute_marginal_utility (callable): User-defined function to compute the
Expand Down Expand Up @@ -51,198 +49,27 @@ def interpolate_value_and_calc_marginal_utility(
containing the interpolated value function.

"""
# For all choices, the wealth is the same in the solution
ind_high, ind_low = get_index_high_and_low(
x=endog_grid_child_state_choice, x_new=wealth_beginning_of_period
)
marg_utils, value_interp = vmap(
vmap(
_interpolate_value_and_marg_util,
in_axes=(0, 0, 0, 0, 0, 0, 0, None, None, None, None, None, None),
),
in_axes=(0, 0, 0, 0, 0, 0, 0, None, None, None, None, None, None),
)(
jnp.take(policy_child_state_choice, ind_high),
value_child_state_choice[ind_high],
endog_grid_child_state_choice[ind_high],
jnp.take(policy_child_state_choice, ind_low),
value_child_state_choice[ind_low],
endog_grid_child_state_choice[ind_low],
wealth_beginning_of_period,
compute_utility,
compute_marginal_utility,
endog_grid_child_state_choice[1],
value_child_state_choice[0],
state_choice_vec,
params,
)

return marg_utils, value_interp


def _interpolate_value_and_marg_util(
policy_high: float | jnp.ndarray,
value_high: float | jnp.ndarray,
wealth_high: float | jnp.ndarray,
policy_low: float | jnp.ndarray,
value_low: float | jnp.ndarray,
wealth_low: float | jnp.ndarray,
new_wealth: float | jnp.ndarray,
compute_utility: Callable,
compute_marginal_utility: Callable,
endog_grid_min: float | jnp.ndarray,
value_at_zero_wealth: float | jnp.ndarray,
state_choice_vec: Dict[str, int],
params: Dict[str, float],
) -> Tuple[float, float]:
"""Calculate interpolated marginal utility and value function.

Args:
policy_high (float): Policy function value at the higher end of the
interpolation interval.
value_high (float): Value function value at the higher end of the
interpolation interval.
wealth_high (float): Endogenous wealth grid value at the higher end of the
interpolation interval.
policy_low (float): Policy function value at the lower end of the
interpolation interval.
value_low (float): Value function value at the lower end of the
interpolation interval.
wealth_low (float): Endogenous wealth grid value at the lower end of the
interpolation interval.
new_wealth (float): New endogenous wealth grid value.
compute_marginal_utility (callable): Function for calculating the marginal
utility from consumption level. The input ```params``` is already
partialled in.
endog_grid_min (float): Minimum endogenous wealth grid value.
value_min (float): Minimum value function value.
choice (int): Discrete choice of an agent.
params (dict): Dictionary containing the model parameters.

Returns:
tuple:

- marg_util_interp (float): Interpolated marginal utility function.
- value_interp (float): Interpolated value function.

"""
policy_interp = linear_interpolation_formula(
y_high=policy_high,
y_low=policy_low,
x_high=wealth_high,
x_low=wealth_low,
x_new=new_wealth,
)

value_interp = interp_value_and_check_creditconstraint(
value_high=value_high,
wealth_high=wealth_high,
value_low=value_low,
wealth_low=wealth_low,
new_wealth=new_wealth,
compute_utility=compute_utility,
endog_grid_min=endog_grid_min,
value_at_zero_wealth=value_at_zero_wealth,
state_choice_vec=state_choice_vec,
params=params,
)

marg_utility_interp = compute_marginal_utility(
consumption=policy_interp, params=params, **state_choice_vec
)

return marg_utility_interp, value_interp

# Generate interpolation function for single wealth point where the endogenous grid,
# policy and value are fixed.
def interp_on_single_wealth(wealth):
policy_interp, value_interp = interp_value_and_policy_on_wealth(
wealth=wealth,
endog_grid=endog_grid_child_state_choice,
policy=policy_child_state_choice,
value=value_child_state_choice,
compute_utility=compute_utility,
state_choice_vec=state_choice_vec,
params=params,
)
marg_utility_interp = compute_marginal_utility(
consumption=policy_interp, params=params, **state_choice_vec
)
return marg_utility_interp, value_interp

# Generate vectorized function for savings and income shock dimension
vector_interp_func = vmap(vmap(interp_on_single_wealth))

marg_utils, value_interp = vector_interp_func(wealth_beginning_of_period)

def interp_value_and_check_creditconstraint(
value_high: float | jnp.ndarray,
wealth_high: float | jnp.ndarray,
value_low: float | jnp.ndarray,
wealth_low: float | jnp.ndarray,
new_wealth: float | jnp.ndarray,
compute_utility: Callable,
endog_grid_min: float,
value_at_zero_wealth: float,
state_choice_vec: Dict[str, int],
params: Dict[str, float],
) -> float | jnp.ndarray:
"""Calculate interpolated marginal utility and value function.

Args:
policy_high (float): Policy function value at the higher end of the
interpolation interval.
value_high (float): Value function value at the higher end of the
interpolation interval.
wealth_high (float): Endogenous wealth grid value at the higher end of the
interpolation interval.
policy_low (float): Policy function value at the lower end of the
interpolation interval.
value_low (float): Value function value at the lower end of the
interpolation interval.
wealth_low (float): Endogenous wealth grid value at the lower end of the
interpolation interval.
new_wealth (float): New endogenous wealth grid value.
endog_grid_min (float): Minimum endogenous wealth grid value.
value_min (float): Minimum value function value.
state_choice_vec (Dict): Dictionary containing a single state and choice.
params (dict): Dictionary containing the model parameters.

Returns:
tuple:

- marg_util_interp (float): Interpolated marginal utility function.
- value_interp (float): Interpolated value function.

"""

value_interp_on_grid = linear_interpolation_formula(
y_high=value_high,
y_low=value_low,
x_high=wealth_high,
x_low=wealth_low,
x_new=new_wealth,
)

value_interp = check_value_if_credit_constrained(
value_interp_on_grid=value_interp_on_grid,
value_at_zero_wealth=value_at_zero_wealth,
new_wealth=new_wealth,
endog_grid_min=endog_grid_min,
params=params,
state_choice_vec=state_choice_vec,
compute_utility=compute_utility,
)
return value_interp


def check_value_if_credit_constrained(
value_interp_on_grid,
value_at_zero_wealth,
new_wealth,
endog_grid_min,
params,
state_choice_vec,
compute_utility,
):
"""This function takes the value interpolated on the solution and checks if it is in
the region, where consume all your wealth is the optimal solution.

This is by construction endog_grid_min. If so, it returns the closed form solution
for the value function, by calculating the utility of consuming all the wealth and
adding the discounted expected value of zero wealth. Otherwise, it returns the
interpolated value function.

"""
utility = compute_utility(
consumption=new_wealth,
params=params,
**state_choice_vec,
)
value_interp_closed_form = utility + params["beta"] * value_at_zero_wealth

credit_constraint = new_wealth <= endog_grid_min
value_interp = (
credit_constraint * value_interp_closed_form
+ (1 - credit_constraint) * value_interp_on_grid
)
return value_interp
return marg_utils, value_interp
125 changes: 125 additions & 0 deletions src/dcegm/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import jax.numpy as jnp
from dcegm.interpolation import interp_policy_on_wealth
from dcegm.interpolation import interp_value_and_policy_on_wealth
from dcegm.interpolation import interp_value_on_wealth


def policy_and_value_for_state_choice_vec(
state_choice_vec,
wealth,
map_state_choice_to_index,
state_space_names,
endog_grid_solved,
policy_solved,
value_solved,
compute_utility,
params,
):
"""Get policy and value for a given state and choice vector.

Args:
state_choice_vec (Dict): Dictionary containing a single state and choice.
model (Model): Model object.
params (Dict): Dictionary containing the model parameters.

Returns:
Tuple[float, float]: Policy and value for the given state and choice vector.

"""
state_choice_tuple = tuple(

Check warning on line 29 in src/dcegm/interface.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/interface.py#L29

Added line #L29 was not covered by tests
state_choice_vec[st] for st in state_space_names + ["choice"]
)

state_choice_index = map_state_choice_to_index[state_choice_tuple]
policy, value = interp_value_and_policy_on_wealth(

Check warning on line 34 in src/dcegm/interface.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/interface.py#L33-L34

Added lines #L33 - L34 were not covered by tests
wealth=wealth,
endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0),
policy=jnp.take(policy_solved, state_choice_index, axis=0),
value=jnp.take(value_solved, state_choice_index, axis=0),
compute_utility=compute_utility,
state_choice_vec=state_choice_vec,
params=params,
)
return policy, value

Check warning on line 43 in src/dcegm/interface.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/interface.py#L43

Added line #L43 was not covered by tests


def value_for_state_choice_vec(
state_choice_vec,
wealth,
map_state_choice_to_index,
state_space_names,
endog_grid_solved,
value_solved,
compute_utility,
params,
):
"""Get policy and value for a given state and choice vector.

Args:
state_choice_vec (Dict): Dictionary containing a single state and choice.
model (Model): Model object.
params (Dict): Dictionary containing the model parameters.

Returns:
Tuple[float, float]: Policy and value for the given state and choice vector.

"""
state_choice_tuple = tuple(

Check warning on line 67 in src/dcegm/interface.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/interface.py#L67

Added line #L67 was not covered by tests
state_choice_vec[st] for st in state_space_names + ["choice"]
)

state_choice_index = map_state_choice_to_index[state_choice_tuple]

Check warning on line 71 in src/dcegm/interface.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/interface.py#L71

Added line #L71 was not covered by tests

value = interp_value_on_wealth(

Check warning on line 73 in src/dcegm/interface.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/interface.py#L73

Added line #L73 was not covered by tests
wealth=wealth,
endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0),
value=jnp.take(value_solved, state_choice_index, axis=0),
compute_utility=compute_utility,
state_choice_vec=state_choice_vec,
params=params,
)
return value

Check warning on line 81 in src/dcegm/interface.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/interface.py#L81

Added line #L81 was not covered by tests


def policy_for_state_choice_vec(
state_choice_vec,
wealth,
map_state_choice_to_index,
state_space_names,
endog_grid_solved,
policy_solved,
):
"""Get policy and value for a given state and choice vector.

Args:
state_choice_vec (Dict): Dictionary containing a single state and choice.
model (Model): Model object.
params (Dict): Dictionary containing the model parameters.

Returns:
Tuple[float, float]: Policy and value for the given state and choice vector.

"""
state_choice_tuple = tuple(

Check warning on line 103 in src/dcegm/interface.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/interface.py#L103

Added line #L103 was not covered by tests
state_choice_vec[st] for st in state_space_names + ["choice"]
)

state_choice_index = map_state_choice_to_index[state_choice_tuple]

Check warning on line 107 in src/dcegm/interface.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/interface.py#L107

Added line #L107 was not covered by tests

policy = interp_policy_on_wealth(

Check warning on line 109 in src/dcegm/interface.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/interface.py#L109

Added line #L109 was not covered by tests
wealth=wealth,
endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0),
policy=jnp.take(policy_solved, state_choice_index, axis=0),
)

return policy

Check warning on line 115 in src/dcegm/interface.py

View check run for this annotation

Codecov / codecov/patch

src/dcegm/interface.py#L115

Added line #L115 was not covered by tests


def get_state_choice_index_per_state(
map_state_choice_to_index, states, state_space_names
):
indexes = map_state_choice_to_index[
tuple((states[key],) for key in state_space_names)
]
# As the code above generates a dummy dimension in the first we eliminate that
return indexes[0]
Loading
Loading