# Jax profiler test: marginal_util_and_exp_max_value_states_period

In [1]:
%load_ext autoreload

%autoreload 2

import io
import pickle
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import yaml

import jax
import jax.numpy as jnp
from jax import vmap, jit
from jax import profiler
import os
from dcegm.solve import solve_dcegm
from functools import partial
import numba
from dcegm.fast_upper_envelope import fast_upper_envelope


TEST_RESOURCES_DIR = "../resources/"

In [2]:
policy_egm = np.genfromtxt(TEST_RESOURCES_DIR + "period_tests/pol10.csv", delimiter=",")
value_egm = np.genfromtxt(TEST_RESOURCES_DIR + "period_tests/val10.csv", delimiter=",")

In [3]:
test_upp_env = jit(partial(fast_upper_envelope, num_iter=int(value_egm.shape[1])))

In [7]:
test_upp_env(
    endog_grid=policy_egm[0, 1:],
    value=value_egm[1, 1:],
    policy=policy_egm[1, 1:],
    expected_value_zero_savings=value_egm[1, 0],
)
%timeit test_upp_env(endog_grid=policy_egm[0, 1:], value=value_egm[1, 1:], policy=policy_egm[1, 1:], expected_value_zero_savings=value_egm[1, 0])

565 µs ± 6.92 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [17]:
jnp.searchsorted(np.array([1, 2, 2, 4]), 1.5)

Array(1, dtype=int32)

In [47]:
jnp.maximum(jnp.array([1, 2, 3]), 50)

Array([50, 50, 50], dtype=int32)

In [66]:
def test_func(a):
    for i in 
    return jax.lax.select(a, 5.5, np.nan)

In [67]:
test_jit = jit(test_func)

In [70]:
test_jit(True)
%timeit test_jit(True)

3.23 µs ± 101 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [35]:
jnp.where(True, 1, 2)

TypeError: object of type 'bool' has no len()

In [2]:
def test_loop(array_to_loop):
    test_list = [0] * array_to_loop.shape[0]
    for i, num in enumerate(array_to_loop):
        ind = num > 10
        test_list[i] = num * ind + 1 * (1 - ind)
    return test_list

In [3]:
def test_loop_2(array_to_loop):
    for i, num in enumerate(array_to_loop):
        ind = num > 10
        array_to_loop.at[i].set(num * ind + 1 * (1 - ind))
    return array_to_loop

In [61]:
test_jit = jit(test_loop)
test_jit_2 = jit(test_loop_2)
test_array = jnp.arange(100)

In [62]:
test_array_2 = jnp.arange(10, 110)

In [63]:
%timeit test_jit(test_array_2)

127 µs ± 988 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [64]:
%timeit test_jit_2(test_array_2)

97.8 µs ± 17.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [45]:
numba_loop = numba.jit(test_loop_2, nopython=True)

In [46]:
test_np_array = np.arange(10, 110)

In [47]:
numba_loop(test_np_array)

array([  1,  11,  12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,
        23,  24,  25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,
        36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,  48,
        49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,
        62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,
        75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,
        88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100,
       101, 102, 103, 104, 105, 106, 107, 108, 109])

In [48]:
%timeit numba_loop(test_np_array)

594 ns ± 4.91 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [67]:
False * 1

0

## Specify *utility functions*, *params* and *options*

In [2]:
utility_functions = {
    "utility": utility_func_crra,
    "inverse_marginal_utility": inverse_marginal_utility_crra,
    "marginal_utility": marginal_utility_crra,
}

In [3]:
params = """category,name,value,comment
beta,beta,0.95,discount factor
delta,delta,0,disutility of work
utility_function,theta,1,CRRA coefficient
wage,constant,0.75,age-independent labor income
wage,exp,0.04,return to experience
wage,exp_squared,-0.0004,return to experience squared
shocks,sigma,0.25,shock on labor income sigma parameter/standard deviation
shocks,lambda,2.220400e-16,taste shock (scale) parameter
assets,interest_rate,0.05,interest rate on capital
assets,initial_wealth_low,0,lowest level of initial wealth (relevant for simulation)
assets,initial_wealth_high,30,highest level of initial wealth (relevant for simulation)
assets,max_wealth,75,maximum level of wealth
assets,consumption_floor,0.0,consumption floor/retirement safety net (only relevant in the dc-egm retirement model)
"""

In [4]:
params = pd.read_csv(io.StringIO(params), index_col=["category", "name"])
params

Unnamed: 0_level_0,Unnamed: 1_level_0,value,comment
category,name,Unnamed: 2_level_1,Unnamed: 3_level_1
beta,beta,0.95,discount factor
delta,delta,0.0,disutility of work
utility_function,theta,1.0,CRRA coefficient
wage,constant,0.75,age-independent labor income
wage,exp,0.04,return to experience
wage,exp_squared,-0.0004,return to experience squared
shocks,sigma,0.25,shock on labor income sigma parameter/standard...
shocks,lambda,2.2204e-16,taste shock (scale) parameter
assets,interest_rate,0.05,interest rate on capital
assets,initial_wealth_low,0.0,lowest level of initial wealth (relevant for s...


In [5]:
params_dict = params_todict(params)

In [6]:
options = """n_periods: 25
min_age: 20
n_discrete_choices: 1
grid_points_wealth: 100
quadrature_points_stochastic: 10
n_simulations: 10
n_exog_processes: 1
"""

In [7]:
options = yaml.safe_load(options)
options

{'n_periods': 25,
 'min_age': 20,
 'n_discrete_choices': 1,
 'grid_points_wealth': 100,
 'quadrature_points_stochastic': 10,
 'n_simulations': 10,
 'n_exog_processes': 1}

## Specify inputs for function **marginal_util_and_exp_max_value_states_period**

In [8]:
(
    compute_utility,
    compute_marginal_utility,
    compute_inverse_marginal_utility,
    compute_value,
    compute_next_period_wealth,
    transition_function,
) = get_partial_functions(
    params_dict,
    options,
    utility_functions,
    budget_constraint,
    get_transition_matrix_by_state,
)

In [9]:
taste_shock_scale = 2.220400e-16
exogenous_savings_grid = pickle.load(
    open("profiling_resources/exogenous_savings_grid.pkl", "rb")
)
income_shock_draws = pickle.load(
    open("profiling_resources/income_shock_draws.pkl", "rb")
)
income_shock_weights = pickle.load(
    open("profiling_resources/income_shock_weights.pkl", "rb")
)
possible_child_states = pickle.load(
    open("profiling_resources/possible_child_states.pkl", "rb")
)
choices_child_states = pickle.load(
    open("profiling_resources/choices_child_states.pkl", "rb")
)
policies_child_states = pickle.load(
    open("profiling_resources/policies_child_states.pkl", "rb")
)
values_child_states = pickle.load(
    open("profiling_resources/values_child_states.pkl", "rb")
)

In [10]:
num_states = 50000
states_repeated = jnp.repeat(possible_child_states, num_states, axis=0)
choices_repeated = jnp.repeat(choices_child_states, num_states, axis=0)
policies_repeated = jnp.repeat(policies_child_states, num_states, axis=0)
values_repeated = jnp.repeat(values_child_states, num_states, axis=0)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [11]:
num_savings_repeats = 5
savings_repeated = jnp.repeat(exogenous_savings_grid, num_savings_repeats, axis=0)

In [12]:
partial_func = partial(
    marginal_util_and_exp_max_value_states_period,
    compute_next_period_wealth=compute_next_period_wealth,
    compute_marginal_utility=compute_marginal_utility,
    compute_value=compute_value,
)
partil_jit = jax.jit(partial_func)

In [13]:
jited_vmap = jit(
    vmap(
        vmap(
            vmap(compute_next_period_wealth, in_axes=(0, None, None)),
            in_axes=(None, 0, None),
        ),
        in_axes=(None, None, 0),
    )
)

In [23]:
def loop_func(states, savings, income, states_per_period, runs):
    for i in range(runs):
        result = vmap(
            vmap(
                vmap(compute_next_period_wealth, in_axes=(0, None, None)),
                in_axes=(None, 0, None),
            ),
            in_axes=(None, None, 0),
        )(states[:states_per_period], savings, income)
    return result

In [42]:
jitted_lopp_func = jit(
    partial(loop_func, states_per_period=int(num_states / 500), runs=500)
)

In [43]:
%timeit jitted_lopp_func(states=states_repeated, savings=savings_repeated, income=income_shock_draws).block_until_ready()

145 µs ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [27]:
%timeit jited_vmap(states_repeated, savings_repeated, income_shock_draws).block_until_ready()

288 ms ± 21.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
int_inner = 10000
int_outer = 50000


def func_vmaped_0(non_arg):
    result = 10 + non_arg
    #     result = vmap(lambda a : a + 10, in_axes=(0))()
    return result


outer_range = jnp.arange(int_outer)
jitted_vmap = jax.jit(jax.vmap(func_vmaped_0, in_axes=(0)))

%timeit jitted_vmap(np.arange(int_outer * int_inner)).block_until_ready()

1.55 s ± 99.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
np.sum(np.zeros((int_outer, int_inner)) + 10, np.arange

array([[10., 10., 10., ..., 10., 10., 10.],
       [10., 10., 10., ..., 10., 10., 10.],
       [10., 10., 10., ..., 10., 10., 10.],
       ...,
       [10., 10., 10., ..., 10., 10., 10.],
       [10., 10., 10., ..., 10., 10., 10.],
       [10., 10., 10., ..., 10., 10., 10.]])

In [26]:
jnp?

In [28]:
range_inner = jnp.arange(int_inner)


def func_vmaped_2(non_arg):
    result_0 = jax.vmap(lambda a: a + 10 + non_arg, in_axes=(0))(range_inner[:5000])
    result = jax.vmap(lambda a: a + 10 + non_arg, in_axes=(0))(range_inner[:5000])
    return jnp.concatenate([result_0, result])


outer_range = jnp.arange(int_outer)
jitted_vmap = jax.jit(jax.vmap(func_vmaped_2, in_axes=(0)))
entry_array = jnp.arange(int_outer)
%timeit jitted_vmap(entry_array).block_until_ready()

434 ms ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [56]:
def func_vmaped_3(non_arg):
    range_inner = jnp.arange(int_inner)
    result = jnp.array([])
    for i in range_inner:
        result = jnp.append(result, i + 10 + non_arg)
    return result


outer_range = jnp.arange(int_outer)
jitted_vmap = jax.jit(vmap(func_vmaped_3, in_axes=(0)))

%timeit jitted_vmap(np.arange(int_outer)).block_until_ready()


KeyboardInterrupt



In [52]:
with jax.profiler.trace("jax-trace", create_perfetto_link=True):
    marginal_util, max_exp_value = marginal_util_and_exp_max_value_states_period(
        compute_next_period_wealth,
        compute_marginal_utility,
        compute_value,
        taste_shock_scale,
        savings_repeated,
        income_shock_draws,
        income_shock_weights,
        states_repeated,
        choices_repeated,
        policies_repeated,
        values_repeated,
    )

2023-02-17 21:12:23.695047: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-02-17 21:12:24.906046: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-02-17 21:12:24.906150: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-02-17 21:12:25.416486: E external/org_tensorflow/tensorflow/compiler/xla/python/profiler/internal/python_hooks.cc:398] Can't import tensorflow.python.profiler.trace
2023-02-17 21:12:25.785160: E external/org_tensorflow/tensorflow/compiler/xla/python/profiler/i

In [1]:
# use jax profiler
def profiled_marginal_util_and_exp_max_value_states_period(
    compute_next_period_wealth,
    compute_marginal_utility,
    compute_value,
    taste_shock_scale,
    exogenous_savings_grid,
    income_shock_draws,
    income_shock_weights,
    possible_child_states,
    choices_child_states,
    policies_child_states,
    values_child_states,
):
    return profiler.call(
        marginal_util_and_exp_max_value_states_period,
        compute_next_period_wealth,
        compute_marginal_utility,
        compute_value,
        taste_shock_scale,
        exogenous_savings_grid,
        income_shock_draws,
        income_shock_weights,
        possible_child_states,
        choices_child_states,
        policies_child_states,
        values_child_states,
    )

In [12]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    result = profiled_marginal_util_and_exp_max_value_states_period(
        compute_next_period_wealth,
        compute_marginal_utility,
        compute_value,
        taste_shock_scale,
        exogenous_savings_grid,
        income_shock_draws,
        income_shock_weights,
        possible_child_states,
        choices_child_states,
        policies_child_states,
        values_child_states,
    )
    result.block_until_ready()

ValueError: Invalid trace folder: /tmp/jax-trace/plugins/profile/2023_02_15_11_28_40

In [19]:
pip install tensorflow tensorboard-plugin-profile

Collecting tensorflow
  Downloading tensorflow-2.11.0-cp39-cp39-macosx_10_14_x86_64.whl (244.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m244.3/244.3 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting tensorboard-plugin-profile
  Downloading tensorboard_plugin_profile-2.11.1-py3-none-any.whl (5.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.4/5.4 MB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hCollecting tensorflow-io-gcs-filesystem>=0.23.1
  Downloading tensorflow_io_gcs_filesystem-0.30.0-cp39-cp39-macosx_10_14_x86_64.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0mm
[?25hCollecting tensorboard<2.12,>=2.11
  Downloading tensorboard-2.11.2-py3-none-any.whl (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m00:01

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.1/77.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting oauthlib>=3.0.0
  Downloading oauthlib-3.2.2-py3-none-any.whl (151 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m151.7/151.7 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorboard-plugin-wit, pyasn1, libclang, flatbuffers, wrapt, werkzeug, termcolor, tensorflow-io-gcs-filesystem, tensorflow-estimator, tensorboard-data-server, rsa, pyasn1-modules, protobuf, oauthlib, keras, h5py, gviz-api, grpcio, google-pasta, gast, cachetools, astunparse, absl-py, tensorboard-plugin-profile, requests-oauthlib, markdown, google-auth, google-auth-oauthlib, tensorboard, tensorflow
Successfully installed absl-py-1.4.0 astunparse-1.6.3 cachetools-5.3.0 flatbuffers-23.1.21 gast-0.4.0 google-auth-2.16.0 google-auth-oauthlib-0.4.6 google-pasta-0.2.0 grpcio-1.51.1 gviz-api-1.10.0 h5py-3.8.0 keras-2.11.0 libclang-15.

In [23]:
# from ChatGPT conversation
import tensorflow as tf
from jax import jit

# Trace the function to create a computation graph
traced_computation = jit(profiled_marginal_util_and_exp_max_value_states_period)
# Use TensorBoard to visualize the computation graph
with tf.summary.create_file_writer("logs").as_default():
    tf.summary.trace_on(graph=True, profiler=True)
    marginal_util, max_exp_value = marginal_util_and_exp_max_value_states_period(
        compute_next_period_wealth,
        compute_marginal_utility,
        compute_value,
        taste_shock_scale,
        savings_repeated,
        income_shock_draws,
        income_shock_weights,
        states_repeated,
        choices_repeated,
        policies_repeated,
        values_repeated,
    )
    with tf.summary.record_if(True):
        tf.summary.trace_export(
            name="my_computation_graph", step=0, profiler_outdir="logs"
        )

Instructions for updating:
use `tf.profiler.experimental.start` instead.
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
Instructions for updating:
`tf.python.eager.profiler` has deprecated, use `tf.profiler` instead.
Instructions for updating:
`tf.python.eager.profiler` has deprecated, use `tf.profiler` instead.


In [30]:
result = jax.xla_computation(marginal_util_and_exp_max_value_states_period)(
    compute_next_period_wealth,
    compute_marginal_utility,
    compute_value,
    taste_shock_scale,
    savings_repeated,
    income_shock_draws,
    income_shock_weights,
    states_repeated,
    choices_repeated,
    policies_repeated,
    values_repeated,
)

TypeError: Cannot interpret '<CompiledFunction of functools.partial(<function budget_constraint at 0x13fdb4ee0>, params_dict={'beta': 0.95, 'delta': 0.0, 'theta': 1.0, 'constant': 0.75, 'exp': 0.04, 'exp_squared': -0.0004, 'sigma': 0.25, 'lambda': 2.2204e-16, 'interest_rate': 0.05, 'initial_wealth_low': 0.0, 'initial_wealth_high': 30.0, 'max_wealth': 75.0, 'consumption_floor': 0.0}, options={'n_periods': 25, 'min_age': 20, 'n_discrete_choices': 1, 'grid_points_wealth': 100, 'quadrature_points_stochastic': 10, 'n_simulations': 10, 'n_exog_processes': 1})>' as a data type