In [1]:
%load_ext autoreload

%autoreload 2

import jax
import jax.numpy as jnp
from jax import vmap, jit
from jax import profiler
import os
from dcegm.solve import solve_dcegm, get_solve_function
from functools import partial
from dcegm.fast_upper_envelope import fast_upper_envelope, fast_upper_envelope_wrapper
import pandas as pd
import yaml
from dcegm.pre_processing import convert_params_to_dict, get_partial_functions
import numpy as np


TEST_RESOURCES_DIR = "../resources/"

# Import toy model

In [2]:
from toy_models.consumption_retirement_model.budget_functions import budget_constraint
from toy_models.consumption_retirement_model.exogenous_processes import (
    get_transition_matrix_by_state,
)
from toy_models.consumption_retirement_model.final_period_solution import (
    solve_final_period_scalar,
)
from toy_models.consumption_retirement_model.state_space_objects import (
    create_state_space,
)
from toy_models.consumption_retirement_model.state_space_objects import (
    get_state_specific_feasible_choice_set,
)
from toy_models.consumption_retirement_model.utility_functions import (
    inverse_marginal_utility_crra,
)
from toy_models.consumption_retirement_model.utility_functions import (
    marginal_utility_crra,
)
from toy_models.consumption_retirement_model.utility_functions import (
    utiility_func_log_crra,
)
from toy_models.consumption_retirement_model.utility_functions import utility_func_crra

In [3]:
utility_functions = {
    "utility": utility_func_crra,
    "inverse_marginal_utility": inverse_marginal_utility_crra,
    "marginal_utility": marginal_utility_crra,
}

In [4]:
state_space_functions = {
    "create_state_space": create_state_space,
    "get_state_specific_choice_set": get_state_specific_feasible_choice_set,
}

# Timeit overall solve

In [5]:
model = "retirement_taste_shocks"
TEST_RESOURCES_DIR = "../resources/"

params = pd.read_csv(
    TEST_RESOURCES_DIR + f"{model}.csv", index_col=["category", "name"]
)
options = yaml.safe_load(open(TEST_RESOURCES_DIR + f"{model}.yaml", "rb"))
options["n_exog_states"] = 1
# options["n_periods"] = 3
exog_savings_grid = jnp.linspace(0, options["max_wealth"], options["n_grid_points"])

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [46]:
backwards_jit = get_solve_function(
    options=options,
    exog_savings_grid=exog_savings_grid,
    utility_functions=utility_functions,
    budget_constraint=budget_constraint,
    final_period_solution=solve_final_period_scalar,
    state_space_functions=state_space_functions,
    transition_function=get_transition_matrix_by_state,
)

In [47]:
backwards_jit(params).block_until_ready()
%timeit backwards_jit(params).block_until_ready()

3.91 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Timeit upper envelope

In [6]:
policy_egm = jnp.array(
    np.genfromtxt(TEST_RESOURCES_DIR + "period_tests/pol10.csv", delimiter=",")
)
value_egm = jnp.array(
    np.genfromtxt(TEST_RESOURCES_DIR + "period_tests/val10.csv", delimiter=",")
)
params_dict = convert_params_to_dict(params)
(
    compute_utility,
    compute_marginal_utility,
    compute_inverse_marginal_utility,
    compute_value,
    compute_next_period_wealth,
    compute_upper_envelope,
    transition_vector_by_state,
) = get_partial_functions(
    options,
    user_utility_functions=utility_functions,
    user_budget_constraint=budget_constraint,
    exogenous_transition_function=get_transition_matrix_by_state,
)

In [14]:
# test_upp_env = jit(partial(fast_upper_envelope_wrapper, choice=1, params=params_dict, compute_value=compute_value))
test_upp_env = jit(partial(fast_upper_envelope, num_iter=int(value_egm.shape[1])))

In [15]:
test_upp_env(
    endog_grid=policy_egm[0, 1:],
    value=value_egm[1, 1:],
    policy=policy_egm[1, 1:],
    expected_value_zero_savings=value_egm[1, 0],
).block_until_ready()
%timeit test_upp_env(endog_grid=policy_egm[0, 1:], value=value_egm[1, 1:], policy=policy_egm[1, 1:], expected_value_zero_savings=value_egm[1, 0]).block_until_ready()

2.5 ms ± 45.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Timing of jax.lax.scan

In [26]:
def body_1(carry, _it_step):
    new_carry = carry + 1.5
    return new_carry, new_carry


def body_2(carry, _it_step):
    new_carry = carry + 1
    return new_carry, new_carry

test_body_1 = jit(lambda start: jax.lax.scan(body_1, start, xs=None, length=5000))
test_body_2 = jit(lambda start: jax.lax.scan(body_2, start, xs=None, length=5000))

In [27]:
test_body_1(1.0)
%timeit test_body_1(1.0)

23.3 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [28]:
test_body_2(1)
%timeit test_body_2(1)

12 µs ± 59.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
