In [1]:
%load_ext autoreload

%autoreload 2

import jax
import jax.numpy as jnp
from jax import vmap, jit
from jax import profiler
import os
from dcegm.solve import solve_dcegm, get_solve_function
from functools import partial
from dcegm.fast_upper_envelope import fast_upper_envelope, fast_upper_envelope_wrapper
import pandas as pd
import yaml
from dcegm.pre_processing import convert_params_to_dict, get_partial_functions
import numpy as np
import time


TEST_RESOURCES_DIR = "../resources/"

## Import toy model

In [2]:
from toy_models.consumption_retirement_model.budget_functions import budget_constraint
from toy_models.consumption_retirement_model.exogenous_processes import (
    get_transition_matrix_by_state,
)
from toy_models.consumption_retirement_model.final_period_solution import (
    solve_final_period_scalar,
)
from toy_models.consumption_retirement_model.state_space_objects import (
    create_state_space,
    update_state,
)
from toy_models.consumption_retirement_model.state_space_objects import (
    get_state_specific_feasible_choice_set,
)
from toy_models.consumption_retirement_model.utility_functions import (
    inverse_marginal_utility_crra,
)
from toy_models.consumption_retirement_model.utility_functions import (
    marginal_utility_crra,
)
from toy_models.consumption_retirement_model.utility_functions import (
    utiility_func_log_crra,
)
from toy_models.consumption_retirement_model.utility_functions import utility_func_crra

from dcegm.state_space import create_state_choice_space

In [3]:
# def _create_state_space_custom(options):
#     """Create state space object and indexer."""
#     n_periods = options["n_periods"]
#     n_choices = options["n_discrete_choices"]  # lagged_choice is a state variable
#     n_exog_states = options["n_exog_states"]
#     # n_experience = options["n_experience"]

#     shape = (n_periods, n_choices, n_exog_states)

#     map_state_to_index = np.full(shape, -9999, dtype=np.int64)
#     _state_space = []

#     i = 0
#     for period in range(n_periods):
#         for choice in range(n_choices):
#             for exog_process in range(n_exog_states):
#                 map_state_to_index[period, choice, exog_process] = i

#                 row = [period, choice, exog_process]
#                 _state_space.append(row)

#                 i += 1

#     state_space = np.array(_state_space, dtype=np.int64)

#     return state_space, map_state_to_index

In [4]:
# def _get_transition_matrix_custom(state, params_dict):
#     return jnp.append(1, jnp.zeros(27 - 1))

In [5]:
utility_functions = {
    "utility": utility_func_crra,
    "inverse_marginal_utility": inverse_marginal_utility_crra,
    "marginal_utility": marginal_utility_crra,
}

In [6]:
state_space_functions = {
    # "create_state_space": _create_state_space_custom,
    "create_state_space": create_state_space,
    "state_specific_choice_set": get_state_specific_feasible_choice_set,
    "next_period_endogenous_state": update_state,
}

In [7]:
model = "deaton"

params = pd.read_csv(
    TEST_RESOURCES_DIR + f"{model}.csv", index_col=["category", "name"]
)

N_EXOG_STATES = 1

options = yaml.safe_load(open(TEST_RESOURCES_DIR + f"{model}.yaml", "rb"))
options["n_exog_states"] = N_EXOG_STATES
# options["n_periods"] = 20
# options["n_experience"] = 20
# options["n_discrete_choices"] = 3
# options["min_age"] = 50
# options["max_wealth"] = 50_000
# options["n_grid_points"] = 500

exog_savings_grid = jnp.linspace(0, options["max_wealth"], options["n_grid_points"])

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [8]:
options

{'n_periods': 25,
 'min_age': 20,
 'n_discrete_choices': 1,
 'n_grid_points': 500,
 'max_wealth': 75,
 'quadrature_points_stochastic': 10,
 'n_simulations': 10,
 'n_exog_states': 1}

In [9]:
solve_jit = get_solve_function(
    options=options,
    exog_savings_grid=exog_savings_grid,
    utility_functions=utility_functions,
    budget_constraint=budget_constraint,
    final_period_solution=solve_final_period_scalar,
    state_space_functions=state_space_functions,
    # transition_function=_get_transition_matrix_custom,
    transition_function=get_transition_matrix_by_state,
)

In [10]:
start = time.time()
jax.block_until_ready(solve_jit(params))
time.time() - start

7.980986595153809

In [11]:
# https://github.com/google/jax/discussions/11169

%timeit jax.block_until_ready(solve_jit(params))

22.4 ms ± 774 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    # Run the operations to be profiled
    jax.block_until_ready(backwards_jit(params))

Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz


2023-08-26 14:58:14.781338: E external/xla/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace
2023-08-26 14:58:14.808664: E external/xla/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace
