# Jax profiler test: marginal_util_and_exp_max_value_states_period

In [56]:
import io
import pickle
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import yaml
#from toy_models.consumption_retirement_model import (compute_expected_value, compute_next_period_marginal_utility,inverse_marginal_utility_crra, utility_func_crra)
from toy_models.consumption_retirement_model.utility_functions import utility_func_crra, marginal_utility_crra, inverse_marginal_utility_crra
from toy_models.consumption_retirement_model.budget_functions import budget_constraint
from toy_models.consumption_retirement_model.exogenous_processes import get_transition_matrix_by_state
from dcegm.pre_processing import get_partial_functions, params_todict
from dcegm.marg_utilities_and_exp_value import marginal_util_and_exp_max_value_states_period
import jax
import jax.numpy as jnp
from jax import profiler
import os
from dcegm.solve import solve_dcegm

## Specify *utility functions*, *params* and *options*

In [57]:
utility_functions = {
    "utility": utility_func_crra,
    "inverse_marginal_utility": inverse_marginal_utility_crra,
    "marginal_utility": marginal_utility_crra,
}

In [58]:
params = """category,name,value,comment
beta,beta,0.95,discount factor
delta,delta,0,disutility of work
utility_function,theta,1,CRRA coefficient
wage,constant,0.75,age-independent labor income
wage,exp,0.04,return to experience
wage,exp_squared,-0.0004,return to experience squared
shocks,sigma,0.25,shock on labor income sigma parameter/standard deviation
shocks,lambda,2.220400e-16,taste shock (scale) parameter
assets,interest_rate,0.05,interest rate on capital
assets,initial_wealth_low,0,lowest level of initial wealth (relevant for simulation)
assets,initial_wealth_high,30,highest level of initial wealth (relevant for simulation)
assets,max_wealth,75,maximum level of wealth
assets,consumption_floor,0.0,consumption floor/retirement safety net (only relevant in the dc-egm retirement model)
"""

In [59]:
params = pd.read_csv(io.StringIO(params), index_col=["category", "name"])
params

Unnamed: 0_level_0,Unnamed: 1_level_0,value,comment
category,name,Unnamed: 2_level_1,Unnamed: 3_level_1
beta,beta,0.95,discount factor
delta,delta,0.0,disutility of work
utility_function,theta,1.0,CRRA coefficient
wage,constant,0.75,age-independent labor income
wage,exp,0.04,return to experience
wage,exp_squared,-0.0004,return to experience squared
shocks,sigma,0.25,shock on labor income sigma parameter/standard...
shocks,lambda,2.2204e-16,taste shock (scale) parameter
assets,interest_rate,0.05,interest rate on capital
assets,initial_wealth_low,0.0,lowest level of initial wealth (relevant for s...


In [60]:
params_dict = params_todict(params)

In [61]:
options = """n_periods: 25
min_age: 20
n_discrete_choices: 1
grid_points_wealth: 100
quadrature_points_stochastic: 10
n_simulations: 10
n_exog_processes: 1
"""

In [62]:
options = yaml.safe_load(options)
options

{'n_periods': 25,
 'min_age': 20,
 'n_discrete_choices': 1,
 'grid_points_wealth': 100,
 'quadrature_points_stochastic': 10,
 'n_simulations': 10,
 'n_exog_processes': 1}

## Specify inputs for function **marginal_util_and_exp_max_value_states_period**

In [63]:
compute_utility,compute_marginal_utility,compute_inverse_marginal_utility, compute_value,compute_next_period_wealth, transition_function = get_partial_functions(params_dict,options,utility_functions,budget_constraint,get_transition_matrix_by_state)

In [64]:
taste_shock_scale = 2.220400e-16
exogenous_savings_grid = pickle.load(open("exogenous_savings_grid.pkl", "rb"))
income_shock_draws = pickle.load(open("income_shock_draws.pkl", "rb"))
income_shock_weights = pickle.load(open("income_shock_weights.pkl", "rb"))
possible_child_states = pickle.load(open("possible_child_states.pkl", "rb"))
choices_child_states = pickle.load(open("choices_child_states.pkl", "rb"))
policies_child_states = pickle.load(open("policies_child_states.pkl", "rb"))
values_child_states = pickle.load(open("values_child_states.pkl", "rb"))

In [65]:
marginal_util, max_exp_value = marginal_util_and_exp_max_value_states_period(compute_next_period_wealth,compute_marginal_utility,compute_value,taste_shock_scale,exogenous_savings_grid,income_shock_draws,income_shock_weights,possible_child_states,choices_child_states,policies_child_states,values_child_states)

In [66]:
# use jax profiler
def profiled_marginal_util_and_exp_max_value_states_period(
    compute_next_period_wealth,
    compute_marginal_utility,
    compute_value,
    taste_shock_scale,
    exogenous_savings_grid,
    income_shock_draws,
    income_shock_weights,
    possible_child_states,
    choices_child_states,
    policies_child_states,
    values_child_states,
):
    return profiler.call(marginal_util_and_exp_max_value_states_period, compute_next_period_wealth,compute_marginal_utility,compute_value,taste_shock_scale,exogenous_savings_grid,income_shock_draws,income_shock_weights,possible_child_states,choices_child_states,policies_child_states,values_child_states)

In [67]:
os.environ['JAX_PROFILE'] = '1'

In [68]:
result = profiled_marginal_util_and_exp_max_value_states_period(
    compute_next_period_wealth,
    compute_marginal_utility,
    compute_value,
    taste_shock_scale,
    exogenous_savings_grid,
    income_shock_draws,
    income_shock_weights,
    possible_child_states,
    choices_child_states,
    policies_child_states,
    values_child_states,
)

AttributeError: module 'jax.profiler' has no attribute 'call'

In [69]:
profiler.print_summary()

AttributeError: module 'jax.profiler' has no attribute 'print_summary'