<a href="https://colab.research.google.com/github/MatiasCovarrubias/jaxecon/blob/main/RbcProdNet_pretrain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pretrain Solution of Rbc model with production Netowrks.

This notebook trains a neural net to output the optimal policies of a loglinearized version of an Rbc model with production networks, calculated using dynare.

In [None]:
# Install dependencies
! pip install optax -q
! pip install flax -q
!pip install --upgrade jax jaxlib 

TPU = True # set True if using TPU runtime
if TPU == True:
    import jax.tools.colab_tpu
    # jax.tools.colab_tpu.setup_tpu(tpu_driver_version='tpu_driver-0.1dev20220801')
    jax.tools.colab_tpu.setup_tpu('tpu_driver_20221011')

# Imports
import jax
from jax import numpy as jnp
from jax import random
from jax import lax
import flax.linen as nn
from flax.training import checkpoints
import optax

import numpy as np
import timeit
from typing import Sequence
import matplotlib.pyplot as plt
print(jax.devices())
import scipy.io as sio
import json
from time import time as time
# import jax.profiler
from jax import pmap,value_and_grad
# from PIL import Image
import os

#Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

[?25l[K     |██▏                             | 10 kB 38.7 MB/s eta 0:00:01[K     |████▎                           | 20 kB 6.9 MB/s eta 0:00:01[K     |██████▍                         | 30 kB 9.8 MB/s eta 0:00:01[K     |████████▌                       | 40 kB 4.1 MB/s eta 0:00:01[K     |██████████▋                     | 51 kB 4.3 MB/s eta 0:00:01[K     |████████████▊                   | 61 kB 5.1 MB/s eta 0:00:01[K     |██████████████▉                 | 71 kB 5.4 MB/s eta 0:00:01[K     |█████████████████               | 81 kB 6.1 MB/s eta 0:00:01[K     |███████████████████             | 92 kB 4.8 MB/s eta 0:00:01[K     |█████████████████████▏          | 102 kB 5.1 MB/s eta 0:00:01[K     |███████████████████████▎        | 112 kB 5.1 MB/s eta 0:00:01[K     |█████████████████████████▍      | 122 kB 5.1 MB/s eta 0:00:01[K     |███████████████████████████▌    | 133 kB 5.1 MB/s eta 0:00:01[K     |█████████████████████████████▋  | 143 kB 5.1 MB/s eta 0:00:01[K    

## 0. Import model structure and create policies

To start, we will import the variance-covariance matrix of the shocks.

In [None]:
# import model parameters

# Model Info
model_data = sio.loadmat("/content/drive/MyDrive/Jaxecon/RbcProdNet_adj_ModData.mat", simplify_cells=True)
modparams = model_data["StStval"]["parameters"] 
k_ss = jnp.array(model_data["StStval"]["k_ss"]) # 3
policies_ss = jnp.array(model_data["StStval"]["policies_ss"]) # policies in logs
print(policies_ss.shape)
Sigma_A = jnp.array(model_data["StStval"]["parameters"]["Sigma_A"])  # variance-covariance of TFP shocks
delta = jnp.array(model_data["StStval"]["parameters"]["delta"]) 
n_sectors = model_data["StStval"]["parameters"]["n_sectors"]
del model_data

# Dynare policies and the standard deviation of variables under that policies
dynare_pols = sio.loadmat("/content/drive/MyDrive/Jaxecon/RbcProdNet_adj_dynare_Nov21.mat", simplify_cells=True)
states_sd = jnp.array(dynare_pols["states_sd"])
shocks_sd = jnp.array(dynare_pols["shocks_sd"])
policies_sd = jnp.array(dynare_pols["policies_sd"])
print(policies_sd.shape)
A = jnp.array(dynare_pols["A"])
B = jnp.array(dynare_pols["B"])
C = jnp.array(dynare_pols["C"])
D = jnp.array(dynare_pols["D"])
print(A.shape, B.shape, C.shape, D.shape)
del dynare_pols
print("done importing parameteres and dynare state space representation")


(185,)
(185,)
(74, 74) (74, 37) (185, 74) (185, 37)
done importing parameteres and dynare state space representation


## 1. Pre-train experiment

We will start by creating a class that initialize and step forward our economic model

In [None]:
# Environment
class RbcProdNet_pretrain():
  """A JAX implementation of an RBC model with Production Networks."""

  def __init__(self, 
  n_sectors=n_sectors, Sigma_A=Sigma_A, delta=delta, k_ss=k_ss, policies_ss=policies_ss, 
  states_sd=states_sd, shocks_sd=shocks_sd, A=A, B=B, C=C, D=D, mc_draws = 128):

    self.n_sectors = n_sectors
    self.Sigma_A = Sigma_A
    self.delta=delta
    self.states_ss = jnp.concatenate([k_ss, jnp.zeros(shape=(self.n_sectors))])
    self.policies_ss = policies_ss
    self.states_sd = states_sd
    self.shocks_sd = shocks_sd
    self.A = A
    self.B = B
    self.C = C
    self.D = D
    self.num_actions = len(policies_ss)


  def initial_state(self, rng):
    e = self.sample_shock(rng)[0]
    state_init = jnp.dot(self.B,e)/self.states_sd
    return lax.stop_gradient(state_init)

  def step(self, state, shock):
    shock_norm = shock/self.shocks_sd
    obs = jnp.concatenate([state, shock_norm]) 
    state_notnorm = state*self.states_sd
    # K = jnp.exp(state_notnorm+self.states_ss)[:self.n_sectors]
    new_state_notnorm =jnp.dot(self.A,state_notnorm)+jnp.dot(self.B,shock)
    # K_tplus1 = jnp.exp(new_state_notnorm+self.states_ss)[:self.n_sectors]
    new_state = new_state_notnorm/self.states_sd
    policy_devs = jnp.dot(self.C,state_notnorm)+jnp.dot(self.D,shock)
    policy_dynare = jnp.exp(policy_devs)
    # I_implied = jnp.where(K_tplus1 - (1-self.delta)*K>0,K_tplus1 - (1-self.delta)*K,0.00001)
    # idevs = jnp.exp(jnp.log(I_implied) -self.policies_ss[:self.n_sectors])
    # idevs = jnp.where(idevs<3,idevs,3)
    # idevs = jnp.where(idevs>0.05,idevs,0.05)
    # policy  = jnp.concatenate([idevs, policy_dynare[self.n_sectors:]]) 
    # policy_dynare = jnp.exp(self.policies_ss+policy_devs)
    policy = policy_dynare
    train_pair = (obs, policy)
    return new_state, train_pair

  def sample_shock(self, rng, n_draws=1):
    """ sample one realization of the shock """
    return jax.random.multivariate_normal(rng, jnp.zeros((self.n_sectors,)), self.Sigma_A, shape=(n_draws,))
    # return random.choice(rng, self.shocks_mc, shape=(n_draws,))
    # return self.shock_sd*random.normal(rng)
  
  def mc_shocks(self, rng=random.PRNGKey(1), mc_draws=8):
    return jax.random.multivariate_normal(random.PRNGKey(1), jnp.zeros((self.n_sectors,)), self.Sigma_A, shape=(mc_draws,))
  


### 2. Create Neural Net policy

First, we use Flax to create the Neural Net, Notice that we activate the last layer using Softplus to guarantee that we get possitive outputs

In [None]:
class MLP_softplus(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.selu(nn.Dense(feat)(x))
    x = nn.softplus(nn.Dense(self.features[-1])(x))
    return x

## 3. Creating computational workflow for our experiment

No we define the entire workflow for an epoque.

In [None]:
def get_epoque_learner_fn(
    env, nn_forward, opt_update, batch_size, epoque_iters):
  """It runs and epoque with learing. This is what the compiler reads and parallelize (the minimal unit of computation)."""
  
  def epis_loss_fn(nn_params, loss_rng, env_state):
    # shocks for the entire trajectory.
    shocks = env.sample_shock(loss_rng,n_draws=batch_size)
    state_final, train_pairs = lax.scan(env.step, env_state, shocks)  # apply period steps for each row shock in shocks.
    obs_vector, dynare_policy_vector = train_pairs
    nn_policy_vector = nn_forward(nn_params, obs_vector)
    epis_abs_loss = jnp.mean(jnp.abs(jnp.divide(dynare_policy_vector, nn_policy_vector)-jnp.ones_like(dynare_policy_vector)))
    epis_loss = jnp.mean(jnp.square(jnp.divide(dynare_policy_vector, nn_policy_vector)-jnp.ones_like(dynare_policy_vector)))
    
    return epis_loss, (state_final, jnp.array([epis_loss]),jnp.array([epis_abs_loss]))

  def update_fn(params, opt_state, epoque_rng, env_state, mean_loss, mean_abs_loss):
    """Compute a gradient update from a single trajectory."""
    new_epoque_rng, loss_rng = random.split(epoque_rng)
    grads, aux_info  = jax.grad(  # compute gradient on a single trajectory.
        epis_loss_fn, has_aux=True)(params, loss_rng, env_state)
    new_env_state, mean_loss, mean_abs_loss = aux_info
    grads = lax.pmean(grads, axis_name='j')  # reduce mean (average grads) across cores.
    grads = lax.pmean(grads, axis_name='i')  # reduce mean (average grads) across batch.
    updates, new_opt_state = opt_update(grads, opt_state)  # transform grads.
    new_params = optax.apply_updates(params, updates)  # update parameters.
    return new_params, new_opt_state, new_epoque_rng, new_env_state, mean_loss, mean_abs_loss

  def epoque_learner_fn(params, opt_state, rngs, env_states, mean_loss, mean_abs_loss):
    """Vectorise and repeat the update."""
    batched_update_fn = jax.vmap(update_fn, axis_name='j')  # vectorize across batch.
    def iterate_fn(_, val):  # repeat many times to avoid going back to Python.
      params, opt_state, rngs, env_states, mean_loss, mean_abs_loss = val
      return batched_update_fn(params, opt_state, rngs, env_states, mean_loss, mean_abs_loss)
    return lax.fori_loop(0, epoque_iters, iterate_fn, (
        params, opt_state, rngs, env_states, mean_loss, mean_abs_loss))

  return epoque_learner_fn

## 4. Configure and run experiemnt

Now we define the learning rate schedule and the config dictionary for our experiment.

In [None]:
def run_experiment(env, config):
  """Runs experiment."""

  cores_count = len(jax.devices())  # get available devices.
  nn_policy = MLP_softplus(config["layers"] + [env.num_actions])
  optim = optax.adam(config["learning_rate"])  # define optimiser.

  rng, rng_e, rng_p = random.split(random.PRNGKey(config["seed"]), num=3)  # prng keys.
  dummy_obs = jnp.concatenate([env.initial_state(rng_e),jnp.zeros(shape=(env.n_sectors,))])  # dummy for net init.
  params = nn_policy.init(rng_p, dummy_obs)  # initialise params.
  nn_forward = nn_policy.apply
  mean_loss = jnp.array([0.0]) # initialize loss
  mean_abs_loss = jnp.array([0.0]) # initialize loss
  opt_state = optim.init(params)  # initialise optimiser stats.
  learn = get_epoque_learner_fn(env, nn_forward, optim.update, config["batch_size"], config["epoque_iters"])
  learn = jax.pmap(learn, axis_name='i')  # replicate over multiple cores.

  broadcast = lambda x: jnp.broadcast_to(x, (cores_count, config["n_batches"]) + x.shape)
  params = jax.tree_map(broadcast, params)  # broadcast to cores and batch.
  opt_state = jax.tree_map(broadcast, opt_state)  # broadcast to cores and batch
  mean_loss = jax.tree_map(broadcast, mean_loss)
  mean_abs_loss = jax.tree_map(broadcast, mean_abs_loss)

  rng, *env_rngs = jax.random.split(rng, cores_count * config["n_batches"]+ 1)
  env_states = jax.vmap(env.initial_state)(jnp.stack(env_rngs))  # init envs.
  rng, *step_rngs = jax.random.split(rng, cores_count * config["n_batches"] + 1)
  rng, *eval_rngs = jax.random.split(rng, cores_count * config["n_batches"] + 1)

  reshape = lambda x: x.reshape((cores_count, config["n_batches"]) + x.shape[1:])
  step_rngs = reshape(jnp.stack(step_rngs))  # add dimension to pmap over.
  eval_rngs = reshape(jnp.stack(eval_rngs))  # add dimension to pmap over.
  env_states = reshape(env_states)  # add dimension to pmap over.

  mean_losses = []
  mean_accuracy = []
  num_steps = cores_count * config["epoque_iters"] * config["batch_size"] * config["n_batches"]

  time_start = time()

  learn(params, opt_state, step_rngs, env_states, mean_loss, mean_abs_loss)  # compiles
  time_compilation = time() - time_start
  print("Time Elapsed for Compilation:", time_compilation, "seconds")
  
  #First run, we calculate periods per second
  time_start = time()

  params, opt_state, step_rngs, env_states, mean_loss, mean_abs_loss = learn(
      params, opt_state, step_rngs, env_states, mean_loss, mean_abs_loss)
  time_epoque = time() - time_start
  print("Time Elapsed for Epoque:", time_epoque, "seconds")
  steps_persec = num_steps/time_epoque
  print("Steps per second:", steps_persec, "st/s")

  # mean_losses.append(float(jnp.mean(mean_loss))) 
  # mean_accuracy.append(float((1- jnp.mean(mean_abs_loss))*100))

  # print('Iteration:', 1*config["epoque_iters"],
  #         ", Mean_loss:", jnp.mean(mean_loss),
  #         ", Learning rate:", config["learning_rate"](1*config["epoque_iters"]), 
  #         ", Mean accuracy (%):", (1- jnp.mean(mean_abs_loss))*100)
  
  #Rest of the runs
  time_start = time()
  for i in range(2,config["n_epoques"]+1):
    rng, *step_rngs = jax.random.split(rng, cores_count * config["n_batches"] + 1)
    step_rngs = reshape(jnp.stack(step_rngs))
    params, opt_state, step_rngs, env_states, mean_loss, mean_abs_loss = learn( 
        params, opt_state, step_rngs, env_states, mean_loss, mean_abs_loss) 
    
    mean_losses.append(float(jnp.mean(mean_loss))) 
    mean_accuracy.append(float((1- jnp.mean(mean_abs_loss))*100))
      
    print('Iteration:', i*config["epoque_iters"],
          ", Mean_loss:", jnp.mean(mean_loss),
          ", Learning rate:", config["learning_rate"](i*config["epoque_iters"]), 
          ", Mean accuracy (%):", (1- jnp.mean(mean_abs_loss))*100)
    
    if i%config["reset_env_nepoques"]==0:
      env_states = jnp.zeros_like(env_states)
      print("ENV RESET")

  """ Save Params and Results """

  # Params
  checkpoints.save_checkpoint(ckpt_dir=config['working_dir']+config['run_name'], target=params, step=config["n_epoques"]*config["epoque_iters"])

  # Results
  max_acc = max(mean_accuracy)
  min_loss = min(mean_losses)
  print("Maximum accuracy attained in training:", max_acc)
  
  # time elasped
  time_fullexp = (time() - time_start)/60
  print("Time Elapsed for Full Experiment:", time_fullexp, "minutes")

  results = {
      "max_accuracy": max_acc,
      "min_loss": min_loss,
      "Losses_list": mean_losses,
      "Accuracy_list":mean_accuracy,
      "Time for Full Experiment (m)": time_fullexp,
      "Time for Epoque (s)": time_epoque,
      "Time for Compilation (s)": time_compilation,
      "Steps per second": steps_persec,
      "Number of devices": cores_count,
      "Batches per device": config["n_batches"],
      "Batch Size": config["batch_size"],
      "NN updates per epoque": config["epoque_iters"],
      "Number of Epoques": config["n_epoques"],
      "date": config["date"],
      "seed": config["seed"]
  }

  if not os.path.exists(config['working_dir']+config['run_name']):
    os.mkdir(config['working_dir']+config['run_name']) 
  with open(config['working_dir']+config['run_name']+"/results.json", "w") as write_file:
    json.dump(results, write_file)


  """ Plots """

  # Mean Losses
  plt.plot([(i)*config["epoque_iters"] for i in range(len(mean_losses))], mean_losses)
  plt.xlabel('Episodes (NN updates)')
  plt.ylabel('Mean Losses')
  plt.savefig(config['working_dir']+config['run_name']+'/mean_losses.jpg')
  plt.close()
  
  # Accuracy
  plt.plot([(i)*config["epoque_iters"] for i in range(len(mean_accuracy))], mean_accuracy)
  plt.xlabel('Episodes (NN updates)')
  plt.ylabel('Mean Accuracy (%)')
  plt.savefig(config['working_dir']+config['run_name']+'/mean_accuracy.jpg')
  plt.close()

  # Learning rate schedule   
  plt.plot(
      [i*config["epoque_iters"] for i in range(config["n_epoques"])], 
      [config["learning_rate"](i*config["epoque_iters"]) for i in range(config["n_epoques"])]
      )
  plt.xlabel('Episodes (NN updates')
  plt.ylabel('Learning Rate')
  plt.savefig(config['working_dir']+config['run_name']+'/learning_rate.jpg')
  plt.close()
  
  return params, optim, nn_policy, mean_losses, mean_accuracy

Now we are to configure our experiment

In [None]:
'''Confg dictionary'''

lr_schedule = optax.join_schedules(
      schedules= [optax.constant_schedule(0.0001),
                  optax.constant_schedule(0.00005),
                  optax.constant_schedule(0.00001),
                  optax.constant_schedule(0.000005),
                  optax.constant_schedule(0.000001),
                  optax.constant_schedule(0.0000005),
                  optax.constant_schedule(0.0000001),
                  optax.constant_schedule(0.000000005)],   
      boundaries=[200000,300000,400000,500000,600000,700000,900000]
      )
# Now we create a config dict
config = {
    "n_batches": 64, # number of minibatches per device (if TPU, you have 8 devices)
    "batch_size": 64, # size of each minibatch
    "layers": [512,512], # layers of the NN
    "epoque_iters": 1000, # frequency at which we print mean loss
    "n_epoques": 1000, # number of log cycles (4000)
    #(if epoque_iters =100, and n_epoques=1000, total iters are 100000)
    "learning_rate": lr_schedule,
    "seed": 261, # random seed, set to whatever int.
    "reset_env_nepoques": 1,
    "run_name": "run_Dec17_2x512_nb64ns64",
    "date": "December_17",
    "working_dir": "/content/drive/MyDrive/Jaxecon/Pretraining/"
}

# Print some key statistics
cores_count = len(jax.devices())

num_steps_perepisode = cores_count * config["batch_size"] * config["n_batches"]
print("periods per episode:", num_steps_perepisode)

num_steps_percycle = cores_count * config["epoque_iters"] * config["batch_size"] * config["n_batches"]
print("periods per epoque:", num_steps_percycle)


periods per episode: 32768
periods per epoque: 32768000


## 
Next, we run the experiment.

In [None]:
# Run Experiment
params, optim, nn_policy, mean_losses, mean_accuracy = run_experiment(RbcProdNet_pretrain(), config)
# Close the session
from google.colab import runtime
runtime.unassign()

Time Elapsed for Compilation: 9.044264554977417 seconds
Time Elapsed for Epoque: 5.8184380531311035 seconds
Steps per second: 5631751.975492186 st/s
Iteration: 2000 , Mean_loss: 0.013260207 , Learning rate: 1e-04 , Mean accuracy (%): 90.933495
ENV RESET
Iteration: 3000 , Mean_loss: 0.008732321 , Learning rate: 1e-04 , Mean accuracy (%): 92.66965
ENV RESET
Iteration: 4000 , Mean_loss: 0.005757059 , Learning rate: 1e-04 , Mean accuracy (%): 94.06194
ENV RESET
Iteration: 5000 , Mean_loss: 0.0038084714 , Learning rate: 1e-04 , Mean accuracy (%): 95.18812
ENV RESET
Iteration: 6000 , Mean_loss: 0.0024606916 , Learning rate: 1e-04 , Mean accuracy (%): 96.14712
ENV RESET
Iteration: 7000 , Mean_loss: 0.0015897825 , Learning rate: 1e-04 , Mean accuracy (%): 96.9191
ENV RESET
Iteration: 8000 , Mean_loss: 0.0010255241 , Learning rate: 1e-04 , Mean accuracy (%): 97.54248
ENV RESET
Iteration: 9000 , Mean_loss: 0.0006667292 , Learning rate: 1e-04 , Mean accuracy (%): 98.03607
ENV RESET
Iteration: 100

## 5. Testing

First, we are going to test the environment

In [None]:
""" Test the environment """

env = RbcProdNet_pretrain()
rng_test = random.PRNGKey(80)

# check init function
state_init = env.initial_state(rng_test)
print('inital state shape', state_init.shape)

# check step function
shock = jax.random.multivariate_normal(rng_test, jnp.zeros((env.n_sectors,)), env.Sigma_A)
state_final, train_pair = env.step(state_init,shock)
obs, policy = train_pair
print("shape of state after step", state_final.shape)
print("shape of obs after step", obs.shape)
print("shape of policy after step", policy.shape)

# Simulate env 
n_periods = 500000
shocks = jax.random.multivariate_normal(rng_test, jnp.zeros((env.n_sectors,)), env.Sigma_A, shape=(n_periods,))
state_final, obs_policy_pair = lax.scan(env.step,state_init,shocks)
obs, policy = obs_policy_pair
print('mean for each state (should be ~ 0)')
print(jnp.mean(obs,axis=0))
print('std for each state (should be ~ 1)')
print(jnp.std(obs, axis=0))
print('mean for each policy:')
print(jnp.mean(policy, axis=0))
print('s.d. for each policy:')
print(jnp.std(policy, axis=0))
print('Max policy:', jnp.max(policy, axis=0))
print('Min policy:', jnp.min(policy, axis=0))


inital state shape (74,)
shape of state after step (74,)
shape of obs after step (111,)
shape of policy after step (185,)
mean for each state (should be ~ 0)
[ 5.2836444e-03 -2.7969519e-03  1.2257562e-03  1.7797967e-03
 -4.3940647e-03 -3.2156992e-03  4.3182825e-03  3.0537480e-03
  5.7349345e-03  3.1981235e-03  2.7504975e-03  2.4309992e-03
  2.9521785e-03  5.4126186e-03  8.0746310e-03  2.2752767e-03
  8.5855210e-03  7.6383785e-03  1.3050292e-02 -9.6182721e-03
  5.7564424e-03  6.2226369e-03  5.5473782e-03  6.9925101e-03
  7.1222517e-03  3.9179055e-03  6.0606515e-03  8.4376149e-03
  5.0477851e-03 -4.3817805e-04  8.2724001e-03  2.9298270e-03
  1.1224650e-02  9.4486205e-03  5.3252513e-03  7.2411294e-03
  2.1112761e-03  1.5854822e-02 -1.9006297e-03  6.6957874e-03
  2.1775956e-03  6.9863633e-03  1.9836547e-03 -1.0845886e-03
 -2.1149639e-04 -3.2887834e-03  9.1188686e-04  2.4856426e-04
  2.4689634e-03  3.4088970e-03 -4.4880924e-03 -3.9758155e-04
  4.4349264e-03 -2.4671915e-03 -5.5301952e-04 -5.

In [None]:
# Close the session
from google.colab import runtime
runtime.unassign()

## 6. Analysis

We first is to test the performance of our learned policy vs the loglienar polict. 