<a href="https://colab.research.google.com/github/MatiasCovarrubias/jaxecon/blob/main/Rbc_CES.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DEQN Solver in Jax: Prelims

This notebook trains a neural net to output the optimal policy of an RBC model with CES production function.



In [None]:
# BACKEND RELATED: the default backend is CPU (change in Edit -> Notebook settings)
GPU = True # set True if using GPU (only to see GPU)
if GPU:
  !nvidia-smi

# precision
from jax import numpy as jnp, config as jax_config,  lax, random
double_precision = True
if double_precision:
  jax_config.update("jax_enable_x64", True)
  precision = jnp.float64
else:
  precision = jnp.float32

# Imports
import matplotlib.pyplot as plt, flax.linen as nn, pandas as pd, jax, flax, optax, os, json
from flax.training.train_state import TrainState  # Useful dataclass to keep train state
from flax.training import checkpoints
# from jax.scipy.optimize import minimize
%pip install jaxopt
import jaxopt
from time import time
from typing import Sequence
jax_config.update("jax_debug_nans", True)

! git clone https://github.com/MatiasCovarrubias/jaxecon
import sys
sys.path.insert(0,'/content/jaxecon')
from DEQN.neural_nets.neural_nets import NeuralNet
from DEQN.econ_models.rbc_ces import RbcCES_SteadyState, RbcCES
from DEQN.algorithm.simulation import create_episode_simul_fn
from DEQN.algorithm.loss import create_batch_loss_fn
from DEQN.algorithm.epoch_train import create_epoch_train_fn
from DEQN.algorithm.eval import create_eval_fn
from DEQN.analysis.simul_analysis import create_episode_simul_verbose_fn, create_descstats_fn
from DEQN.analysis.stochastic_ss import create_stochss_fn
# Mount Google Drive to store results (a pop up will appear, follow instructions)
from google.colab import drive
drive.mount('/content/drive')
working_dir = "/content/drive/MyDrive/Jaxecon/RbcCES/Experiments/Params_compstat/"



Mon May 20 19:39:24 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0              43W / 400W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# Create Econ model


In [None]:
# Set parameters

beta=0.96
alpha=0.31
delta=0.057
sigma_y=0.8
eps_c=0.5
eps_l=0.5
rho=0.69
phi=5
shock_sd=0.0153

# Find steady state and theta
econ_model_ss = RbcCES_SteadyState(precision=precision, beta=beta, alpha=alpha, delta=delta, sigma_y=sigma_y, eps_c=eps_c, eps_l=eps_l)
initial_policy = jnp.array([1.0000126, 0.41989392, 2.2284806, 0.12702352, 0.99998146, 0.99996126, 1.1270372, 12.25385])

@jax.jit
def optimize_policy(initial_policy):
    # result = minimize_sc(jax.jit(econ_model_ss.loss), initial_policy, method='BFGS', options={'disp': True})
    solver = jaxopt.BFGS(econ_model_ss.loss, tol=1e-09, verbose=False)
    ss_policy, state = solver.run(initial_policy)
    return ss_policy, state


ss_policy, state = optimize_policy(initial_policy)
print("loss", state.error)
print("steady state", ss_policy)

policies_ss = jnp.log(ss_policy[:7])
theta = ss_policy[-1]
print("theta", theta)
# create econ_model
econ_model = RbcCES(precision=precision, policies_ss=policies_ss, theta=theta, beta=beta, alpha=alpha, delta=delta, sigma_y=sigma_y, eps_c=eps_c, eps_l=eps_l, rho=rho, phi=phi, shock_sd=shock_sd)

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 4.0726188426530533e-10 Stepsize:1.0  Decrease Error:2.0012576681408925e-10  Curvature Error:4.0726188426530533e-10 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.12981847422057885  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.19472771133086827  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.2920915669963024  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 3.1800239501256695e-10 Stepsize:0.43813735049445357  Decrease Error:6.044458097745855e-11  Curvature Error:3.1800239501256695e-10 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0

# Configure experiment

In [None]:
'''Config dictionary'''

# CREATE CONFIG DICT
config = {
    # general
    "seed": 48,
    "exper_name": "May20_correctedv3",
    "save_dir": working_dir,
    "restore": False,                                                            # True if start from restored checkpoint
    "restore_exper_name": "",
    "seed": 5,

    # neural net
    "layers": [64,64],              # layers of the NN

    # learning rate schedule
    "lr_sch_values": [0.0005,0.0005],                                        # values (from the last, we do cosine decay to 0)
    "lr_sch_transitions": [2000],
    "lr_end_value": 1e-7,

    # simulation
    "periods_per_epis": 32,
    "simul_vol_scale": 1.5,        # scale of volatility while simul
    "init_range": 10,

    # loss calculation
    "mc_draws": 16,               # monte-carlo draws

    # training
    "epis_per_step": 512,         # epoch per steps
    "steps_per_epoch": 100,       # steps per epoch
    "n_epochs": 100,               # number of epochs
    "batch_size": 16,             # size of batch of obs to calculate grads
    "init_range": 0,
    "checkpoint_frequency": 1000,

    "config_eval": {
      "periods_per_epis": 64,      # periods per episode
      "mc_draws": 128,         # number of mc draws
      "simul_vol_scale": 1,        # scale of volatility while simul
      "eval_n_epis": 128,           # episodes to sample for eval
      "init_range": 0
    }
}

#create auxiliary config variables for readability
config["periods_per_step"] =config["periods_per_epis"]*config["epis_per_step"]
config["n_batches"] = config["periods_per_step"]//config["batch_size"]

# PRINT AND PLOT KEY CONFIGS
print("Number of parameters:")
print(NeuralNet(config["layers"] + [RbcCES().n_actions], precision).tabulate(
    random.PRNGKey(0),
    RbcCES(precision=precision).initial_obs(random.PRNGKey(0))
    ))

print("TOTAL Number of steps (NN updates):", config["steps_per_epoch"]*config["n_epochs"], "episodes \n")


Number of parameters:

[3m                             NeuralNet Summary                              [0m
┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule   [0m[1m [0m┃[1m [0m[1minputs     [0m[1m [0m┃[1m [0m[1moutputs    [0m[1m [0m┃[1m [0m[1mparams                [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ NeuralNet │ [2mfloat64[0m[2]  │ [2mfloat64[0m[7]  │                        │
├─────────┼───────────┼─────────────┼─────────────┼────────────────────────┤
│ Dense_0 │ Dense     │ [2mfloat64[0m[2]  │ [2mfloat64[0m[64] │ bias: [2mfloat64[0m[64]      │
│         │           │             │             │ kernel: [2mfloat64[0m[2,64]  │
│         │           │             │             │                        │
│         │           │             │             │ [1m192 [0m[1;2m(1.5 KB)[0m           │
├─────────┼──

# Create experiment
Now we create code for the entire experiment as a function to call later.

In [None]:
def run_experiment(econ_model, config):
  """Runs experiment."""


  # CREATE NN, RNGS, TRAIN_STATE AND EPOQUE UPDATE
  nn = NeuralNet(features = config["layers"] + [econ_model.n_actions], precision = precision)
  rng, rng_pol, rng_econ_model, rng_epoch, rng_eval = random.split(random.PRNGKey(config["seed"]), num=5)  # random number generator

  # CREATE LR SCHEDULE
  lr_schedule = optax.join_schedules(
    schedules= [optax.constant_schedule(i) for i in config["lr_sch_values"][:-1]]
                  + [optax.warmup_cosine_decay_schedule(
                    init_value=config["lr_sch_values"][-1],
                    peak_value=config["lr_sch_values"][-1],
                    warmup_steps=0,
                    decay_steps=config["n_epochs"]*config["steps_per_epoch"]-config["lr_sch_transitions"][-1],
                    end_value=config["lr_end_value"],)],
      boundaries=config["lr_sch_transitions"] # the number of episodes at which to switch
      )

  # INITIALIZE OR RESTORE FULL NN TRAIN STATE
  if not config["restore"]:
    params=nn.init(rng_pol, jnp.zeros_like(econ_model.initial_obs(rng_econ_model)))
    train_state = TrainState.create(apply_fn=nn.apply, params=params, tx=optax.adam(lr_schedule))
  else:
    train_state_restored = checkpoints.restore_checkpoint(ckpt_dir=config["working_dir"]+config["restore_run_name"], target = None)
    params = train_state_restored["params"]
    opt_state = train_state_restored["opt_state"]
    train_state = TrainState.create(apply_fn=nn.apply, params=params, tx=optax.adam(lr_schedule))
    train_state.replace(opt_state=opt_state)

  # GET TRAIN AND EVAL FUNCTIONS
  simul_fn = jax.jit(create_episode_simul_fn(econ_model,config))
  loss_fn = jax.jit(create_batch_loss_fn(econ_model,config))
  train_epoch_fn  = jax.jit(create_epoch_train_fn(econ_model, config, simul_fn, loss_fn))
  eval_fn  = jax.jit(create_eval_fn(config, simul_fn, loss_fn))

  # COMPILE CODE
  time_start = time()
  train_epoch_fn(train_state, rng_epoch)  # compiles
  eval_fn(train_state, rng_epoch) # compiles
  time_compilation = time() - time_start
  print("Time Elapsed for Compilation:", time_compilation, "seconds")

  # RUN AN EPOCH TO GET TIME STATS
  time_start = time()
  train_epoch_fn(train_state, rng_epoch) # run one epoque
  time_epoch = time() - time_start
  print("Time Elapsed for epoch:", time_epoch, "seconds")

  time_start = time()
  eval_fn(train_state, rng_epoch) # run one epoque
  time_eval = time() - time_start
  print("Time Elapsed for eval:", time_eval, "seconds")

  time_experiment = (time_epoch + time_eval)*config["n_epochs"]/60
  print("Estimated time for full experiment", time_experiment, "minutes")

  steps_per_second = config["steps_per_epoch"]*config["periods_per_step"]/time_epoch
  print("Steps per second:", steps_per_second, "st/s")

  # CREATE LISTS TO STORE METRICS
  mean_losses, max_losses, mean_accuracy, min_accuracy = [], [], [], []

  # RUN ALL THE EPOCHS
  time_start = time()
  for i in range(1,config["n_epochs"]+1):

    # eval
    eval_metrics = eval_fn(train_state, rng_eval)
    print('EVALUATION:\n',
      'Iteration:', train_state.step,
      "Mean Loss:", eval_metrics[0],
      ", Max Loss:", eval_metrics[1],
      ", Mean Acc:", eval_metrics[2],
      ", Min Acc:", eval_metrics[3], "\n"
      ", Mean Accs Foc", eval_metrics[4], "\n"
      ", Min Accs Foc:", eval_metrics[5],
      "\n")

    # run epoch
    train_state, rng_epoch, epoch_metrics = train_epoch_fn(train_state, rng_epoch)
    print('TRAINING:\n',
          'Iteration:', train_state.step,
          ", Mean Loss:", jnp.mean(epoch_metrics[0]),
          ", Max Loss:", jnp.mean(epoch_metrics[1]),
          ", Mean Acc:", jnp.mean(epoch_metrics[2]),
          ", Min Acc:", jnp.min(epoch_metrics[3]),
          ", Learning rate:", lr_schedule(train_state.step),
          "\n"
          )

    # checkpoint
    if train_state.step>=config["checkpoint_frequency"] and train_state.step%config["checkpoint_frequency"]==0:
      checkpoints.save_checkpoint(ckpt_dir=config['save_dir']+config['exper_name'], target=train_state, step=train_state.step)

      # store results
      mean_losses.append(float(eval_metrics[0]))
      max_losses.append(float(eval_metrics[1]))
      mean_accuracy.append(float(eval_metrics[2]))
      min_accuracy.append(float(eval_metrics[3]))

    #end of inner loop

  # PRINT RESULTS
  print("Minimum mean loss attained in evaluation:", min(mean_losses))
  print("Minimum max loss attained in evaluation:", min(max_losses))
  print("Maximum mean accuracy attained in evaluation:", max(mean_accuracy))
  print("Maximum min accuracy attained in evaluation:", max(min_accuracy))
  time_fullexp = (time() - time_start)/60
  print("Time Elapsed for Full Experiment:", time_fullexp, "minutes")

  # STORE RESULTS
  results = {
    "exper_name": config["exper_name"],
    "min_mean_loss":  min(mean_losses),
    "min_max_loss": max(max_losses),
    "max_mean_acc": max(mean_accuracy),
    "max_min_acc": max(min_accuracy),
    "time_full_exp_minutes": time_fullexp,
    "time_epoque_seconds": time_epoch,
    "time_compilation_seconds": time_compilation,
    "steps_per_second": steps_per_second,
    "config": config,
    "mean_losses_list": mean_losses,
    "max_losses_list": max_losses,
    "mean_acc_list": mean_accuracy,
    "min_acc_list": min_accuracy,
  }

  # store to json
  if not os.path.exists(config['save_dir']+config['exper_name']):
    os.mkdir(config['save_dir']+config['exper_name'])
  with open(config['save_dir']+config['exper_name']+"/results.json", "w") as write_file:
    json.dump(results, write_file)

  # PLOT LEARNING

  # Mean Losses
  plt.plot([(i+1)*config["checkpoint_frequency"] for i in range(len(mean_losses))], mean_losses)
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Mean Losses')
  plt.savefig(config['save_dir']+config['exper_name']+'/mean_losses.jpg')
  plt.close()

  plt.plot([(i+1)*config["checkpoint_frequency"] for i in range(len(max_losses))], max_losses)
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Max Losses')
  plt.savefig(config['save_dir']+config['exper_name']+'/max_losses.jpg')
  plt.close()

  # Mean Accuracy
  plt.plot([(i+1)*config["checkpoint_frequency"] for i in range(len(mean_accuracy))], mean_accuracy)
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Mean Accuracy (%)')
  plt.savefig(config['save_dir']+config['exper_name']+'/mean_accuracy.jpg')
  plt.close()

  # Min Accuracy
  plt.plot([(i+1)*config["checkpoint_frequency"] for i in range(len(min_accuracy))], min_accuracy)
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Minimum Accuracy (%)')
  plt.savefig(config['save_dir']+config['exper_name']+'/min_accuracy.jpg')
  plt.close()

  # Learning rate schedule
  plt.plot([(i+1)*config["checkpoint_frequency"] for i in range(len(mean_losses))], [lr_schedule((i+1)*config["checkpoint_frequency"]) for i in range(len(mean_losses))])
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Learning Rate')
  plt.savefig(config['save_dir']+config['exper_name']+'/learning_rate.jpg')
  plt.close()

  return train_state

# Run experiment
*Finally*, we run the experiment abd get the trained parameter plus useful info.

In [None]:
trained_train_state = run_experiment(econ_model, config)


Time Elapsed for Compilation: 6.052121639251709 seconds
Time Elapsed for epoch: 0.5464625358581543 seconds
Time Elapsed for eval: 0.0049397945404052734 seconds
Estimated time for full experiment 0.9190038839975992 minutes
Steps per second: 2998192.725924181 st/s
EVALUATION:
 Iteration: 0 Mean Loss: 0.11436209094180974 , Max Loss: 0.7948116957212975 , Mean Acc: 0.7520290058915519 , Min Acc: 0.10847787704325729 
, Mean Accs Foc [0.33811368 0.50984866 0.96659638 0.89222323 0.84613051 0.99261497
 0.71867562] 
, Min Accs Foc: [0.33636888 0.10847788 0.92761977 0.8862874  0.80969721 0.89994026
 0.60340746] 

TRAINING:
 Iteration: 100 , Mean Loss: 0.03273297418924186 , Max Loss: 0.18583739514082426 , Mean Acc: 0.8918663160931923 , Min Acc: 0.05554514026485724 , Learning rate: 0.0005 

EVALUATION:
 Iteration: 100 Mean Loss: 0.002821582791908187 , Max Loss: 0.01894669886718434 , Mean Acc: 0.9591209294616589 , Min Acc: 0.8623529917971904 
, Mean Accs Foc [0.90612665 0.97176735 0.9706826  0.989059

# Analysis

In [None]:
config_analysis = {
    "init_range": 0,
    "periods_per_epis": 2000000,
    "simul_vol_scale": 1,
}

config_stochss = {
    "n_draws": 2000,
    "time_to_converge": 500,
    "seed": 0
}

rng_analysis=random.PRNGKey(4)


# create functions
simul_fn_verbose = jax.jit(create_episode_simul_verbose_fn(econ_model, config_analysis))
descstats_fn = create_descstats_fn(econ_model, config_analysis)
stochss_fn = jax.jit(create_stochss_fn(econ_model, config_stochss))

# simulate model
simul_obs, simul_policies = simul_fn_verbose(trained_train_state, rng_analysis)
simul_policies_logdev = jnp.log(simul_policies)
descstats_df, autocorr_df = descstats_fn(simul_policies_logdev)
print("\n descstat: \n", descstats_df)

# calculate stochastic steady state
stochss = stochss_fn(simul_obs, trained_train_state)
print("\n stoch_ss: \n", stochss)





 descstat:
                C        L        K        I        Y
Mean      0.00006  0.00002  0.00053 -0.00037  0.00015
Sd        0.01489  0.00299  0.01874  0.06545  0.02040
Skewness  0.00873 -0.02430 -0.00604 -0.15695 -0.00758
Kurtosis  0.04262  0.20266  0.04703 -0.01917 -0.01837
Q1       -0.03481 -0.00716 -0.04390 -0.15981 -0.04744
Q25      -0.00995 -0.00195 -0.01196 -0.04430 -0.01370
Q50       0.00000  0.00005  0.00030  0.00143  0.00019
Q75       0.01006  0.00200  0.01316  0.04508  0.01399
Q99       0.03489  0.00715  0.04425  0.14291  0.04734

 stoch_ss:
[-0.00013008  0.00015333  0.00023468  0.00173204  0.00023521  0.00071905
  0.0001409 ]


In [None]:
def dataframe_to_latex(df):
    latex_str = "\\begin{table}[h!]\n\\centering\n\\begin{tabular}{l" + "c" * len(df.columns) + "}\n"
    latex_str += " & ".join([""] + list(df.columns)) + " \\\\\n"
    latex_str += "\\hline\n"
    for idx in df.index:
        row = " & ".join([idx] + [f"{val:.5f}" for val in df.loc[idx]])
        latex_str += row + " \\\\\n"
    latex_str += "\\hline\n"
    latex_str += "\\end{tabular}\n\\caption{Descriptive Statistics}\n\\label{table:desc_stats}\n\\end{table}"
    return latex_str

latex_table = dataframe_to_latex(descstats_df)
print(latex_table)
latex_filename = config['save_dir'] + config['exper_name'] + f'/descstat_table.tex'

# Save the LaTeX code to the specified file
with open(latex_filename, 'w') as file:
    file.write(latex_table)

\begin{table}[h!]
\centering
\begin{tabular}{lccccc}
 & C & L & K & I & Y \\
\hline
Mean & 0.00006 & 0.00002 & 0.00053 & -0.00037 & 0.00015 \\
Sd & 0.01489 & 0.00299 & 0.01874 & 0.06545 & 0.02040 \\
Skewness & 0.00873 & -0.02430 & -0.00604 & -0.15695 & -0.00758 \\
Kurtosis & 0.04262 & 0.20266 & 0.04703 & -0.01917 & -0.01837 \\
Q1 & -0.03481 & -0.00716 & -0.04390 & -0.15981 & -0.04744 \\
Q25 & -0.00995 & -0.00195 & -0.01196 & -0.04430 & -0.01370 \\
Q50 & 0.00000 & 0.00005 & 0.00030 & 0.00143 & 0.00019 \\
Q75 & 0.01006 & 0.00200 & 0.01316 & 0.04508 & 0.01399 \\
Q99 & 0.03489 & 0.00715 & 0.04425 & 0.14291 & 0.04734 \\
\hline
\end{tabular}
\caption{Descriptive Statistics}
\label{table:desc_stats}
\end{table}


In [None]:
table_titles = ["Consumption", "Labor", "Capital", "Investment",  "Price", "Price of Capital", " GDP"]
table_data_strings = [str(round(float(value),5)) for value in stochss]

# Initialize the LaTeX table string
latex_table = "\\begin{tabular}{" + "c" * 7 + "}\n\\hline\n"

# Add column titles
latex_table += " & ".join(table_titles) + " \\\\\n\\hline\n"

# Add rows of values

latex_table += " & ".join(table_data_strings) +  " \\\\\n"

# End the table
latex_table += "\\hline\n \\end{tabular}"

# Print the LaTeX table string
print(latex_table)
latex_filename = config['save_dir'] + config['restore_exper_name'] + '/stoch_ss_table.tex'

# Save the LaTeX code to the specified file
with open(latex_filename, 'w') as file:
    file.write(latex_table)

\begin{tabular}{ccccccc}
\hline
Consumption & Labor & Capital & Investment & Price & Price of Capital &  GDP \\
\hline
-0.00013 & 0.00015 & 0.00023 & 0.00173 & 0.00024 & 0.00072 & 0.00014 \\
\hline
 \end{tabular}


In [None]:
# DISCONNECT SESSION (uncomment next 2 lines if you do large run)
# from google.colab import runtime
# runtime.unassign()