<a href="https://colab.research.google.com/github/MatiasCovarrubias/jaxecon/blob/main/RbcProdNet_Sep13_2025.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DEQN Solver in Jax: Prelims

This notebook trains a neural net to output the optimal policy of a nonlinear Rbc model.



In [None]:
!pip install --upgrade jax jaxlib

# PRECISION
from jax import numpy as jnp, lax, random, config
double_precision = True
if double_precision:
  config.update("jax_enable_x64", True)
  precision = jnp.float64
else:
  precision = jnp.float32

# IMPORTS
import matplotlib.pyplot as plt, numpy as np, pandas as pd, jax, flax, optax, os, json
import flax.linen as nn
from flax.training.train_state import TrainState  # Useful dataclass to keep train state
from flax.training import checkpoints
from flax.core import freeze, unfreeze
import optax
from time import time
from typing import Sequence
config.update("jax_debug_nans", True)

print("DEVICES: ", jax.devices())

# MOUNT GOOGLE DRIVE (a pop up will appear, follow instructions)
from google.colab import drive
drive.mount('/content/drive')

# DEFINE WORKING DIR
working_dir = "/content/drive/MyDrive/Jaxecon/RbcProdNet/"

Found existing installation: jax 0.5.3
Uninstalling jax-0.5.3:
  Successfully uninstalled jax-0.5.3
Found existing installation: jaxlib 0.5.3
Uninstalling jaxlib-0.5.3:
  Successfully uninstalled jaxlib-0.5.3
Collecting jax==0.6.0
  Downloading jax-0.6.0-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib==0.6.0
  Downloading jaxlib-0.6.0-cp312-cp312-manylinux2014_x86_64.whl.metadata (1.2 kB)
Downloading jax-0.6.0-py3-none-any.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m57.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxlib-0.6.0-cp312-cp312-manylinux2014_x86_64.whl (87.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.8/87.8 MB[0m [31m25.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: jaxlib, jax
Successfully installed jax-0.6.0 jaxlib-0.6.0




DEVICES:  [CudaDevice(id=0)]
Mounted at /content/drive


The first time we run this colab, we need to create a daframe that will sstore the results of the different experiments.

In [None]:
# Create an empty DataFrame with the desired columns
columns = [
    "exper_name",
    "comment",
    "min_loss",
    "max_mean_acc",
    "max_min_acc",
    "Time for Full Experiment (m)",
    "Time for epoch (s)",
    "Time for Compilation (s)",
    "Steps per second",
    "config",
    "Losses_list",
    "mean_accuracy_list",
    "min_accuracy_list",
]

# Create an empty DataFrame with the specified columns
df = pd.DataFrame(columns=columns)

# Specify the file path for the CSV file in your working directory
csv_filename = os.path.join(working_dir, 'experiment_results.csv')
# Save the DataFrame to a CSV file if it doesn't exist
if not os.path.isfile(csv_filename):
  df.to_csv(csv_filename, index=False)
  print(f"New experiments csv saved as {csv_filename}")
else:
  print("The experiments csv already exists")


The experiments csv already exists


# Create Neural Net Policy

First, we use Flax to create the Neural Net, Notice that we activate the last layer using Softplus to guarantee that we get possitive outputs.

See https://flax.readthedocs.io/en/latest/getting_started.html

In [None]:
class NeuralNet(nn.Module):
  features: Sequence[int]
  C: jnp.ndarray  # shape (n_states, n_states)
  policies_sd: jnp.ndarray  # shape (n_policies,)
  param_dtype = precision

  @nn.compact
  def __call__(self, x):
    # Ensure 2D for consistent slicing
    x_2d = x.reshape(-1, x.shape[-1])  # Always (batch, 111)

    # Baseline loglinear policy
    baseline = x @ self.C.T
    baseline = baseline * self.policies_sd[None, :]  # (batch, n_states)

    # Residual MLP
    h = x_2d
    for feat in self.features:
        h = nn.relu(nn.Dense(feat, param_dtype=self.param_dtype)(h))

    residual = nn.Dense(
        self.C.shape[0],
        kernel_init=nn.initializers.zeros,
        bias_init=nn.initializers.zeros,
        param_dtype=self.param_dtype
    )(h)

    output = baseline + residual

    # Reshape output to match input shape structure
    if x.ndim == 1:
        output = output.reshape(-1)

    return output

# Create Economic Model

We will represent our model as clas with four main methods (or functions): initial_obs to get first observation; step to advance a period, expectation to get the expectation term given a state, policy, and shock; and a loss funciton that gets as the loss given a state, policy and expectation term.

## Import Parameters of the Model

In [None]:
import scipy.io as sio
model_name = "Feb21_24_baselinev3.mat"
model_data = sio.loadmat(working_dir + "Model_Data/RbcProdNet_SolData_" + model_name, simplify_cells=True)
modparams = model_data["SolData"]["parameters"]  # dictionary with model parameters of the model
policies_ss = model_data["SolData"]["policies_ss"]  # dictionary with model parameters of the model
k_ss = model_data["SolData"]["k_ss"]
policies_sd = model_data["SolData"]["policies_sd"]  # dictionary with model parameters of the model
states_sd = model_data["SolData"]["states_sd"]  # dictionary with model parameters of the model
shocks_sd = model_data["SolData"]["shocks_sd"]  # dictionary with model parameters of the model
A = model_data["SolData"]["A"]
B = model_data["SolData"]["B"]
C = model_data["SolData"]["C"]
D = model_data["SolData"]["D"]
n_sectors = modparams["parn_sectors"]
print(modparams)
del model_data

{'parn_sectors': 37, 'parbeta': 0.96, 'pareps_c': 0.5, 'pareps_l': 0.5, 'parphi': 2, 'partheta': 282791.12733419565, 'parsigma_c': 0.5, 'parsigma_m': 0.1, 'parsigma_q': 0.5, 'parsigma_y': 0.8, 'parsigma_I': 0.5, 'parsigma_l': 0.5, 'paralpha': array([0.76246851, 0.81447933, 0.2173877 , 0.35041499, 0.51072498,
       0.49043508, 0.43432306, 0.48266322, 0.46770221, 0.45357891,
       0.50864942, 0.45744396, 0.30623142, 0.49284854, 0.61563524,
       0.33601341, 0.27585449, 0.55552307, 0.23069028, 0.65399183,
       0.72390174, 0.48730601, 0.54020845, 0.47152072, 0.42829889,
       0.70186202, 0.59354188, 0.96764559, 0.57588724, 0.20876475,
       0.33779216, 0.16464859, 0.45550776, 0.57440692, 0.4837216 ,
       0.4069147 , 0.41336966]), 'pardelta': array([0.07113207, 0.03325417, 0.12865278, 0.08514028, 0.07672778,
       0.05764028, 0.07075   , 0.08320417, 0.13533333, 0.10309861,
       0.15317917, 0.11067917, 0.08234722, 0.0942625 , 0.06625972,
       0.06867222, 0.06729028, 0.08019444,

In [None]:
print(states_sd)

[0.04515133 0.03403735 0.08149502 0.10456063 0.05156925 0.05730716
 0.05469752 0.03207556 0.05441389 0.04682447 0.04630113 0.05072055
 0.06580466 0.04258033 0.05302759 0.07543571 0.05762596 0.05643485
 0.07725362 0.05537577 0.03626287 0.0572563  0.05564264 0.04074429
 0.04199542 0.03781637 0.03965085 0.03402221 0.03517175 0.04551916
 0.0584624  0.02805981 0.04183397 0.03928181 0.04694842 0.05608965
 0.03324947 0.1965228  0.05434032 0.05390058 0.10868998 0.07450848
 0.08911767 0.06404581 0.16582568 0.0904109  0.07320349 0.14056454
 0.10458124 0.06094005 0.07292626 0.04934042 0.07508559 0.040071
 0.0742004  0.05010725 0.29439812 0.08114295 0.07334797 0.03927505
 0.04591136 0.0449825  0.07102711 0.04402118 0.02309391 0.0373194
 0.05784268 0.03522574 0.05849283 0.02823388 0.04865723 0.04147002
 0.04459083 0.05860165]


## Create Model class

In [None]:
class Model():
  """A JAX implementation of an RBC model."""

  def __init__(self,
    modparams = modparams, k_ss=k_ss, policies_ss=policies_ss, state_sd=states_sd,
    shocks_sd=shocks_sd, policies_sd=policies_sd, A=A, B=B, C=C, D=D):

    self.alpha = jnp.array(modparams["paralpha"], dtype = precision)
    self.beta = jnp.array(modparams["parbeta"], dtype = precision)
    self.delta = jnp.array(modparams["pardelta"], dtype = precision)
    self.rho = jnp.array(modparams["parrho"], dtype = precision)
    self.eps_c = jnp.array(modparams["pareps_c"], dtype = precision)
    self.eps_l = jnp.array(modparams["pareps_l"], dtype = precision)
    self.phi = jnp.array(modparams["parphi"], dtype = precision)
    self.theta = jnp.array(modparams["partheta"], dtype = precision)
    self.sigma_c = jnp.array(modparams["parsigma_c"], dtype = precision)
    self.sigma_m = jnp.array(modparams["parsigma_m"], dtype = precision)
    self.sigma_q = jnp.array(modparams["parsigma_q"], dtype = precision)
    self.sigma_y = jnp.array(modparams["parsigma_y"], dtype = precision)
    self.sigma_I = jnp.array(modparams["parsigma_I"], dtype = precision)
    self.sigma_l = jnp.array(modparams["parsigma_l"], dtype = precision)
    self.xi = jnp.array(modparams["parxi"], dtype = precision)
    self.mu = jnp.array(modparams["parmu"], dtype = precision)
    self.Gamma_M = jnp.array(modparams["parGamma_M"], dtype = precision)
    self.Gamma_I = jnp.array(modparams["parGamma_I"], dtype = precision)
    self.Sigma_A = jnp.array(modparams["parSigma_A"], dtype = precision)
    self.n_sectors = modparams["parn_sectors"]
    self.state_ss = jnp.concatenate([k_ss,jnp.zeros(shape=(1*self.n_sectors,), dtype = precision)])
    self.policies_ss = jnp.array(policies_ss, dtype = precision)

    self.A = A
    self.B = B
    self.C = C
    self.D = D
    self.state_sd = jnp.array(state_sd, dtype = precision)
    self.policies_sd = jnp.array(policies_sd, dtype = precision)
    self.dim_policies = len(policies_ss)
    self.dim_states = len(self.state_ss)
    self.n_actions = len(policies_ss)
    self.L_cholesky = jnp.linalg.cholesky(self.Sigma_A)

  def initial_obs(self, rng, range=1):
    """ Get initial obs given first shock """

    rng_k, rng_a, rng_e, rng_c = random.split(rng,4)
    e = self.sample_shock(rng)                                                  # sample a realization of the shock
    k_ss = self.state_ss[:self.n_sectors]                                         # get log K in StSt
    a_ss = self.state_ss[self.n_sectors:]                         # get log A in StSt
    K_init = random.uniform(rng_k, shape=(self.n_sectors,), minval=(1-range/100)*jnp.exp(self.state_ss[:self.n_sectors]),maxval = (1+range/300)*jnp.exp(self.state_ss[:self.n_sectors]))
    A_init = random.uniform(rng_a, shape=(self.n_sectors,), minval=(1-range/100),maxval = (1+range/100))
    state_init_notnorm = jnp.concatenate([jnp.log(K_init),jnp.log(A_init)])
    state_init = (state_init_notnorm-self.state_ss)/self.state_sd                       # normalize
    return random.choice(rng_c,jnp.array([state_init]))


  def step(self, state, policy, shock):
    """ A period step of the model, given current obs, the shock and policy_params """

    state_notnorm = state*self.state_sd + self.state_ss                                 # denormalize obs
    K = jnp.exp(state_notnorm[:self.n_sectors])                                   # extract k and put in levels
    a = state_notnorm[self.n_sectors:]
    a_next = self.rho*a + shock
    policy_notnorm = jnp.exp(policy*self.policies_sd+self.policies_ss)                           # denormalize policy

    I = policy_notnorm[6*self.n_sectors:7*self.n_sectors]
    K_tplus1 = (1-self.delta)*K + I - (self.phi/2) * (I/K - self.delta)**2 * K     # update K
    state_next_notnorm = jnp.concatenate([jnp.log(K_tplus1),a_next])            # calculate next obs not notrm
    state_next = (state_next_notnorm-self.state_ss)/self.state_sd                       # normalize

    return state_next


  def expect_realization(self, obs_next, policy_next):
    """ A realization (given a shock) of the expectation terms in system of equation """

    # Process observation
    obs_next_notnorm = obs_next*self.state_sd + self.state_ss# denormalize
    K_next = jnp.exp(obs_next_notnorm[:self.n_sectors]) # put in levels
    a_next = obs_next_notnorm[self.n_sectors:2*self.n_sectors]
    A_next = jnp.exp(a_next)

    # Calculate tplus1 policies
    policy_next_notnorm = jnp.exp(policy_next*self.policies_sd+self.policies_ss)                           # denormalize policy
    Pk_next = policy_next_notnorm[2*self.n_sectors:3*self.n_sectors]
    I_next = policy_next_notnorm[6*self.n_sectors:7*self.n_sectors]
    P_next = policy_next_notnorm[8*self.n_sectors:9*self.n_sectors]
    Q_next = policy_next_notnorm[9*self.n_sectors:10*self.n_sectors]
    Y_next = policy_next_notnorm[10*self.n_sectors:11*self.n_sectors]

    # Solve for the expectation term in the FOC for Ktplus1
    expect_realization = (P_next*A_next**((self.sigma_y-1)/self.sigma_y) * (self.mu*Q_next/Y_next)**(1/self.sigma_q) *(self.alpha*Y_next/K_next)**(1/self.sigma_y)
      + Pk_next*((1-self.delta) + self.phi/2*(I_next**2 / K_next**2-self.delta**2)))

    return jax.lax.stop_gradient(expect_realization)

  def loss(self, state, expect, policy):
    """ Calculate loss associated with observing obs, having policy_params, and expectation exp """

    # Process observation
    state_notnorm = state*self.state_sd + self.state_ss# denormalize
    K = jnp.exp(state_notnorm[:self.n_sectors]) # put in levels
    a = state_notnorm[self.n_sectors:]
    A = jnp.exp(a)

    # Process policy
    policy_notnorm = jnp.exp(policy*self.policies_sd+self.policies_ss)                           # denormalize policy
    C = policy_notnorm[:self.n_sectors]
    L = policy_notnorm[self.n_sectors:2*self.n_sectors]
    Pk = policy_notnorm[2*self.n_sectors:3*self.n_sectors]
    Pm = policy_notnorm[3*self.n_sectors:4*self.n_sectors]
    M = policy_notnorm[4*self.n_sectors:5*self.n_sectors]
    Mout = policy_notnorm[5*self.n_sectors:6*self.n_sectors]
    I = policy_notnorm[6*self.n_sectors:7*self.n_sectors]
    Iout = policy_notnorm[7*self.n_sectors:8*self.n_sectors]
    P = policy_notnorm[8*self.n_sectors:9*self.n_sectors]
    Q = policy_notnorm[9*self.n_sectors:10*self.n_sectors]
    Y = policy_notnorm[10*self.n_sectors:11*self.n_sectors]
    Cagg = policy_notnorm[11*self.n_sectors]
    Lagg = policy_notnorm[11*self.n_sectors+1]
    Yagg = policy_notnorm[11*self.n_sectors+2]
    Iagg = policy_notnorm[11*self.n_sectors+3]
    Magg = policy_notnorm[11*self.n_sectors+4]

    # get steady state prices to aggregate Y, I and M
    Pss = jnp.exp(policies_ss[8*self.n_sectors:9*self.n_sectors])
    Pkss = jnp.exp(policies_ss[2*self.n_sectors:3*self.n_sectors])
    Pmss = jnp.exp(policies_ss[3*self.n_sectors:4*self.n_sectors])
    capadj_term = 1-self.phi*(I/K-self.delta)

    # auxialiry variables
    Pagg = (self.xi.T @ P ** (1 - self.sigma_c)) ** (1 / (1 - self.sigma_c))
    MgUtCagg = (Cagg - self.theta * 1 / (1 + self.eps_l ** (-1)) * Lagg ** (1 + self.eps_l ** (-1))) ** (-self.eps_c ** (-1))

    # key variables for loss function
    MgUtCmod = MgUtCagg * (Cagg * self.xi / C) ** (1 / self.sigma_c)
    MgUtLmod = MgUtCagg * self.theta * Lagg ** (self.eps_l ** -1) * (L / Lagg) ** (1 / self.sigma_l)
    MPLmod = P * A**((self.sigma_y-1)/self.sigma_y) * (self.mu * Q / Y) ** (1 / self.sigma_q) * ((1 - self.alpha) * Y / L) ** (1 / self.sigma_y)
    MPKmod = self.beta * expect
    Pmdef = (self.Gamma_M.T @ P ** (1 - self.sigma_m)) ** (1 / (1 - self.sigma_m))
    Mmod = (1 - self.mu) * (Pm / P) ** (-self.sigma_q) * Q
    Moutmod = P ** (-self.sigma_m) * jnp.dot(self.Gamma_M, Pm**self.sigma_m * M)
    Pkdef = (self.Gamma_I.T @ P ** (1 - self.sigma_I)) ** (1 / (1 - self.sigma_I)) * capadj_term**(-1)
    Ioutmod = P ** (-self.sigma_I) * jnp.dot( self.Gamma_I,Pk**self.sigma_I * I * capadj_term**(self.sigma_I) )
    Qrc = C + Mout + Iout
    Qdef = ( self.mu**(1/self.sigma_q) * Y**((self.sigma_q-1)/self.sigma_q) + (1-self.mu)**(1/self.sigma_q) * M**((self.sigma_q-1)/self.sigma_q) ) ** (self.sigma_q/(self.sigma_q-1))
    Ydef = A * ( self.alpha**(1/self.sigma_y) * K**((self.sigma_y-1)/self.sigma_y) + (1-self.alpha)**(1/self.sigma_y) * L**((self.sigma_y-1)/self.sigma_y) ) ** (self.sigma_y/(self.sigma_y-1))
    Caggdef = ( (self.xi**(1/self.sigma_c)).T @ C**((self.sigma_c-1)/self.sigma_c) ) ** (self.sigma_c/(self.sigma_c-1))
    Laggdef = jnp.sum( L**((self.sigma_l+1)/self.sigma_l) ) ** (self.sigma_l/(self.sigma_l+1))
    Yaggdef = jnp.sum(Y * Pss)
    Iaggdef = jnp.sum(I * Pkss)
    Maggdef = jnp.sum(M * Pmss)

    C_loss = P/MgUtCmod - 1;
    L_loss = MgUtLmod/MPLmod - 1;
    K_loss = Pk/MPKmod - 1;
    Pm_loss = Pm/Pmdef - 1;
    M_loss = M/Mmod - 1;
    Mout_loss = Mout/Moutmod - 1;
    Pk_loss = Pk/Pkdef - 1;
    Iout_loss = Iout/Ioutmod - 1;
    Qrc_loss = Q/Qrc - 1;
    Qdef_loss = Q/Qdef - 1;
    Ydef_loss = Y/Ydef - 1;
    Caggdef_loss = jnp.array([Cagg/Caggdef - 1]);
    Laggdef_loss = jnp.array([Lagg/Laggdef - 1]);
    Yaggdef_loss = jnp.array([Yagg/Yaggdef - 1]);
    Iaggdef_loss = jnp.array([Iagg/Iaggdef - 1]);
    Maggdef_loss = jnp.array([Magg/Maggdef - 1]);

    losses_array = jnp.concatenate([C_loss,L_loss,K_loss,Pm_loss,M_loss,Mout_loss,Pk_loss,
                              Iout_loss,Qrc_loss,Qdef_loss,Ydef_loss,Caggdef_loss,
                              Laggdef_loss,Yaggdef_loss,Iaggdef_loss,Maggdef_loss], axis =0)

    # Calculate aggregate losses and metrics
    mean_loss = jnp.mean(losses_array**2)
    mean_accuracy = jnp.mean(1-jnp.abs(losses_array))
    min_accuracy = jnp.min(1-jnp.abs(losses_array))
    mean_accuracies_focs = jnp.array([jnp.mean(1-jnp.abs(C_loss)),jnp.mean(1-jnp.abs(L_loss)),jnp.mean(1-jnp.abs(K_loss)),jnp.mean(1-jnp.abs(Pm_loss)),jnp.mean(1-jnp.abs(M_loss)),jnp.mean(1-jnp.abs(Mout_loss)),jnp.mean(1-jnp.abs(Pk_loss)),
                              jnp.mean(1-jnp.abs(Iout_loss)),jnp.mean(1-jnp.abs(Qrc_loss)),jnp.mean(1-jnp.abs(Qdef_loss)),jnp.mean(1-jnp.abs(Ydef_loss)),jnp.mean(1-jnp.abs(Caggdef_loss)),
                              jnp.mean(1-jnp.abs(Laggdef_loss)),jnp.mean(1-jnp.abs(Yaggdef_loss)),jnp.mean(1-jnp.abs(Iaggdef_loss)),jnp.mean(1-jnp.abs(Maggdef_loss))])

    min_accuracies_focs = jnp.array([jnp.min(1-jnp.abs(C_loss)),jnp.min(1-jnp.abs(L_loss)),jnp.min(1-jnp.abs(K_loss)),jnp.min(1-jnp.abs(Pm_loss)),jnp.min(1-jnp.abs(M_loss)),jnp.min(1-jnp.abs(Mout_loss)),jnp.min(1-jnp.abs(Pk_loss)),
                              jnp.min(1-jnp.abs(Iout_loss)),jnp.min(1-jnp.abs(Qrc_loss)),jnp.min(1-jnp.abs(Qdef_loss)),jnp.min(1-jnp.abs(Ydef_loss)),jnp.min(1-jnp.abs(Caggdef_loss)),
                              jnp.min(1-jnp.abs(Laggdef_loss)),jnp.min(1-jnp.abs(Yaggdef_loss)),jnp.min(1-jnp.abs(Iaggdef_loss)),jnp.min(1-jnp.abs(Maggdef_loss))])

    return mean_loss, mean_accuracy, min_accuracy, mean_accuracies_focs, min_accuracies_focs

  # def sample_shock(self, rng):
  #   """ sample one realization of the shock """
  #   z = jax.random.normal(rng, shape=(self.n_sectors,))
  #   shock = self.L_cholesky @ z
  #   return shock

  # def mc_shocks(self, rng=random.PRNGKey(0), mc_draws=8):
  #   """ sample mc_draws realizations of the shock (for monte-carlo) """
  #   z = jax.random.normal(rng, shape=(mc_draws, self.n_sectors))
  #   mc_shocks = jax.vmap(lambda zi: self.L_cholesky @ zi)(z)
  #   return mc_shocks

  def sample_shock(self, rng):
    """ sample one realization of the shock """
    return jax.random.multivariate_normal(rng, jnp.zeros((self.n_sectors,)), self.Sigma_A)

  def mc_shocks(self, rng=random.PRNGKey(0), mc_draws=8):
    """
    Optimized for highly nonlinear functions using:
    - Antithetic variates (always helps with symmetric distributions)
    - Latin Hypercube Sampling (better than simple stratification for nonlinear cases)
    - Optional importance sampling direction (if you know where nonlinearity is strongest)
    """

    # Latin Hypercube Sampling for better space-filling with nonlinear functions
    def latin_hypercube_sample(key, n_samples, n_dims):
        """Generate Latin Hypercube samples"""
        keys = random.split(key, n_dims)

        # Create permutations for each dimension
        perms = jnp.stack([
            random.permutation(keys[i], n_samples)
            for i in range(n_dims)
        ], axis=1)

        # Add uniform noise within each cell
        key_uniform = random.fold_in(key, 1)
        uniform_noise = random.uniform(
            key_uniform,
            shape=(n_samples, n_dims)
        )

        # Create LHS samples in [0,1]^d
        lhs_samples = (perms + uniform_noise) / n_samples
        return lhs_samples

    # Decide on sampling strategy
    use_antithetic = (mc_draws % 2 == 0)

    if use_antithetic:
        # Generate half samples with LHS, create antithetic pairs
        n_base = mc_draws // 2
        key1, key2 = random.split(rng)

        # Latin hypercube sampling for base samples
        u_lhs = latin_hypercube_sample(key1, n_base, self.n_sectors)

        # Transform to standard normal
        u_lhs = jnp.clip(u_lhs, 1e-6, 1 - 1e-6)
        z_base = jax.scipy.stats.norm.ppf(u_lhs)

        # Create antithetic pairs (works well even for nonlinear functions)
        z = jnp.vstack([z_base, -z_base])

    else:
        # Full Latin Hypercube Sampling
        u_lhs = latin_hypercube_sample(rng, mc_draws, self.n_sectors)
        u_lhs = jnp.clip(u_lhs, 1e-6, 1 - 1e-6)
        z = jax.scipy.stats.norm.ppf(u_lhs)

    # Optional: Add controlled noise for highly discontinuous functions
    # This can help explore around discontinuities
    # key_noise = random.fold_in(rng, 2)
    # noise = 0.1 * random.normal(key_noise, shape=z.shape)
    # z = z + noise

    # Transform to target distribution
    if hasattr(self, 'L_cholesky'):
        mc_shocks = jax.vmap(lambda zi: self.L_cholesky @ zi)(z)
    else:
        L_cholesky = jnp.linalg.cholesky(self.Sigma_A)
        mc_shocks = jax.vmap(lambda zi: L_cholesky @ zi)(z)

    return mc_shocks


  def ir_shocks(self):
    """ (Optional) Define a set of shocks sequences that are of interest"""
    # ir_shock_1 = jnp.array([-1]+[0 for i in range(40)])
    # ir_shock_2 = jnp.array([1]+[0 for i in range(40)])
    ir_shock_1 = jnp.zeros(shape=(40,self.n_sectors), dtype = precision).at[0,0].set(-1)
    ir_shock_2 = jnp.zeros(shape=(40,self.n_sectors), dtype = precision).at[0,0].set(1)

    return jnp.array([ir_shock_1, ir_shock_2])


## Test the environment
We are going to make sure that the functions in our model are correct

In [None]:
env_test = Model()
rng_test = random.PRNGKey(4)

# STEADY STATE POLICY  AND STEP WITH RANDOM NEURAL NET
obs_ss_norm = jnp.zeros_like(env_test.state_ss)                                   # get StSt obs
print("type of zeros like ss:", obs_ss_norm.dtype)
obs_test = obs_ss_norm                                                               # set obs that we will test to StSt
nn_test = NeuralNet([32,32] + [env_test.dim_policies], C, policies_sd)                             # initialize NN class
nn_policy_test = nn_test.apply                                                  # initialize policy fn
params_test = nn_test.init(rng_test, env_test.initial_obs(rng_test))            # initialize params
policy_test = nn_policy_test(params_test,obs_test)                              # get policy
print("type of nn_policy:", policy_test.dtype)
policy_test_ones = jnp.zeros_like(env_test.policies_ss)                             # get policy
# print("policy in steady state", policy_test)
next_obs_sspol = env_test.step(obs_test, policy_test, jnp.zeros_like(env_test.sample_shock(rng_test)))
print("next obs with policy_ss", next_obs_sspol)
print("type of next obs", next_obs_sspol.dtype)

# print("next obs with policy= ones (so steady state)", env_test.step(
#     obs_test, policy_test_ones, jnp.zeros_like(env_test.sample_shock(rng_test))))    # get next obs for random policy

# STEADY STATE LOSS WITH RANDOM NEURAL NET
expect_test = env_test.expect_realization(obs_ss_norm, policy_test)             # calculate expectations
mean_loss, mean_accuracy, min_accuracy, mean_acc_kfoc, min_acc_kfoc = env_test.loss(
    obs_ss_norm, expect_test, policy_test)                                      # calculate loss given expect. and policy
print("type of mean_loss:", mean_loss.dtype)
# print("loss with random policy in StSt: \n",
#       ", Mean_loss:", mean_loss,
#       ", Mean_accuracy:", mean_accuracy,
#       ", Min_accuracy:", min_accuracy)

# POLICY AND STEP OVER INTIAL OBS
obs_test = env_test.initial_obs(rng_test)                                       # get init obs (not necessariliy ==StStobs)
# print("initial obs", obs_test)

# # apply a step
policy_test = nn_policy_test(params_test,obs_test)                              # get policy for first step
# print("policy in first step", policy_test)
shock_test = env_test.sample_shock(rng_test)                                    # get shock
# print("Realization of shock", shock_test)
next_obs_test = env_test.step(obs_test, policy_test, shock_test)                # make a step
# print("next obs first step", next_obs_test)

#  LOSS IN INITAL OBS
# First, we calculate expectations
mc_shocks_test = env_test.mc_shocks(rng_test, mc_draws = 1280)                  # get mc shock
mc_nextobs_test = jax.vmap(env_test.step, in_axes = (None,None,0))(
    obs_test, policy_test, mc_shocks_test)                                      # next obs given policy and for each shock in mc_shocks
# mc_nextobs_test = jax.vmap(env_test.step, in_axes = (None,0))(obs_ss_norm, mc_shocks_test)
mc_nextpols_test = nn_policy_test(params_test, mc_nextobs_test)
expect_test = jnp.mean(jax.vmap(env_test.expect_realization)(
        mc_nextobs_test, mc_nextpols_test), axis=0)                             # calculate expectations
# print("expect in oneperiodtest", expect_test)
# Second, we calculate loss given expectations and policy
mean_loss, mean_accuracy, min_accuracy, mean_acc_kfoc, min_acc_kfoc = env_test.loss(
    obs_ss_norm, expect_test, policy_test)                                         # calculate loss
print("test one period: \n",
      ", Mean_loss:", mean_loss,
      ", Mean_accuracy:", mean_accuracy,
      ", Min_accuracy:", min_accuracy)

# LOSS IN STST WITH STST POLICY
obs_test = obs_ss_norm
policy_test = jnp.ones_like(policy_test)                                        # policies=1 pass the StSt policies to the env
mc_nextobs_test = jax.vmap(env_test.step, in_axes = (None,None,0))(
    obs_test, policy_test, jnp.zeros_like(mc_shocks_test))                      # next obs given policy and for each shock in mc_shocks
print("next obs for montecarlo with policies =1",
      "(should be an array with multiple obs =0)", mc_nextobs_test)
mc_nextpols_test = jnp.ones_like(mc_nextpols_test)                              # policies for mc_nextobs
expect_test = jnp.mean(jax.vmap(env_test.expect_realization)(
    mc_nextobs_test, mc_nextpols_test), axis=0)                                 # calculate expectation
print("expectation", expect_test)
mean_loss, mean_accuracy, min_accuracy, mean_acc_kfoc, min_acc_kfoc = env_test.loss(
    obs_test, expect_test, policy_test)                                         # calculate loss
print("test that StSt. policies give 0 loss in StSt. obs: \n",
      ", Mean_loss:", mean_loss,
      ", Mean_accuracy:", mean_accuracy,
      ", Min_accuracy:", min_accuracy)

# DELETE VARIABLES
del env_test
del rng_test
del obs_ss_norm
del obs_test
del next_obs_test
del nn_test
del nn_policy_test
del params_test
del policy_test
del expect_test
del mc_shocks_test
del mc_nextobs_test
del mc_nextpols_test
del mean_loss
del mean_accuracy
del min_accuracy
del mean_acc_kfoc
del min_acc_kfoc


type of zeros like ss: float64
type of nn_policy: float64
next obs with policy_ss [-2.45889331e-15  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  3.99054793e-15  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.0000

# Create Simulation function

In [None]:
def create_episode_simul_fn(env, config):

  if config["proxy_sampler"]:
    def sample_epis_obs(train_state, epis_rng):
      "sample obs of an episode"
      init_obs = env.initial_obs(epis_rng, config["init_range"])
      period_rngs = random.split(epis_rng, config["periods_per_epis"])
      def period_step(env_obs, period_rng):
        period_shock = config["simul_vol_scale"]*env.sample_shock(period_rng)     # Sample next obs
        obs_next = env.step_loglinear(env_obs, period_shock)                      # apply period steps.
        return obs_next, obs_next # we pass it two times because of the syntax of the lax.scan loop
      _, epis_obs = lax.scan(period_step, init_obs, jnp.stack(period_rngs)) # we get the obs_batch
      return epis_obs

  else:
    def sample_epis_obs(train_state, epis_rng):
      "sample obs of an episode"
      init_obs = env.initial_obs(epis_rng, config["init_range"])
      period_rngs = random.split(epis_rng, config["periods_per_epis"])
      def period_step(env_obs, period_rng):
        policy = train_state.apply_fn(train_state.params, env_obs)
        period_shock = config["simul_vol_scale"]*env.sample_shock(period_rng)     # Sample next obs
        obs_next = env.step(env_obs, policy, period_shock)  # apply period steps.
        return obs_next, obs_next # we pass it two times because of the syntax of the lax.scan loop
      _, epis_obs = lax.scan(period_step, init_obs, jnp.stack(period_rngs)) # we get the obs_batch
      return epis_obs

  return sample_epis_obs

## Test simulation function

In [None]:
#CREATE ENV,  TRAIN_STATE AND RNG
env_test = Model()
nn_test = NeuralNet([2,2] + [env_test.n_actions])
rng_test = random.PRNGKey(1)

# CREATE CONFIG
config_test = {
    "periods_per_epis": 100,      # periods per episode
    "simul_vol_scale": 1,        # scale of volatility while simul
    "proxy_sampler": True,
    "init_range": 5,
}

# GET FUNCTIONS
episode_simul_fn = create_episode_simul_fn(env_test, config_test)
train_state_test = TrainState.create(apply_fn=nn_test.apply, params=nn_test.init(rng_test, env_test.initial_obs(rng_test)), tx=optax.adam(0.05))
epis_rng, loss_rng = random.split(rng_test, 2)
epis_obs = episode_simul_fn(train_state_test, epis_rng)
print("last observation of simulation: \n", epis_obs[-1])


# Create Loss function

In [None]:
def create_batch_loss_fn(env, config):

  if config["proxy_mcsampler"] and config["proxy_futurepol"]:
    def batch_loss_fn(params, train_state, batch_obs, loss_rng):
      """Loss function of a batch of obs."""
      period_mc_rngs = random.split(loss_rng, batch_obs.shape[0])
      batch_policies = train_state.apply_fn(params, batch_obs) # get the policies for the entire obs batch.
      # batch_policies = jax.vmap(env.policy_loglinear)(batch_obs) # get the policies for the entire obs batch.
      def period_loss(obs, policy, period_mc_rng):
        """Loss function for an individual period."""
        mc_shocks = env.mc_shocks(period_mc_rng, config["mc_draws"])
        mc_nextobs = jax.vmap(env.step_loglinear, in_axes = (None,0))(obs, mc_shocks)
        # print("shape of mc_nextobs")
        mc_nextpols = jax.vmap(env.policy_loglinear)(mc_nextobs)
        # print("shape of mc_nexpols", )
        expect = jnp.mean(jax.vmap(env.expect_realization)(mc_nextobs, mc_nextpols), axis=0)
        mean_loss, mean_accuracy, min_accuracy, mean_accs_foc, min_accs_foc = env.loss(obs, expect, policy) # calculate loss
        return mean_loss, mean_accuracy, min_accuracy, mean_accs_foc, min_accs_foc

      # parallelize callculation of period_loss for the entire batch
      mean_losses, mean_accuracies, min_accuracies, mean_accs_foc, min_accs_foc = jax.vmap(period_loss)(batch_obs, batch_policies, jnp.stack(period_mc_rngs))
      mean_loss = jnp.mean(mean_losses)                   # average accross periods
      mean_accuracy = jnp.mean(mean_accuracies)           # average accross periods
      min_accuracy = jnp.min(min_accuracies)              # min accross periods and across eqs within period
      mean_accs_foc = jnp.mean(mean_accs_foc,axis=0)
      min_accs_foc = jnp.min(min_accs_foc,axis=0)
      metrics = mean_loss, mean_accuracy, min_accuracy, mean_accs_foc, min_accs_foc # pass as auxiliary info
      # metrics = jnp.array([mean_losses, mean_accuracies, min_accuracies]) # pass as auxiliary info
      return mean_loss, metrics

  elif config["proxy_mcsampler"] and not config["proxy_futurepol"]:
    def batch_loss_fn(params, train_state, batch_obs, loss_rng):
      """Loss function of a batch of obs."""
      period_mc_rngs = random.split(loss_rng, batch_obs.shape[0])
      batch_policies = train_state.apply_fn(params, batch_obs) # get the policies for the entire obs batch.

      def period_loss(obs, policy, period_mc_rng):
        """Loss function for an individual period."""
        mc_shocks = env.mc_shocks(period_mc_rng, config["mc_draws"])
        mc_nextobs = jax.vmap(env.step_loglinear, in_axes = (None,0))(obs, mc_shocks)
        mc_nextpols = train_state.apply_fn(params, mc_nextobs)
        expect = jnp.mean(jax.vmap(env.expect_realization)(mc_nextobs, mc_nextpols), axis=0)
        mean_loss, mean_accuracy, min_accuracy, mean_accs_foc, min_accs_foc = env.loss(obs, expect, policy) # calculate loss
        return mean_loss, mean_accuracy, min_accuracy, mean_accs_foc, min_accs_foc

      # parallelize callculation of period_loss for the entire batch
      mean_losses, mean_accuracies, min_accuracies, mean_accs_foc, min_accs_foc = jax.vmap(period_loss)(batch_obs, batch_policies, jnp.stack(period_mc_rngs))
      mean_loss = jnp.mean(mean_losses)                   # average accross periods
      mean_accuracy = jnp.mean(mean_accuracies)           # average accross periods
      min_accuracy = jnp.min(min_accuracies)              # min accross periods and across eqs within period
      mean_accs_foc = jnp.mean(mean_accs_foc,axis=0)
      min_accs_foc = jnp.min(min_accs_foc,axis=0)
      metrics = mean_loss, mean_accuracy, min_accuracy, mean_accs_foc, min_accs_foc # pass as auxiliary info
      return mean_loss, metrics

  elif not config["proxy_mcsampler"] and config["proxy_futurepol"]:
    def batch_loss_fn(params, train_state, batch_obs, loss_rng):
      """Loss function of a batch of obs."""
      period_mc_rngs = random.split(loss_rng, batch_obs.shape[0])
      batch_policies = train_state.apply_fn(params, batch_obs) # get the policies for the entire obs batch.

      def period_loss(obs, policy, period_mc_rng):
        """Loss function for an individual period."""
        mc_shocks = env.mc_shocks(period_mc_rng, config["mc_draws"])
        mc_nextobs = jax.vmap(env.step, in_axes = (None,None,0))(obs, policy, mc_shocks)
        mc_nextpols = jax.vmap(env.policy_loglinear)(mc_nextobs)
        expect = jnp.mean(jax.vmap(env.expect_realization)(mc_nextobs, mc_nextpols), axis=0)
        mean_loss, mean_accuracy, min_accuracy, mean_accs_foc, min_accs_foc = env.loss(obs, expect, policy) # calculate loss
        return mean_loss, mean_accuracy, min_accuracy, mean_accs_foc, min_accs_foc

     # parallelize callculation of period_loss for the entire batch
      mean_losses, mean_accuracies, min_accuracies, mean_accs_foc, min_accs_foc = jax.vmap(period_loss)(batch_obs, batch_policies, jnp.stack(period_mc_rngs))
      mean_loss = jnp.mean(mean_losses)                   # average accross periods
      mean_accuracy = jnp.mean(mean_accuracies)           # average accross periods
      min_accuracy = jnp.min(min_accuracies)              # min accross periods and across eqs within period
      mean_accs_foc = jnp.mean(mean_accs_foc,axis=0)
      min_accs_foc = jnp.min(min_accs_foc,axis=0)
      metrics = mean_loss, mean_accuracy, min_accuracy, mean_accs_foc, min_accs_foc # pass as auxiliary info
      return mean_loss, metrics
  else:
    def batch_loss_fn(params, train_state, batch_obs, loss_rng):
      """Loss function of a batch of obs."""
      period_mc_rngs = random.split(loss_rng, batch_obs.shape[0])
      batch_policies = train_state.apply_fn(params, batch_obs) # get the policies for the entire obs batch.

      def period_loss(obs, policy, period_mc_rng):
        """Loss function for an individual period."""
        mc_shocks = env.mc_shocks(period_mc_rng, config["mc_draws"])
        mc_nextobs = jax.vmap(env.step, in_axes = (None,None,0))(obs, policy, mc_shocks)
        mc_nextpols = train_state.apply_fn(params, mc_nextobs)
        expect = jnp.mean(jax.vmap(env.expect_realization)(mc_nextobs, mc_nextpols), axis=0)
        mean_loss, mean_accuracy, min_accuracy, mean_accs_foc, min_accs_foc = env.loss(obs, expect, policy) # calculate loss
        return mean_loss, mean_accuracy, min_accuracy, mean_accs_foc, min_accs_foc

      # parallelize callculation of period_loss for the entire batch
      mean_losses, mean_accuracies, min_accuracies, mean_accs_foc, min_accs_foc = jax.vmap(period_loss)(batch_obs, batch_policies, jnp.stack(period_mc_rngs))
      mean_loss = jnp.mean(mean_losses)                   # average accross periods
      mean_accuracy = jnp.mean(mean_accuracies)           # average accross periods
      min_accuracy = jnp.min(min_accuracies)              # min accross periods and across eqs within period
      mean_accs_foc = jnp.mean(mean_accs_foc,axis=0)
      min_accs_foc = jnp.min(min_accs_foc,axis=0)
      metrics = mean_loss, mean_accuracy, min_accuracy, mean_accs_foc, min_accs_foc # pass as auxiliary info
      return mean_loss, metrics

  return batch_loss_fn

## Test loss fn

We calculate the loss and metrics of a single epsiode. To do so, we set batch_size equal to periods_per_epis

In [None]:
#CREATE ENV,  TRAIN_STATE AND RNG
env_test = Model()
nn_test = NeuralNet([32,32] + [env_test.n_actions])
rng_test = random.PRNGKey(1)

# CREATE CONFIG
config_test = {
    "periods_per_epis": 8,      # periods per episode
    "simul_vol_scale": 1,        # scale of volatility while simul
    "mc_draws": 32,             # only applies if shock is continuous
    "proxy_sampler": True,
    "proxy_mcsampler": True,
    "proxy_futurepol": True,
    "init_range": 5,
}
config_test["batch_size"] = config_test["periods_per_epis"]

# GET FUNCTIONS
episode_simul_fn = create_episode_simul_fn(env_test, config_test)
episode_loss_fn = create_batch_loss_fn(env_test, config_test)

train_state_test = TrainState.create(apply_fn=nn_test.apply, params=nn_test.init(rng_test, env_test.initial_obs(rng_test)), tx=optax.adam(0.05))
epis_rng, loss_rng = random.split(rng_test, 2)
epis_obs = episode_simul_fn(train_state_test, epis_rng)

loss, epis_metrics = episode_loss_fn(train_state_test.params, train_state_test, epis_obs, loss_rng)
print("loss:", loss)
print("epis_metrics:", epis_metrics)

# Create Epoch Training function



In [None]:
def get_epoch_train_fn(env, config):
  episode_simul_fn = create_episode_simul_fn(env, config)
  batch_loss_fn = create_batch_loss_fn(env, config)

  def batch_train_fn(train_state, batch_obs, loss_rng):
    grad_fn = jax.value_and_grad(batch_loss_fn, has_aux=True)
    (_, batch_metrics), grads = grad_fn(train_state.params, train_state, batch_obs, loss_rng)
    grads = jax.lax.pmean(grads, axis_name="batch")
    train_state = train_state.apply_gradients(grads=grads)
    return train_state, batch_metrics

  def step_train_fn(train_state, step_rng):
    epis_rng = random.split(step_rng, config["epis_per_step"])
    loss_rng = random.split(step_rng, config["n_batches"])
    step_obs = jax.vmap(episode_simul_fn, in_axes=(None,0))(train_state, jnp.stack(epis_rng))
    step_obs = step_obs.reshape(config["periods_per_step"], env.state_ss.shape[0]) # combine all periods in one axis
    step_obs = random.permutation(step_rng, step_obs, axis=0)                   # reshuffle obs at random
    step_obs = step_obs.reshape(config["n_batches"], config["batch_size"] ,env.state_ss.shape[0]) # reshape to into batches
    train_state, step_metrics = jax.vmap(batch_train_fn, in_axes=(None,0,0), out_axes=(None,0), axis_name="batch")(train_state, step_obs, jnp.stack(loss_rng))
    losses, mean_accuracies, min_accuracies,_, _ = step_metrics
    loss = jnp.mean(losses)
    mean_accuracy = jnp.mean(mean_accuracies)
    min_accuracy = jnp.min(min_accuracies)
    metrics = loss, mean_accuracy, min_accuracy
    return train_state, metrics

  def epoch_train_fn(train_state, epoch_rng):
    """Vectorise and repeat the update to complete an epoch, made aout of steps_per_epoch episodes."""
    epoch_rng, *step_rngs = random.split(epoch_rng, config["steps_per_epoch"] + 1)
    train_state, epoch_metrics = lax.scan(step_train_fn, train_state, jnp.stack(step_rngs))
    return train_state, epoch_rng, epoch_metrics

  return epoch_train_fn


## Test the training function

We can run one epoch and see the results. Play with the parameters of the epoch to evaluate how good is the starting point. You can also add prints inside the update function to check internal values. An important check is to print the grads inside the epis_update_fn and make sure they are not zero for an entire layer. This is especially relevant when using pre-trained models.

In [None]:
config_test = {
    "periods_per_epis": 2,      # periods per episode
    "epis_per_step": 2,         # epoch per steps
    "steps_per_epoch": 2,       # steps per epoch
    "batch_size": 2,
    "n_epochs": 1,              # number of epochs
    "mc_draws": 2,             # only applies if shock is continuous
    "simul_vol_scale": 1,        # scale of volatility while simul
    "proxy_sampler": True,
    "proxy_mcsampler": True,
    "proxy_futurepol": True,
    "config_eval": {
      "periods_per_epis": 8,      # periods per episode
      "mc_draws": 8,         # number of mc draws
      "simul_vol_scale": 1,        # scale of volatility while simul
      "eval_n_epis": 8,           # episodes to sample for eval
      "proxy_sampler": True,
      "proxy_mcsampler": True,
      "proxy_futurepol": True,
    }
}

config_test["periods_per_step"] =config_test["periods_per_epis"]*config_test["epis_per_step"]
config_test["n_batches"] = config_test["periods_per_step"]//config_test["batch_size"]
env_test = Model()
epoch_update_test = get_epoch_train_fn(env_test, config_test)

#CREATE TRAIN_STATE AND RNG
nn_test = NeuralNet([32,32] + [env_test.n_actions])
rng_test_init = random.PRNGKey(1)
train_state_test = TrainState.create(apply_fn=nn_test.apply, params=nn_test.init(rng_test_init, env_test.initial_obs(rng_test_init)), tx=optax.adam(0.001))

# RUN UPDATE FUNCTION
new_train_state, new_rng_test, metrics_test = epoch_update_test(train_state_test, rng_test_init)
print(metrics_test)
print(len(metrics_test[0]))

print("test epoch: \n",
      "Mean_loss:", metrics_test[-1][0], "\n",
      "Mean_accuracy:", metrics_test[-1][1], "\n",
      "Min_accuracy:", metrics_test[-1][2],)


# Create Evaluation function

In [None]:
def get_eval_fn(env, config):
  config = config["config_eval"]
  episode_simul_fn = create_episode_simul_fn(env, config)
  batch_loss_fn = create_batch_loss_fn(env, config)

  def episode_eval_fn(train_state, epis_rng):
    epis_rng, loss_rng = random.split(epis_rng, 2)
    epis_obs = episode_simul_fn(train_state, epis_rng)
    _, epis_metrics = batch_loss_fn(train_state.params, train_state, epis_obs, loss_rng)
    return epis_metrics

  def eval_fn(train_state, step_rng):
    epis_rng = random.split(step_rng, config["eval_n_epis"])
    losses, mean_accuracies, min_accuracies, mean_accs_focs, min_accs_focs = jax.vmap(episode_eval_fn, in_axes=(None,0))(train_state, jnp.stack(epis_rng))
    loss = jnp.mean(losses)
    mean_accuracy = jnp.mean(mean_accuracies)
    min_accuracy = jnp.min(min_accuracies)
    mean_accs_focs = jnp.mean(mean_accs_focs, axis=0)
    min_accs_focs = jnp.min(min_accs_focs, axis=0)
    return loss, mean_accuracy, min_accuracy, mean_accs_focs, min_accs_focs

  return eval_fn

## Test Evaluation function



In [None]:
config_test = {
    "periods_per_epis": 2,      # periods per episode
    "epis_per_step": 2,         # epoch per steps
    "steps_per_epoch": 2,       # steps per epoch
    "batch_size": 2,
    "n_epochs": 1,              # number of epochs
    "mc_draws": 2,             # only applies if shock is continuous
    "simul_vol_scale": 1,        # scale of volatility while simul
    "proxy_sampler": True,
    "proxy_mcsampler": True,
    "proxy_futurepol": True,
    "config_eval": {
      "periods_per_epis": 8,      # periods per episode
      "mc_draws": 8,         # number of mc draws
      "simul_vol_scale": 1,        # scale of volatility while simul
      "eval_n_epis": 8,           # episodes to sample for eval
      "proxy_sampler": True,
      "proxy_mcsampler": True,
      "proxy_futurepol": True,
    }
}


config_test["periods_per_step"] =config_test["periods_per_epis"]*config_test["epis_per_step"]
config_test["n_batches"] = config_test["periods_per_step"]//config_test["batch_size"]
env_test = Model()
epoch_train_fn_test = get_epoch_train_fn(env_test, config_test)
eval_fn_test = get_eval_fn(env_test, config_test)

#CREATE TRAIN_STATE AND RNG
nn_test = NeuralNet([2,2] + [env_test.n_actions])
rng_test_init = random.PRNGKey(1)
train_state_test = TrainState.create(apply_fn=nn_test.apply, params=nn_test.init(rng_test_init, env_test.initial_obs(rng_test_init)), tx=optax.adam(0.05))

# RUN UPDATE FUNCTION
new_train_state, new_rng_test, train_metrics_test = epoch_train_fn_test(train_state_test, rng_test_init)
print("Training Metrics:", train_metrics_test)
eval_metrics_test = eval_fn_test(new_train_state, rng_test_init)
print("Evaluation Metrics:", eval_metrics_test)

# Configure experiment

In [None]:
'''Config dictionary'''

# CREATE CONFIG DICT
config = {
    # general
    "model_name": model_name,
    "exper_name": "baseline_nostateaug_finetune",
    "date": "Sep18_2025",
    "save_dir": working_dir + "Experiments/",
    "restore": True,                                                            # True if start from restored checkpoint
    "restore_exper_name": "baseline_nostateaug",
    "seed": 4,
    # defining who drive simuls and future policies
    "proxy_sampler": False,    # use a proxy (e.g. loglinear policy) to control simulation
    "proxy_mcsampler": False, # use a proxy to control monte-carlo sampling
    "proxy_futurepol": False,  # use a proxy to control future policy inside monte-carlo sampling
    "comment": "This is an experiment where we use the original state.",
    "comment_at_end": False,

    # neural net
    "layers": [512,512],              # layers of the NN
    # "layers": [6*37,9*32,11*37+5],              # layers of the NN
    # "layers": [512,512],              # layers of the NN

    # learning rate schedule
    "lr_sch_values": [0.00001,0.00001],                                        # values (from the last, we do cosine decay to 0)
    "lr_sch_transitions": [1000],
    "warmup_steps": 0,
    "lr_end_value": 0.0000001,

    # simulation
    "periods_per_epis": 128,
    "simul_vol_scale": 1,        # scale of volatility while simul
    "init_range": 2,

    # loss calculation
    "mc_draws": 256,               # monte-carlo draws

    # training
    "epis_per_step": 64,         # epoch per steps
    "steps_per_epoch": 100,       # steps per epoch
    "n_epochs": 100,               # number of epochs
    "batch_size": 16,             # size of batch of obs to calculate grad


    "config_eval": {
      "periods_per_epis": 64,      # periods per episode
      "mc_draws": 128,         # number of mc draws
      "simul_vol_scale": 1,        # scale of volatility while simul
      "eval_n_epis": 64,           # episodes to sample for eval
      "init_range": 0,      # range (in pp) around steady state for intialization
      "proxy_sampler": False,    # use a proxy (e.g. loglinear policy) to control simulation
      "proxy_mcsampler": False, # use a proxy to control monte-carlo sampling
      "proxy_futurepol": False,  # use a proxy to control future policy inside monte-carlo sampling

    }
}

#create auxiliary config variables for readability
config["periods_per_step"] =config["periods_per_epis"]*config["epis_per_step"]
config["n_batches"] = config["periods_per_step"]//config["batch_size"]

# PRINT AND PLOT KEY CONFIGS
print("Number of parameters:")
print(NeuralNet(config["layers"] + [Model().n_actions],C,policies_sd).tabulate(
    random.PRNGKey(0),
    Model().initial_obs(random.PRNGKey(0))
    ))

print("Number of steps (NN updates): \n", config["steps_per_epoch"]*config["n_epochs"], "episodes \n")


Number of parameters:

[3m                               NeuralNet Summary                                [0m
┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule   [0m[1m [0m┃[1m [0m[1minputs        [0m[1m [0m┃[1m [0m[1moutputs       [0m[1m [0m┃[1m [0m[1mparams              [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│         │ NeuralNet │ [2mfloat64[0m[74]    │ [2mfloat64[0m[412]   │                      │
├─────────┼───────────┼────────────────┼────────────────┼──────────────────────┤
│ Dense_0 │ Dense     │ [2mfloat64[0m[1,74]  │ [2mfloat64[0m[1,512] │ bias: [2mfloat64[0m[512]   │
│         │           │                │                │ kernel:              │
│         │           │                │                │ [2mfloat64[0m[74,512]      │
│         │           │                │                │               

# Create experiment
Now we the entire experiment workflow as a function to call later.

In [None]:
def run_experiment(env, config):
  """Runs experiment."""

  n_cores = len(jax.devices())

  # CREATE NN, RNGS, TRAIN_STATE AND EPOQUE UPDATE
  neural_net = NeuralNet(config["layers"] + [env.n_actions],C,policies_sd)
  rng, rng_pol, rng_env, rng_epoch, rng_eval = random.split(random.PRNGKey(config["seed"]), num=5)  # random number generator

  # CREATE LR SCHEDULE
  lr_schedule = optax.join_schedules(
    schedules= [optax.constant_schedule(i) for i in config["lr_sch_values"][:-1]]
                  + [optax.warmup_cosine_decay_schedule(
                    init_value=config["lr_sch_values"][-1],
                    peak_value=config["lr_sch_values"][-1],
                    warmup_steps=config["warmup_steps"],
                    decay_steps=config["n_epochs"]*config["steps_per_epoch"]-config["lr_sch_transitions"][-1],
                    end_value=config["lr_end_value"],)],
      boundaries=config["lr_sch_transitions"] # the number of episodes at which to switch
      )

  # lr_schedule = optax.join_schedules(
  #   schedules= [optax.constant_schedule(i) for i in config["lr_sch_values"][:-1]]
  #                 + [optax.warmup_cosine_decay_schedule(
  #                   init_value=config["lr_sch_values"][-1],
  #                   peak_value=config["lr_sch_values"][-1],
  #                   warmup_steps=0,
  #                   decay_steps=config["n_epochs"]*config["steps_per_epoch"]-config["lr_sch_transitions"][-1],
  #                   end_value=0.00000005,)],
  #     boundaries=config["lr_sch_transitions"] # the number of episodes at which to switch
  #     )

  # INITIALIZE OR RESTORE FULL NN TRAIN STATE
  if not config["restore"]:
    params=neural_net.init(rng_pol, jnp.zeros_like(env.initial_obs(rng_env)))
    train_state = TrainState.create(apply_fn=neural_net.apply, params=params, tx=optax.adam(lr_schedule))

  else:
    train_state_restored = checkpoints.restore_checkpoint(ckpt_dir=config["save_dir"]+config["restore_exper_name"], target = None)
    params = train_state_restored["params"]
    opt_state = train_state_restored["opt_state"]
    train_state = TrainState.create(apply_fn=neural_net.apply, params=params, tx=optax.adam(lr_schedule))
    train_state.replace(opt_state=opt_state)

  # GET TRAIN AND EVAL FUNCTIONS
  train_epoch_fn  = jax.jit(get_epoch_train_fn(env, config))
  eval_fn  = jax.jit(get_eval_fn(env, config))

  # COMPILE CODE
  time_start = time()
  train_epoch_fn(train_state, rng_epoch)  # compiles
  eval_fn(train_state, rng_epoch) # compiles
  time_compilation = time() - time_start
  print("Time Elapsed for Compilation:", time_compilation, "seconds")

  # RUN AN EPOCH TO GET TIME STATS
  time_start = time()
  train_epoch_fn(train_state, rng_epoch) # run one epoque
  time_epoch = time() - time_start
  print("Time Elapsed for epoch:", time_epoch, "seconds")

  time_start = time()
  eval_fn(train_state, rng_epoch) # run one epoque
  time_eval = time() - time_start
  print("Time Elapsed for eval:", time_eval, "seconds")

  time_experiment = (time_epoch + time_eval)*config["n_epochs"]/60
  print("Estimated time for full experiment", time_experiment, "minutes")

  steps_per_second = config["steps_per_epoch"]*config["periods_per_step"]/time_epoch
  print("Steps per second:", steps_per_second, "st/s")

  # CREATE LISTS TO STORE METRICS
  mean_losses, mean_accuracy, min_accuracy = [], [], []

  # RUN ALL THE EPOCHS
  time_start = time()
  for i in range(1,config["n_epochs"]+1):

    # eval
    eval_metrics = eval_fn(train_state, rng_eval)
    print('EVALUATION:\n',
      'Iteration:', train_state.step,
      "Mean_loss:", eval_metrics[0],
      ", Mean Acc:", eval_metrics[1],
      ", Min Acc:", eval_metrics[2], "\n"
      ", Mean Accs Foc", eval_metrics[3], "\n"
      ", Min Accs Foc:", eval_metrics[4],
      "\n")

    # run epoch
    train_state, rng_epoch, epoch_metrics = train_epoch_fn(train_state, rng_epoch)
    print('TRAINING:\n',
          'Iteration:', train_state.step,
          ", Mean_loss:", jnp.mean(epoch_metrics[0]),
          ", Mean_accuracy:", jnp.mean(epoch_metrics[1]),
          ", Min_accuracy:", jnp.min(epoch_metrics[2]),
          # ", Mean_acc_Kfoc:", jnp.mean(epoch_metrics[:,3]),
          # ", Min_acc_Kfoc:", jnp.min(epoch_metrics[:,4]),
          ", Learning rate:", lr_schedule(train_state.step),
          "\n"
          )

    # checkpoint
    if train_state.step>=1000 and train_state.step%1000==0:
      checkpoints.save_checkpoint(ckpt_dir=config['save_dir']+config['exper_name'], target=train_state, step=train_state.step)

      # store results
      mean_losses.append(float(eval_metrics[0]))
      mean_accuracy.append(float(eval_metrics[1]))
      min_accuracy.append(float(eval_metrics[2]))

    #end of inner loop


  # PRINT RESULTS
  print("Minimum loss attained in evaluation:", min(mean_losses))
  print("Maximum mean accuracy attained in evaluation:", max(mean_accuracy))
  print("Maximum min accuracy attained in evaluation:", max(min_accuracy))
  time_fullexp = (time() - time_start)/60
  print("Time Elapsed for Full Experiment:", time_fullexp, "minutes")

  # ASK FOR A COMMENT:
  # comment = input("Enter a comment for the researcher: ")
  if config["comment_at_end"]:
    comment_result = input("Enter a comment for the researcher: ")
  else:
    comment_result = ""

  # STORE RESULTS
  results = {
    "exper_name": config["exper_name"],
    "comment_preexp": config["comment"],
    "comment_result": comment_result,
    "min_loss":  min(mean_losses),
    "max_mean_acc": max(mean_accuracy),
    "max_min_acc": max(min_accuracy),
    "Time for Full Experiment (m)": time_fullexp,
    "Time for epoch (s)": time_epoch,
    "Time for Compilation (s)": time_compilation,
    "Steps per second": steps_per_second,
    "config": config,
    "Losses_list": mean_losses,
    "mean_accuracy_list": mean_accuracy,
    "min_accuracy_list": min_accuracy,
  }

  # store to json
  if not os.path.exists(config['save_dir']+config['exper_name']):
    os.mkdir(config['save_dir']+config['exper_name'])
  with open(config['save_dir']+config['exper_name']+"/results.json", "w") as write_file:
    json.dump(results, write_file)

  # store to experiments csv
  csv_filename = os.path.join(working_dir, 'experiment_results.csv')
  df = pd.read_csv(csv_filename)
  df = pd.concat([df, pd.DataFrame([results])], ignore_index=True)
  df.to_csv(csv_filename, index=False)

  # PLOT LEARNING

  # Mean Losses
  plt.plot([(i+1)*1000 for i in range(len(mean_losses))], mean_losses)
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Mean Losses')
  plt.savefig(config['save_dir']+config['exper_name']+'/mean_losses.jpg')
  plt.close()

  # Mean Accuracy
  plt.plot([(i+1)*1000 for i in range(len(mean_accuracy))], mean_accuracy)
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Mean Accuracy (%)')
  plt.savefig(config['save_dir']+config['exper_name']+'/mean_accuracy.jpg')
  plt.close()

  # Min Accuracy
  plt.plot([(i+1)*1000 for i in range(len(min_accuracy))], min_accuracy)
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Minimum Accuracy (%)')
  plt.savefig(config['save_dir']+config['exper_name']+'/min_accuracy.jpg')
  plt.close()

  # Learning rate schedule
  plt.plot([(i+1)*1000 for i in range(len(mean_losses))], [lr_schedule((i+1)*1000) for i in range(len(mean_losses))])
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Learning Rate')
  plt.savefig(config['save_dir']+config['exper_name']+'/learning_rate.jpg')
  plt.close()

  return train_state

# Run experiment
*Finally*, we run the experiment abd get the trained parameter plus useful info.

In [None]:
final_train_state = run_experiment(Model(), config)

#DISCONNECT SESSION (uncomment next 2 lines if you do large run)
from google.colab import runtime
runtime.unassign()



Time Elapsed for Compilation: 96.08792734146118 seconds
Time Elapsed for epoch: 30.74180006980896 seconds
Time Elapsed for eval: 0.09713602066040039 seconds
Estimated time for full experiment 51.398226817448936 minutes
Steps per second: 26647.75641438523 st/s
EVALUATION:
 Iteration: 0 Mean_loss: 9.225874777917741e-07 , Mean Acc: 0.9994172325290587 , Min Acc: 0.9683098374862518 
, Mean Accs Foc [0.99978942 0.99936081 0.9991298  0.99942101 0.99977917 0.99988794
 0.99897378 0.99947289 0.99979318 0.99948837 0.99849257 0.99889272
 0.99906906 0.99980508 0.9994019  0.99994034] 
, Min Accs Foc: [0.99318577 0.99105117 0.9888942  0.98038409 0.9933519  0.99687089
 0.96830984 0.98630983 0.99547693 0.97229663 0.98085633 0.99444021
 0.99536036 0.99740343 0.9956177  0.9994219 ] 

TRAINING:
 Iteration: 100 , Mean_loss: 1.3965336315845857e-06 , Mean_accuracy: 0.9992718802467316 , Min_accuracy: 0.9389780938541722 , Learning rate: 1e-05 

EVALUATION:
 Iteration: 100 Mean_loss: 9.093505587935597e-07 , Mea