# Jax profiler test: marginal_util_and_exp_max_value_states_period

In [4]:
import io
import pickle
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import yaml
#from toy_models.consumption_retirement_model import (compute_expected_value, compute_next_period_marginal_utility,inverse_marginal_utility_crra, utility_func_crra)
from toy_models.consumption_retirement_model.utility_functions import utility_func_crra, marginal_utility_crra, inverse_marginal_utility_crra
from toy_models.consumption_retirement_model.budget_functions import budget_constraint
from toy_models.consumption_retirement_model.exogenous_processes import get_transition_matrix_by_state
from dcegm.pre_processing import get_partial_functions, params_todict
from dcegm.marg_utilities_and_exp_value import marginal_util_and_exp_max_value_states_period
import jax
import jax.numpy as jnp
from jax import profiler
import os
from dcegm.solve import solve_dcegm

ModuleNotFoundError: No module named 'toy_models'

## Specify *utility functions*, *params* and *options*

In [2]:
utility_functions = {
    "utility": utility_func_crra,
    "inverse_marginal_utility": inverse_marginal_utility_crra,
    "marginal_utility": marginal_utility_crra,
}

In [2]:
params = """category,name,value,comment
beta,beta,0.95,discount factor
delta,delta,0,disutility of work
utility_function,theta,1,CRRA coefficient
wage,constant,0.75,age-independent labor income
wage,exp,0.04,return to experience
wage,exp_squared,-0.0004,return to experience squared
shocks,sigma,0.25,shock on labor income sigma parameter/standard deviation
shocks,lambda,2.220400e-16,taste shock (scale) parameter
assets,interest_rate,0.05,interest rate on capital
assets,initial_wealth_low,0,lowest level of initial wealth (relevant for simulation)
assets,initial_wealth_high,30,highest level of initial wealth (relevant for simulation)
assets,max_wealth,75,maximum level of wealth
assets,consumption_floor,0.0,consumption floor/retirement safety net (only relevant in the dc-egm retirement model)
"""

In [3]:
params = pd.read_csv(io.StringIO(params), index_col=["category", "name"])
params

NameError: name 'pd' is not defined

In [5]:
params_dict = params_todict(params)

In [6]:
options = """n_periods: 25
min_age: 20
n_discrete_choices: 1
grid_points_wealth: 100
quadrature_points_stochastic: 10
n_simulations: 10
n_exog_processes: 1
"""

In [7]:
options = yaml.safe_load(options)
options

{'n_periods': 25,
 'min_age': 20,
 'n_discrete_choices': 1,
 'grid_points_wealth': 100,
 'quadrature_points_stochastic': 10,
 'n_simulations': 10,
 'n_exog_processes': 1}

## Specify inputs for function **marginal_util_and_exp_max_value_states_period**

In [8]:
(
    compute_utility,
    compute_marginal_utility,
    compute_inverse_marginal_utility,
    compute_value,
    compute_next_period_wealth,
    transition_function,
) = get_partial_functions(
    params_dict,
    options,
    utility_functions,
    budget_constraint,
    get_transition_matrix_by_state,
)

In [9]:
taste_shock_scale = 2.220400e-16
exogenous_savings_grid = pickle.load(open("profiling_resources/exogenous_savings_grid.pkl", "rb"))
income_shock_draws = pickle.load(open("profiling_resources/income_shock_draws.pkl", "rb"))
income_shock_weights = pickle.load(open("profiling_resources/income_shock_weights.pkl", "rb"))
possible_child_states = pickle.load(open("profiling_resources/possible_child_states.pkl", "rb"))
choices_child_states = pickle.load(open("profiling_resources/choices_child_states.pkl", "rb"))
policies_child_states = pickle.load(open("profiling_resources/policies_child_states.pkl", "rb"))
values_child_states = pickle.load(open("profiling_resources/values_child_states.pkl", "rb"))

In [1]:
params_dict("lambda")

NameError: name 'params_dict' is not defined

In [10]:
num_states = 5000
states_repeated = jnp.repeat(possible_child_states, num_states, axis=0)
choices_repeated = jnp.repeat(choices_child_states, num_states, axis=0)
policies_repeated = jnp.repeat(policies_child_states, num_states, axis=0)
values_repeated = jnp.repeat(values_child_states, num_states, axis=0)

In [11]:
num_savings_repeats = 5
savings_repeated = jnp.repeat(exogenous_savings_grid, num_savings_repeats, axis=0)

In [14]:
with jax.profiler.trace("jax-trace", create_perfetto_link=True):
    marginal_util, max_exp_value = marginal_util_and_exp_max_value_states_period(
        compute_next_period_wealth,
        compute_marginal_utility,
        compute_value,
        taste_shock_scale,
        savings_repeated,
        income_shock_draws,
        income_shock_weights,
        states_repeated,
        choices_repeated,
        policies_repeated,
        values_repeated,
    )

2023-02-15 10:03:28.014761: E external/org_tensorflow/tensorflow/compiler/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace
2023-02-15 10:03:34.162133: E external/org_tensorflow/tensorflow/compiler/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace


ValueError: Invalid trace folder: /Users/viktoriakleinschmidt/Desktop/Arbeit/dc-egm/src/dcegm/sandbox/jax-trace/plugins/profile/2023_02_15_10_03_35

In [15]:
%%timeit
marginal_util, max_exp_value = marginal_util_and_exp_max_value_states_period(
    compute_next_period_wealth,
    compute_marginal_utility,
    compute_value,
    taste_shock_scale,
    savings_repeated,
    income_shock_draws,
    income_shock_weights,
    states_repeated,
    choices_repeated,
    policies_repeated,
    values_repeated,
)

2.46 s ± 214 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [1]:
# use jax profiler
def profiled_marginal_util_and_exp_max_value_states_period(
    compute_next_period_wealth,
    compute_marginal_utility,
    compute_value,
    taste_shock_scale,
    exogenous_savings_grid,
    income_shock_draws,
    income_shock_weights,
    possible_child_states,
    choices_child_states,
    policies_child_states,
    values_child_states,
):
    return profiler.call(
        marginal_util_and_exp_max_value_states_period,
        compute_next_period_wealth,
        compute_marginal_utility,
        compute_value,
        taste_shock_scale,
        exogenous_savings_grid,
        income_shock_draws,
        income_shock_weights,
        possible_child_states,
        choices_child_states,
        policies_child_states,
        values_child_states,
    )

In [12]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):

    result = profiled_marginal_util_and_exp_max_value_states_period(
        compute_next_period_wealth,
        compute_marginal_utility,
        compute_value,
        taste_shock_scale,
        exogenous_savings_grid,
        income_shock_draws,
        income_shock_weights,
        possible_child_states,
        choices_child_states,
        policies_child_states,
        values_child_states,
    )
    result.block_until_ready()

ValueError: Invalid trace folder: /tmp/jax-trace/plugins/profile/2023_02_15_11_28_40

In [19]:
pip install tensorflow tensorboard-plugin-profile

Collecting tensorflow
  Downloading tensorflow-2.11.0-cp39-cp39-macosx_10_14_x86_64.whl (244.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m244.3/244.3 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting tensorboard-plugin-profile
  Downloading tensorboard_plugin_profile-2.11.1-py3-none-any.whl (5.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.4/5.4 MB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hCollecting tensorflow-io-gcs-filesystem>=0.23.1
  Downloading tensorflow_io_gcs_filesystem-0.30.0-cp39-cp39-macosx_10_14_x86_64.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0mm
[?25hCollecting tensorboard<2.12,>=2.11
  Downloading tensorboard-2.11.2-py3-none-any.whl (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m00:01

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.1/77.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting oauthlib>=3.0.0
  Downloading oauthlib-3.2.2-py3-none-any.whl (151 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m151.7/151.7 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorboard-plugin-wit, pyasn1, libclang, flatbuffers, wrapt, werkzeug, termcolor, tensorflow-io-gcs-filesystem, tensorflow-estimator, tensorboard-data-server, rsa, pyasn1-modules, protobuf, oauthlib, keras, h5py, gviz-api, grpcio, google-pasta, gast, cachetools, astunparse, absl-py, tensorboard-plugin-profile, requests-oauthlib, markdown, google-auth, google-auth-oauthlib, tensorboard, tensorflow
Successfully installed absl-py-1.4.0 astunparse-1.6.3 cachetools-5.3.0 flatbuffers-23.1.21 gast-0.4.0 google-auth-2.16.0 google-auth-oauthlib-0.4.6 google-pasta-0.2.0 grpcio-1.51.1 gviz-api-1.10.0 h5py-3.8.0 keras-2.11.0 libclang-15.

In [23]:
# from ChatGPT conversation
import tensorflow as tf
from jax import jit
# Trace the function to create a computation graph
traced_computation = jit(profiled_marginal_util_and_exp_max_value_states_period)
# Use TensorBoard to visualize the computation graph
with tf.summary.create_file_writer("logs").as_default():
    tf.summary.trace_on(graph=True, profiler=True)
    marginal_util, max_exp_value = marginal_util_and_exp_max_value_states_period(
        compute_next_period_wealth,
        compute_marginal_utility,
        compute_value,
        taste_shock_scale,
        savings_repeated,
        income_shock_draws,
        income_shock_weights,
        states_repeated,
        choices_repeated,
        policies_repeated,
        values_repeated,
    )
    with tf.summary.record_if(True):
        tf.summary.trace_export(
            name="my_computation_graph",
            step=0,
            profiler_outdir="logs"
        )


Instructions for updating:
use `tf.profiler.experimental.start` instead.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
Instructions for updating:
`tf.python.eager.profiler` has deprecated, use `tf.profiler` instead.
Instructions for updating:
`tf.python.eager.profiler` has deprecated, use `tf.profiler` instead.


In [30]:
result=jax.xla_computation(marginal_util_and_exp_max_value_states_period)(
        compute_next_period_wealth,
        compute_marginal_utility,
        compute_value,
        taste_shock_scale,
        savings_repeated,
        income_shock_draws,
        income_shock_weights,
        states_repeated,
        choices_repeated,
        policies_repeated,
        values_repeated,
    )

TypeError: Cannot interpret '<CompiledFunction of functools.partial(<function budget_constraint at 0x13fdb4ee0>, params_dict={'beta': 0.95, 'delta': 0.0, 'theta': 1.0, 'constant': 0.75, 'exp': 0.04, 'exp_squared': -0.0004, 'sigma': 0.25, 'lambda': 2.2204e-16, 'interest_rate': 0.05, 'initial_wealth_low': 0.0, 'initial_wealth_high': 30.0, 'max_wealth': 75.0, 'consumption_floor': 0.0}, options={'n_periods': 25, 'min_age': 20, 'n_discrete_choices': 1, 'grid_points_wealth': 100, 'quadrature_points_stochastic': 10, 'n_simulations': 10, 'n_exog_processes': 1})>' as a data type