In [1]:
%load_ext autoreload

%autoreload 2

import jax
import jax.numpy as jnp
from jax import vmap, jit
from jax import profiler
import os
from dcegm.backward_induction import solve_dcegm, get_solve_function
from functools import partial
import pandas as pd
import yaml

# from dcegm.pre_processing import convert_params_to_dict, get_partial_functions
import numpy as np
import jax

jax.config.update("jax_enable_x64", True)


TEST_RESOURCES_DIR = "../resources/"

## Import toy model


In [2]:
from toy_models.consumption_retirement_model.budget_functions import budget_constraint
from toy_models.consumption_retirement_model.utility_functions import (
    create_utility_function_dict,
    create_final_period_utility_function_dict,
)

from toy_models.consumption_retirement_model.state_space_objects import (
    create_state_space_function_dict,
)

In [3]:
model = "retirement_with_shocks"

params = pd.read_csv(
    TEST_RESOURCES_DIR + f"replication_tests/{model}/params.csv",
    index_col=["category", "name"],
)
params = params.reset_index()[["name", "value"]].set_index("name")["value"].to_dict()
options = {}

_raw_options = yaml.safe_load(
    open(TEST_RESOURCES_DIR + f"replication_tests/{model}/options" f".yaml", "rb")
)
options["model_params"] = _raw_options
options["model_params"]["n_choices"] = _raw_options["n_discrete_choices"]
options["state_space"] = {
    "n_periods": 25,
    "choices": [i for i in range(_raw_options["n_discrete_choices"])],
}
exog_savings_grid = jnp.linspace(
    0,
    options["model_params"]["max_wealth"],
    options["model_params"]["n_grid_points"],
)

## Timeit overall solve

In [4]:
params

{'beta': 0.9523809523809524,
 'delta': 0.35,
 'rho': 1.95,
 'constant': 0.75,
 'exp': 0.04,
 'exp_squared': -0.0002,
 'sigma': 0.35,
 'lambda': 0.2,
 'interest_rate': 0.05,
 'initial_wealth_low': 0.0,
 'initial_wealth_high': 30.0,
 'max_wealth': 50.0,
 'consumption_floor': 0.001}

In [5]:
backward_jit = get_solve_function(
    options=options,
    exog_savings_grid=exog_savings_grid,
    state_space_functions=create_state_space_function_dict(),
    utility_functions=create_utility_function_dict(),
    budget_constraint=budget_constraint,
    utility_functions_final_period=create_final_period_utility_function_dict(),
)

Update function for state space not given. Assume states only change with an increase of the period and lagged choice.
The batch size of the backwards induction is  3


In [6]:
# https://github.com/google/jax/discussions/11169

jax.block_until_ready(backward_jit(params))
%timeit jax.block_until_ready(backward_jit(params))

35.6 ms ± 201 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
# https://github.com/google/jax/discussions/11169

jax.block_until_ready(backward_jit(params))
%timeit jax.block_until_ready(backwards_jit(params))

## Timeit upper envelope

In [None]:
policy_egm = jnp.array(
    np.genfromtxt(TEST_RESOURCES_DIR + "period_tests/pol10.csv", delimiter=",")
)
value_egm = jnp.array(
    np.genfromtxt(TEST_RESOURCES_DIR + "period_tests/val10.csv", delimiter=",")
)
params_dict = convert_params_to_dict(params)

In [None]:
# test_upp_env = jit(partial(fast_upper_envelope_wrapper, choice=1, params=params_dict, compute_value=compute_value))
test_upp_env = jit(partial(fast_upper_envelope, num_iter=int(value_egm.shape[1])))

In [None]:
jax.block_until_ready(
    test_upp_env(
        endog_grid=policy_egm[0, 1:],
        value=value_egm[1, 1:],
        policy=policy_egm[1, 1:],
        expected_value_zero_savings=value_egm[1, 0],
    )
)

%timeit jax.block_until_ready(test_upp_env(endog_grid=policy_egm[0, 1:], value=value_egm[1, 1:], policy=policy_egm[1, 1:], expected_value_zero_savings=value_egm[1, 0]))

# Profiling

## UI Perfetto

In [None]:
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    # Run the operations to be profiled
    jax.block_until_ready(
        test_upp_env(
            endog_grid=policy_egm[0, 1:],
            value=value_egm[1, 1:],
            policy=policy_egm[1, 1:],
            expected_value_zero_savings=value_egm[1, 0],
        )
    )

## Timing of jax.lax.scan

In [None]:
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

In [None]:
def vectr_function(carry, array_in):
    res = jax.vmap(lambda x: x + carry)(array_in)
    return res.sum(), res

In [None]:
def loop_for(segments):
    carry = 0.0
    res = []
    for segment in segments:
        carry, vec_res = vectr_function(carry, segment)
        res += [vec_res]
    return res

In [None]:
segments = (
    [jnp.arange(1000) for _ in range(200)]
    + [jnp.arange(800) for _ in range(30)]
    + [jnp.arange(500) for _ in range(50)]
)

In [None]:
jax.block_until_ready(jax.jit(loop_for)(segments))
%timeit jax.block_until_ready(jax.jit(loop_for)(segments))

In [None]:
segments_even = ravel_pytree(segments)[0].reshape(-1, 100)

In [None]:
jax.block_until_ready(jax.lax.scan(vectr_function, 0.0, xs=segments_even))
%timeit jax.block_until_ready(jax.lax.scan(vectr_function, 0.0, xs=segments_even))

In [None]:
def body_1(carry, _it_step):
    new_carry = carry + 1.5
    return new_carry, new_carry


def body_2(carry, _it_step):
    new_carry = carry + 1
    return new_carry, new_carry


test_body_1 = jit(lambda start: jax.lax.scan(body_1, start, xs=None, length=5000))
test_body_2 = jit(lambda start: jax.lax.scan(body_2, start, xs=None, length=5000))

In [None]:
test_body_1(1.0)
%timeit test_body_1(1.0)

In [None]:
test_body_2(1)
%timeit test_body_2(1)