<a href="https://colab.research.google.com/github/MatiasCovarrubias/jaxecon/blob/main/jaxDEQN_rbc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DEQN Solver in Jax: Prelims

This notebook trains a neural net to output the optimal policy of a nonlinear Rbc model.



In [1]:
# BACKEND RELATED: the default backend is CPU (change in Edit -> Notebook settings)
GPU = True # set True if using GPU (only to see GPU)
if GPU:
  !nvidia-smi

# precision
from jax import numpy as jnp, lax, random, config
double_precision = True
if double_precision:
  config.update("jax_enable_x64", True)
  precision = jnp.float64
else:
  precision = jnp.float32

# Imports
import matplotlib.pyplot as plt, flax.linen as nn, pandas as pd, jax, flax, optax, os, json
from flax.training.train_state import TrainState  # Useful dataclass to keep train state
from flax.training import checkpoints
from time import time
from typing import Sequence
config.update("jax_debug_nans", True)

! git clone https://github.com/MatiasCovarrubias/jaxecon
import sys
sys.path.insert(0,'/content/jaxecon')
from DEQN.neural_nets.neural_nets import NeuralNet
from DEQN.econ_models.rbc import Rbc
from DEQN.algorithm.simulation import create_episode_simul_fn
from DEQN.algorithm.loss import create_batch_loss_fn
from DEQN.algorithm.epoch_train import create_epoch_train_fn
from DEQN.algorithm.eval import create_eval_fn

# Mount Google Drive to store results (a pop up will appear, follow instructions)
from google.colab import drive
drive.mount('/content/drive')
working_dir = "/content/drive/MyDrive/Jaxecon/Rbc/Experiments/"


Mon Mar 11 08:45:50 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P8               9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

## Create csv file for all experiments related to this econ model

In [2]:

columns = [
  "exper_name",
  "min_mean_loss",
  "min_max_loss",
  "max_mean_acc",
  "max_min_acc",
  "time_full_exp_minutes",
  "time_epoque_seconds",
  "time_compilation_seconds",
  "steps_per_second",
  "config",
  "mean_losses_list",
  "max_losses_list",
  "mean_acc_list",
  "min_acc_list",
]

# Create an empty DataFrame with the specified columns
df = pd.DataFrame(columns=columns)

# Specify the file path for the CSV file in your working directory
csv_filename = os.path.join(working_dir, 'experiment_results.csv')
# Save the DataFrame to a CSV file if it doesn't exist
if not os.path.isfile(csv_filename):
  df.to_csv(csv_filename, index=False)
  print(f"New experiments csv saved as {csv_filename}")
else:
  print("The experiments csv already exists")

The experiments csv already exists


# Configure experiment

In [65]:
'''Config dictionary'''

# CREATE CONFIG DICT
config = {
    # general
    "seed": 48,
    "exper_name": "rbc_smallNN_Mar10_24",
    "save_dir": working_dir,
    "restore": False,                                                            # True if start from restored checkpoint
    "restore_run_name": None,
    "restore_exper_name": "",
    "seed": 5,

    # neural net
    "layers": [8,8],              # layers of the NN

    # learning rate schedule
    "lr_sch_values": [0.001,0.001],                                        # values (from the last, we do cosine decay to 0)
    "lr_sch_transitions": [2000],
    "lr_end_value": 1e-7,

    # simulation
    "periods_per_epis": 32,
    "simul_vol_scale": 1.5,        # scale of volatility while simul
    "init_range": 10,

    # loss calculation
    "mc_draws": 128,               # monte-carlo draws

    # training
    "epis_per_step": 256,         # epoch per steps
    "steps_per_epoch": 100,       # steps per epoch
    "n_epochs": 100,               # number of epochs
    "batch_size": 16,             # size of batch of obs to calculate grads
    "init_range": 0,
    "checkpoint_frequency": 1000,

    "config_eval": {
      "periods_per_epis": 64,      # periods per episode
      "mc_draws": 128,         # number of mc draws
      "simul_vol_scale": 1,        # scale of volatility while simul
      "eval_n_epis": 128,           # episodes to sample for eval
      "init_range": 0
    }
}

#create auxiliary config variables for readability
config["periods_per_step"] =config["periods_per_epis"]*config["epis_per_step"]
config["n_batches"] = config["periods_per_step"]//config["batch_size"]

# PRINT AND PLOT KEY CONFIGS
print("Number of parameters:")
print(NeuralNet(config["layers"] + [Rbc().n_actions]).tabulate(
    random.PRNGKey(0),
    Rbc().initial_obs(random.PRNGKey(0))
    ))

print("TOTAL Number of steps (NN updates):", config["steps_per_epoch"]*config["n_epochs"], "episodes \n")


Number of parameters:

[3m                           NeuralNet Summary                            [0m
┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule   [0m[1m [0m┃[1m [0m[1minputs    [0m[1m [0m┃[1m [0m[1moutputs   [0m[1m [0m┃[1m [0m[1mparams              [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│         │ NeuralNet │ [2mfloat32[0m[2] │ [2mfloat32[0m[1] │                      │
├─────────┼───────────┼────────────┼────────────┼──────────────────────┤
│ Dense_0 │ Dense     │ [2mfloat32[0m[2] │ [2mfloat32[0m[8] │ bias: [2mfloat32[0m[8]     │
│         │           │            │            │ kernel: [2mfloat32[0m[2,8] │
│         │           │            │            │                      │
│         │           │            │            │ [1m24 [0m[1;2m(96 B)[0m            │
├─────────┼───────────┼────────────┼────────────┼────

# Create experiment
Now we create code for the entire experiment as a function to call later.

In [8]:
def run_experiment(econ_model, config):
  """Runs experiment."""


  # CREATE NN, RNGS, TRAIN_STATE AND EPOQUE UPDATE
  nn = NeuralNet(config["layers"] + [econ_model.n_actions])
  rng, rng_pol, rng_econ_model, rng_epoch, rng_eval = random.split(random.PRNGKey(config["seed"]), num=5)  # random number generator

  # CREATE LR SCHEDULE
  lr_schedule = optax.join_schedules(
    schedules= [optax.constant_schedule(i) for i in config["lr_sch_values"][:-1]]
                  + [optax.warmup_cosine_decay_schedule(
                    init_value=config["lr_sch_values"][-1],
                    peak_value=config["lr_sch_values"][-1],
                    warmup_steps=0,
                    decay_steps=config["n_epochs"]*config["steps_per_epoch"]-config["lr_sch_transitions"][-1],
                    end_value=config["lr_end_value"],)],
      boundaries=config["lr_sch_transitions"] # the number of episodes at which to switch
      )

  # INITIALIZE OR RESTORE FULL NN TRAIN STATE
  if not config["restore"]:
    params=nn.init(rng_pol, jnp.zeros_like(econ_model.initial_obs(rng_econ_model)))
    train_state = TrainState.create(apply_fn=nn.apply, params=params, tx=optax.adam(lr_schedule))
  else:
    train_state_restored = checkpoints.restore_checkpoint(ckpt_dir=config["working_dir"]+config["restore_run_name"], target = None)
    params = train_state_restored["params"]
    opt_state = train_state_restored["opt_state"]
    train_state = TrainState.create(apply_fn=nn.apply, params=params, tx=optax.adam(lr_schedule))
    train_state.replace(opt_state=opt_state)

  # GET TRAIN AND EVAL FUNCTIONS
  simul_fn = jax.jit(create_episode_simul_fn(econ_model,config))
  loss_fn = jax.jit(create_batch_loss_fn(econ_model,config))
  train_epoch_fn  = jax.jit(create_epoch_train_fn(econ_model, config, simul_fn, loss_fn))
  eval_fn  = jax.jit(create_eval_fn(config, simul_fn, loss_fn))

  # COMPILE CODE
  time_start = time()
  train_epoch_fn(train_state, rng_epoch)  # compiles
  eval_fn(train_state, rng_epoch) # compiles
  time_compilation = time() - time_start
  print("Time Elapsed for Compilation:", time_compilation, "seconds")

  # RUN AN EPOCH TO GET TIME STATS
  time_start = time()
  train_epoch_fn(train_state, rng_epoch) # run one epoque
  time_epoch = time() - time_start
  print("Time Elapsed for epoch:", time_epoch, "seconds")

  time_start = time()
  eval_fn(train_state, rng_epoch) # run one epoque
  time_eval = time() - time_start
  print("Time Elapsed for eval:", time_eval, "seconds")

  time_experiment = (time_epoch + time_eval)*config["n_epochs"]/60
  print("Estimated time for full experiment", time_experiment, "minutes")

  steps_per_second = config["steps_per_epoch"]*config["periods_per_step"]/time_epoch
  print("Steps per second:", steps_per_second, "st/s")

  # CREATE LISTS TO STORE METRICS
  mean_losses, max_losses, mean_accuracy, min_accuracy = [], [], [], []

  # RUN ALL THE EPOCHS
  time_start = time()
  for i in range(1,config["n_epochs"]+1):

    # eval
    eval_metrics = eval_fn(train_state, rng_eval)
    print('EVALUATION:\n',
      'Iteration:', train_state.step,
      "Mean Loss:", eval_metrics[0],
      ", Max Loss:", eval_metrics[1],
      ", Mean Acc:", eval_metrics[2],
      ", Min Acc:", eval_metrics[3], "\n"
      ", Mean Accs Foc", eval_metrics[4], "\n"
      ", Min Accs Foc:", eval_metrics[5],
      "\n")

    # run epoch
    train_state, rng_epoch, epoch_metrics = train_epoch_fn(train_state, rng_epoch)
    print('TRAINING:\n',
          'Iteration:', train_state.step,
          ", Mean Loss:", jnp.mean(epoch_metrics[0]),
          ", Max Loss:", jnp.mean(epoch_metrics[1]),
          ", Mean Acc:", jnp.mean(epoch_metrics[2]),
          ", Min Acc:", jnp.min(epoch_metrics[3]),
          ", Learning rate:", lr_schedule(train_state.step),
          "\n"
          )

    # checkpoint
    if train_state.step>=config["checkpoint_frequency"] and train_state.step%config["checkpoint_frequency"]==0:
      checkpoints.save_checkpoint(ckpt_dir=config['save_dir']+config['exper_name'], target=train_state, step=train_state.step)

      # store results
      mean_losses.append(float(eval_metrics[0]))
      max_losses.append(float(eval_metrics[1]))
      mean_accuracy.append(float(eval_metrics[2]))
      min_accuracy.append(float(eval_metrics[3]))

    #end of inner loop

  # PRINT RESULTS
  print("Minimum mean loss attained in evaluation:", min(mean_losses))
  print("Minimum max loss attained in evaluation:", min(max_losses))
  print("Maximum mean accuracy attained in evaluation:", max(mean_accuracy))
  print("Maximum min accuracy attained in evaluation:", max(min_accuracy))
  time_fullexp = (time() - time_start)/60
  print("Time Elapsed for Full Experiment:", time_fullexp, "minutes")

  # STORE RESULTS
  results = {
    "exper_name": config["exper_name"],
    "min_mean_loss":  min(mean_losses),
    "min_max_loss": max(max_losses),
    "max_mean_acc": max(mean_accuracy),
    "max_min_acc": max(min_accuracy),
    "time_full_exp_minutes": time_fullexp,
    "time_epoque_seconds": time_epoch,
    "time_compilation_seconds": time_compilation,
    "steps_per_second": steps_per_second,
    "config": config,
    "mean_losses_list": mean_losses,
    "max_losses_list": max_losses,
    "mean_acc_list": mean_accuracy,
    "min_acc_list": min_accuracy,
  }

  # store to json
  if not os.path.exists(config['save_dir']+config['exper_name']):
    os.mkdir(config['save_dir']+config['exper_name'])
  with open(config['save_dir']+config['exper_name']+"/results.json", "w") as write_file:
    json.dump(results, write_file)

  # PLOT LEARNING

  # Mean Losses
  plt.plot([(i+1)*config["checkpoint_frequency"] for i in range(len(mean_losses))], mean_losses)
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Mean Losses')
  plt.savefig(config['save_dir']+config['exper_name']+'/mean_losses.jpg')
  plt.close()

  plt.plot([(i+1)*config["checkpoint_frequency"] for i in range(len(max_losses))], max_losses)
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Max Losses')
  plt.savefig(config['save_dir']+config['exper_name']+'/mean_losses.jpg')
  plt.close()

  # Mean Accuracy
  plt.plot([(i+1)*config["checkpoint_frequency"] for i in range(len(mean_accuracy))], mean_accuracy)
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Mean Accuracy (%)')
  plt.savefig(config['save_dir']+config['exper_name']+'/mean_accuracy.jpg')
  plt.close()

  # Min Accuracy
  plt.plot([(i+1)*config["checkpoint_frequency"] for i in range(len(min_accuracy))], min_accuracy)
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Minimum Accuracy (%)')
  plt.savefig(config['save_dir']+config['exper_name']+'/min_accuracy.jpg')
  plt.close()

  # Learning rate schedule
  plt.plot([(i+1)*config["checkpoint_frequency"] for i in range(len(mean_losses))], [lr_schedule((i+1)*config["checkpoint_frequency"]) for i in range(len(mean_losses))])
  plt.xlabel('Steps (NN updates)')
  plt.ylabel('Learning Rate')
  plt.savefig(config['save_dir']+config['exper_name']+'/learning_rate.jpg')
  plt.close()

  return train_state

# Run experiment
*Finally*, we run the experiment abd get the trained parameter plus useful info.

In [66]:
econ_model = Rbc()
final_train_state = run_experiment(econ_model, config)

# DISCONNECT SESSION (uncomment next 2 lines if you do large run)
# from google.colab import runtime
# runtime.unassign()

Time Elapsed for Compilation: 6.7579779624938965 seconds
Time Elapsed for epoch: 0.6061084270477295 seconds
Time Elapsed for eval: 0.004477977752685547 seconds
Estimated time for full experiment 1.0176440080006917 minutes
Steps per second: 1351573.3546062214 st/s
EVALUATION:
 Iteration: 0 Mean Loss: 0.0033413086 , Max Loss: 0.032399423 , Mean Acc: 0.94753325 , Min Acc: 0.8200016 
, Mean Accs Foc [0.94753325] 
, Min Accs Foc: [0.8200016] 

TRAINING:
 Iteration: 100 , Mean Loss: 0.0016367657 , Max Loss: 0.0051183635 , Mean Acc: 0.96434903 , Min Acc: 0.8352863 , Learning rate: 0.001 

EVALUATION:
 Iteration: 100 Mean Loss: 0.000878623 , Max Loss: 0.005064266 , Mean Acc: 0.9751673 , Min Acc: 0.92883635 
, Mean Accs Foc [0.9751673] 
, Min Accs Foc: [0.92883635] 

TRAINING:
 Iteration: 200 , Mean Loss: 0.0005626963 , Max Loss: 0.0024158438 , Mean Acc: 0.9809046 , Min Acc: 0.868834 , Learning rate: 0.001 

EVALUATION:
 Iteration: 200 Mean Loss: 0.00036315792 , Max Loss: 0.008054408 , Mean Acc