In [29]:
from jax import vmap, jit
import pickle
import pandas as pd
import io
import yaml
from functools import partial
import jax.numpy as jnp

from toy_models.consumption_retirement_model.utility_functions import (
    utility_func_crra,
    marginal_utility_crra,
    inverse_marginal_utility_crra,
)
from toy_models.consumption_retirement_model.budget_functions import budget_constraint
from toy_models.consumption_retirement_model.exogenous_processes import (
    get_transition_matrix_by_state,
)
from dcegm.pre_processing import get_partial_functions, params_todict

In [10]:
utility_functions = {
    "utility": utility_func_crra,
    "inverse_marginal_utility": inverse_marginal_utility_crra,
    "marginal_utility": marginal_utility_crra,
}

In [18]:
params = """category,name,value,comment
beta,beta,0.95,discount factor
delta,delta,0,disutility of work
utility_function,theta,1,CRRA coefficient
wage,constant,0.75,age-independent labor income
wage,exp,0.04,return to experience
wage,exp_squared,-0.0004,return to experience squared
shocks,sigma,0.25,shock on labor income sigma parameter/standard deviation
shocks,lambda,2.220400e-16,taste shock (scale) parameter
assets,interest_rate,0.05,interest rate on capital
assets,initial_wealth_low,0,lowest level of initial wealth (relevant for simulation)
assets,initial_wealth_high,30,highest level of initial wealth (relevant for simulation)
assets,max_wealth,75,maximum level of wealth
assets,consumption_floor,0.0,consumption floor/retirement safety net (only relevant in the dc-egm retirement model)
"""

params = pd.read_csv(io.StringIO(params), index_col=["category", "name"])
params_dict = params_todict(params)

In [19]:
options = """n_periods: 25
min_age: 20
n_discrete_choices: 1
grid_points_wealth: 100
quadrature_points_stochastic: 10
n_simulations: 10
n_exog_processes: 1
"""

options = yaml.safe_load(options)
options

{'n_periods': 25,
 'min_age': 20,
 'n_discrete_choices': 1,
 'grid_points_wealth': 100,
 'quadrature_points_stochastic': 10,
 'n_simulations': 10,
 'n_exog_processes': 1}

In [20]:
(
    compute_utility,
    compute_marginal_utility,
    compute_inverse_marginal_utility,
    compute_value,
    compute_next_period_wealth,
    transition_function,
) = get_partial_functions(
    params_dict,
    options,
    utility_functions,
    budget_constraint,
    get_transition_matrix_by_state,
)

In [21]:
possible_child_states = pickle.load(
    open("profiling_resources/possible_child_states.pkl", "rb")
)
exogenous_savings_grid = pickle.load(
    open("profiling_resources/exogenous_savings_grid.pkl", "rb")
)
income_shock_draws = pickle.load(
    open("profiling_resources/income_shock_draws.pkl", "rb")
)

In [30]:
num_states = 50000
states_repeated = jnp.repeat(possible_child_states, num_states, axis=0)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [36]:
num_savings_repeats = 5
savings_repeated = jnp.repeat(exogenous_savings_grid, num_savings_repeats, axis=0)

In [37]:
jited_vmap = jit(
    vmap(
        vmap(
            vmap(compute_next_period_wealth, in_axes=(0, None, None)),
            in_axes=(None, 0, None),
        ),
        in_axes=(None, None, 0),
    )
)

In [38]:
def loop_func(states, savings, income, states_per_period, runs):
    for i in range(runs):
        result = vmap(
            vmap(
                vmap(compute_next_period_wealth, in_axes=(0, None, None)),
                in_axes=(None, 0, None),
            ),
            in_axes=(None, None, 0),
        )(states[:states_per_period], savings, income)
    return result

In [39]:
jitted_lopp_func = jit(
    partial(loop_func, states_per_period=int(num_states / 500), runs=500)
)

In [40]:
%timeit jitted_lopp_func(states=states_repeated, savings=savings_repeated, income=income_shock_draws).block_until_ready()

134 µs ± 20.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [41]:
%timeit jited_vmap(states_repeated, savings_repeated, income_shock_draws).block_until_ready()

305 ms ± 19.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
