In [12]:
%load_ext autoreload

%autoreload 2

import jax
import jax.numpy as jnp
from jax import vmap, jit
from jax import profiler
import os
from dcegm.solve import solve_dcegm, get_solve_function
from functools import partial
from dcegm.fast_upper_envelope import fast_upper_envelope
import pandas as pd
import yaml
from dcegm.pre_processing import convert_params_to_dict
import numpy as np


TEST_RESOURCES_DIR = "../resources/"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Import toy model

In [2]:
from toy_models.consumption_retirement_model.budget_functions import budget_constraint
from toy_models.consumption_retirement_model.exogenous_processes import (
    get_transition_matrix_by_state,
)
from toy_models.consumption_retirement_model.final_period_solution import (
    solve_final_period_scalar,
)
from toy_models.consumption_retirement_model.state_space_objects import (
    create_state_space,
)
from toy_models.consumption_retirement_model.state_space_objects import (
    get_state_specific_feasible_choice_set,
)
from toy_models.consumption_retirement_model.utility_functions import (
    inverse_marginal_utility_crra,
)
from toy_models.consumption_retirement_model.utility_functions import (
    marginal_utility_crra,
)
from toy_models.consumption_retirement_model.utility_functions import (
    utiility_func_log_crra,
)
from toy_models.consumption_retirement_model.utility_functions import utility_func_crra

In [3]:
utility_functions = {
    "utility": utility_func_crra,
    "inverse_marginal_utility": inverse_marginal_utility_crra,
    "marginal_utility": marginal_utility_crra,
}

In [4]:
state_space_functions = {
    "create_state_space": create_state_space,
    "get_state_specific_choice_set": get_state_specific_feasible_choice_set,
}

In [5]:
TEST_RESOURCES_DIR = "../resources/"
params = pd.read_csv(
    TEST_RESOURCES_DIR + "retirement_taste_shocks.csv", index_col=["category", "name"]
)
options = yaml.safe_load(
    open(TEST_RESOURCES_DIR + "retirement_taste_shocks.yaml", "rb")
)
options["n_exog_states"] = 1
exog_savings_grid = jnp.linspace(0, options["max_wealth"], options["n_grid_points"])

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


# Timeit

In [8]:
backwards_jit = get_solve_function(
    options=options,
    exog_savings_grid=exog_savings_grid,
    utility_functions=utility_functions,
    budget_constraint=budget_constraint,
    final_period_solution=solve_final_period_scalar,
    state_space_functions=state_space_functions,
    transition_function=get_transition_matrix_by_state,
)

In [9]:
backwards_jit(params).block_until_ready()
%timeit backwards_jit(params).block_until_ready()

58.8 ms ± 2.36 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [14]:
policy_egm = jnp.array(
    np.genfromtxt(TEST_RESOURCES_DIR + "period_tests/pol10.csv", delimiter=",")
)
value_egm = jnp.array(
    np.genfromtxt(TEST_RESOURCES_DIR + "period_tests/val10.csv", delimiter=",")
)

In [21]:
test_upp_env = jit(partial(fast_upper_envelope, num_iter=int(value_egm.shape[1])))

In [22]:
test_upp_env(
    endog_grid=policy_egm[0, 1:],
    value=value_egm[1, 1:],
    policy=policy_egm[1, 1:],
    expected_value_zero_savings=value_egm[1, 0],
).block_until_ready()
%timeit test_upp_env(endog_grid=policy_egm[0, 1:], value=value_egm[1, 1:], policy=policy_egm[1, 1:], expected_value_zero_savings=value_egm[1, 0]).block_until_ready()

2.55 ms ± 23.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
