In [1]:
import os
import time
import jax

import numpy as np
import pandas as pd
import sys
import yaml
jax.config.update("jax_enable_x64", True)
import pickle

In [None]:
# Mount drive. Make sure everything(including submodules) are there
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd drive/MyDrive/colab_notebooks/policy_unc/analysis

In [6]:
analysis_path = os.getcwd() + "/../../"
# analysis_path = os.getcwd() + "/"
sys.path.insert(0, analysis_path + "submodules/dcegm/src/")
sys.path.insert(0, analysis_path + "src/")
from set_paths import create_path_dict
path_dict = create_path_dict(analysis_path)

In [10]:
params = pickle.load(open(path_dict["est_results"] + "est_params.pkl", "rb"))

In [20]:
from model_code.policy_states_belief import (update_specs_exp_ret_age_trans_mat, 
                                             expected_SRA_probs_estimation)
from model_code.budget_equation import create_savings_grid
from model_code.specify_model import specify_model
from dcegm.solve import get_solve_func_for_model


savings_grid = create_savings_grid()
grid_size_steps = [3.5, 1, 0.5, 0.25, 0.1, 0.05, 0.025, 0.01]
n_exps = len(grid_size_steps)

times = np.zeros((n_exps, 6))

for id_exp, grid_size in enumerate(grid_size_steps):
    specs = yaml.safe_load(open(path_dict["specs"]))
    specs["SRA_grid_size"] = grid_size
    yaml.safe_dump(specs, open(path_dict["specs"], "w"))
    
    # Specify the model
    before_spec = time.time()
    model, options, params = specify_model(
        path_dict=path_dict,
        update_spec_for_policy_state=update_specs_exp_ret_age_trans_mat,
        policy_state_trans_func=expected_SRA_probs_estimation,
        params=params,
        load_model=False,
    )
    n_policy_states = options["model_params"]["n_policy_states"]
    after_spec = time.time()
    spec_time = after_spec - before_spec
    times[id_exp, 0] = spec_time
    np.savetxt(path_dict["intermediate_data"] + "times.csv", times, delimiter=",")
    print(f"Specified model for {n_policy_states} took {spec_time}")
    
    # Get solve function
    solve_func = get_solve_func_for_model(model, savings_grid, options)
    after_get_solve = time.time()
    get_solve_time = after_get_solve - after_spec
    times[id_exp, 1] = get_solve_time
    np.savetxt(path_dict["intermediate_data"] + "times.csv", times, delimiter=",")
    print(f"Got solve function for {n_policy_states} took {get_solve_time}")
    
    # Solve the model first time
    jax.block_until_ready(solve_func(params))
    after_first_solve = time.time()
    first_solve_time = after_first_solve - after_get_solve
    times[id_exp, 2] = first_solve_time
    np.savetxt(path_dict["intermediate_data"] + "times.csv", times, delimiter=",")
    print(f"First solve for {n_policy_states} took {first_solve_time}")
    
    # Solve the model second time
    jax.block_until_ready(solve_func(params))
    after_second_solve = time.time()
    second_solve_time = after_second_solve - after_first_solve
    times[id_exp, 3] = second_solve_time
    np.savetxt(path_dict["intermediate_data"] + "times.csv", times, delimiter=",")
    print(f"Second solve for {n_policy_states} took {second_solve_time}")
    
    # Solve the model third time
    jax.block_until_ready(solve_func(params))   
    after_third_solve = time.time()
    third_solve_time = after_third_solve - after_second_solve
    times[id_exp, 4] = third_solve_time
    np.savetxt(path_dict["intermediate_data"] + "times.csv", times, delimiter=",")
    print(f"Third solve for {n_policy_states} took {third_solve_time}")
    
    # Solve the model fourth time
    jax.block_until_ready(solve_func(params))
    after_fourth_solve = time.time()
    fourth_solve_time = after_fourth_solve - after_third_solve
    times[id_exp, 5] = fourth_solve_time
    np.savetxt(path_dict["intermediate_data"] + "times.csv", times, delimiter=",")
    print(f"Fourth solve for {n_policy_states} took {fourth_solve_time}")
    print(f"Total time for {n_policy_states} is {np.sum(times[id_exp, :])}")
    
    

3
8
15
29
71
141
281
701
