<a href="https://colab.research.google.com/github/MatiasCovarrubias/jaxecon/blob/main/apg_RbcMS_Sep12.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# APG algorithm in Jax/Flax.

This notebook trains a neural net to output the optimal policy of a nonlinear Rbc model.



In [None]:
# BACKEND RELATED: the default backend is CPU (change in Edit -> Notebook settings)
TPU = False # set True if using TPU
GPU = True # set True if using GPU (only to see GPU)
if TPU:
  !pip install --upgrade jax jaxlib
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu('tpu_driver_20221011')
elif GPU:
  !nvidia-smi

# CHECK VERSIONS
!pip show jax | grep Version && pip show jaxlib | grep Version && pip show flax | grep Version && pip show optax | grep Version

# IMPORTS
import matplotlib.pyplot as plt, jax, flax, optax, os, json
from jax import numpy as jnp, lax, random, config
from flax import linen as nn
from flax.training import train_state as ts_class, checkpoints
from time import time
from typing import Sequence, NamedTuple, Any, Optional, Tuple, Union, Callable
from functools import partial
from dataclasses import dataclass
from google.colab import drive

print(jax.devices())

# MOUNT DRIVE: to store results (a pop up will appear, follow instructions)
from google.colab import drive
drive.mount('/content/drive')


Sun Sep 17 09:27:06 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    23W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

Create dataclasses that will alow us to give labels to vectors, so the code does not contain hard to read indexation of arrays.

In [None]:
class Transition(NamedTuple):
  done: jax.Array
  action: jax.Array
  value: jax.Array
  reward: jax.Array
  obs: jax.Array
  info: jax.Array

class Metrics(NamedTuple):
  mean_loss: jax.Array
  mean_actor_loss: jax.Array
  mean_value_loss: jax.Array

## 0. Create Neural Net Policy

First, we use Flax to create the Neural Net, Notice that we activate the last layer using Softplus to guarantee that we get possitive outputs.

See https://flax.readthedocs.io/en/latest/getting_started.html

In [None]:
class ActorCritic(nn.Module):
  actions_dim: Sequence[int]
  hidden_dims_actor: Sequence[int]
  hidden_dims_critic: Sequence[int]
  activations: Callable[[jax.Array], jax.Array] = nn.tanh
  activation_final_actor: Optional[Callable[[jax.Array], jax.Array]] = None
  activation_final_critic: Optional[Callable[[jax.Array], jax.Array]] = None


  @nn.compact
  def __call__(self, x: jax.Array, training: bool = False) -> jax.Array:
    # Actor Action
    action = x
    for size in self.hidden_dims_actor:
      action = self.activations(nn.Dense(size)(action))
    action = nn.Dense(self.actions_dim)(action)
    if self.activation_final_actor:
        action = self.activation_final_actor(action)

    # Critic Value
    value = x
    for size in self.hidden_dims_actor:
        value = self.activations(nn.Dense(size)(value))
    value = nn.Dense(1)(value)
    if self.activation_final_critic:
        value = self.activation_final_critic(value)
    value = jnp.squeeze(value, axis=-1)

    return action, value


## 1. Create Environment

In [None]:
class RbcMultiSector:
    """A JAX implementation of a multi-sector RBC model."""

    def __init__(self,
                 N=2,
                 beta=0.96,
                 alpha_values=0.3,
                 delta_values=0.1,
                 rho_values=0.9,
                 shock_sd=0.1,
                 xi_values=None,
                 sigma_c=0.5,
                 discount_rate = 0.9):

        self.N = N
        self.beta = beta
        self.alpha = jnp.ones(N) * alpha_values  # Convert scalar to vector
        self.delta = jnp.ones(N) * delta_values  # Convert scalar to vector
        self.rho = jnp.ones(N) * rho_values  # Convert scalar to vector
        self.discount_rate = discount_rate
        self.shock_sd = jnp.ones(N)* shock_sd
        self.sigma_c = sigma_c
        self.xi = jnp.ones(N) / N if xi_values is None else xi_values

        # Calculate steady state values
        self.k_ss = jnp.log((self.alpha / (1 / self.beta - 1 + self.delta)) ** (1 / (1 - self.alpha)))
        self.a_ss = jnp.zeros(N)
        self.obs_ss = jnp.concatenate([self.k_ss, self.a_ss])
        self.obs_sd = jnp.ones(2 * N)
        self.policy_ss = jnp.log(self.delta * jnp.exp(self.k_ss))

        # Steady state rewards and value
        self.I_ss = jnp.exp(self.policy_ss)
        self.K_ss = jnp.exp(self.k_ss)
        self.A_ss = jnp.exp(self.a_ss)
        self.Y_ss = self.A_ss * self.K_ss ** self.alpha
        self.C_ss = self.Y_ss - jnp.exp(self.policy_ss)
        self.Cagg_ss = jnp.sum(self.xi**(self.sigma_c**(-1)) * self.C_ss**(1 -self.sigma_c**(-1))) ** (1/(1-self.sigma_c**(-1)))
        self.reward_ss = jnp.log(self.Cagg_ss)
        self.value_ss = self.reward_ss/(1-self.beta)

        # Utility variables
        self.obs_dim = 2 * N
        self.state_dim = 2 * N
        self.action_dim = N


    def reset(self, rng):
        """ Get initial obs given first shock """
        K = random.uniform(rng, shape=(self.N,), minval=0.95 * jnp.exp(self.k_ss), maxval=1.05 * jnp.exp(self.k_ss))
        A = random.uniform(rng, shape=(self.N,), minval=0.95 * jnp.exp(self.a_ss), maxval=1.05 * jnp.exp(self.a_ss))

        obs_init_notnorm = jnp.concatenate([jnp.log(K), jnp.log(A)])
        obs_init = (obs_init_notnorm - self.obs_ss) / self.obs_sd  # normalize
        state_init = obs_init
        return obs_init, state_init

    def step(self, rng, state, action):
        """ A period step of the model, given current obs and policy """

        # Process observation
        obs = state
        policy = action
        obs_notnorm = obs * self.obs_sd + self.obs_ss  # denormalize
        K = jnp.exp(obs_notnorm[:self.N])
        a = obs_notnorm[self.N:]

        # Evolution of state
        a_tplus1 = self.rho * a + self.shock_sd*random.normal(rng, (self.N,))
        Inv = action * self.I_ss
        K_tplus1 = (1 - self.delta) * K + Inv

        # New observation and state
        new_obs_notnorm = jnp.concatenate([jnp.log(K_tplus1), a_tplus1])
        new_obs = (new_obs_notnorm - self.obs_ss) / self.obs_sd  # normalize
        new_state = new_obs

        # Reward
        A = jnp.exp(a)
        Y = A * K ** self.alpha
        C = Y - Inv
        Cagg = jnp.sum(self.xi**(self.sigma_c**(-1)) * C**(1-self.sigma_c**(-1))) ** (1/(1-self.sigma_c**(-1)))
        reward = jnp.log(Cagg)

        # Done, Info
        done = jnp.array(False)
        info = jnp.array([0.0])

        return new_obs, new_state, reward, done, info



Test environment

In [None]:
env = RbcMultiSector(N=8)
rng_test = random.PRNGKey(2)
# print(env.N, env.beta,env.alpha,env.k_ss, env.shock_var)
obs_init, state_init = env.reset(rng_test)
# print(obs_init, state_init)
nn_test = ActorCritic(actions_dim=env.action_dim,hidden_dims_actor=[16,8],hidden_dims_critic=[16,8], activation_final_actor=nn.softmax)
params_test = nn_test.init(rng_test, obs_init) # we initialize random params
action_test, value_test = nn_test.apply(params_test, obs_init)
new_obs, new_state, reward, done, info = env.step(rng_test, state_init, action_test)
print(new_obs, new_state, reward, done, info)

[-0.11759251 -0.11058551 -0.13201296 -0.14158124 -0.10777473 -0.08997279
 -0.04937065 -0.09906226 -0.09349994 -0.06809372 -0.16883685 -0.3341882
 -0.05597031  0.00554939  0.19597729 -0.02602552] [-0.11759251 -0.11058551 -0.13201296 -0.14158124 -0.10777473 -0.08997279
 -0.04937065 -0.09906226 -0.09349994 -0.06809372 -0.16883685 -0.3341882
 -0.05597031  0.00554939  0.19597729 -0.02602552] 2.353922 False [0.]


## 3. Create Simulation function

In [None]:
def create_simul_episode_fn(env, config, periods_per_epis):

  def simul_episode(params, train_state, epis_rng):
    obs, env_state = env.reset(epis_rng)
    period_rngs = random.split(epis_rng, periods_per_epis)
    # epis_rng, *period_rngs = random.split(epis_rng, periods_per_epis + 1)
    runner_state = params, env_state, obs , 0, 1

    def period_step(runner_state, period_rng):
      params, env_state, obs, returns, discount = runner_state
      # SELECT ACTION
      action, value_notnorm = train_state.apply_fn(params, obs)
      value = value_notnorm*env.value_ss

      # STEP ENV
      obs, env_state, reward, done, info = env.step(period_rng, env_state, action)
      transition = Transition(done, action, value, reward, obs, info)
      returns = returns + discount*reward
      discount = env.discount_rate*discount
      runner_state = (params, env_state, obs, returns, discount)
      return runner_state, transition

    # GET TRAJECTORIES
    runner_state, trajectory = jax.lax.scan(period_step, runner_state, jnp.stack(period_rngs))

    # CALCULATE DISCOUNTED RETURN AND LAST VALUE
    _, _, last_obs, returns, discount = runner_state
    _ , last_val_notnorm = train_state.apply_fn(train_state.params, last_obs)
    last_val = last_val_notnorm*env.value_ss
    returns = returns + discount * last_val

    return returns, trajectory, last_val

  return simul_episode

## Create Loss function

First, we define a function that gives as targets for value function estimation using Generalized Advantage Estimation. It uses the reward information of a trajectory, plus the current value function, to calculate new value targets for the visited states.

Then, we create the loss function for an episode. We will use the create_simul_episode_fn that we specified before.

In [None]:
def create_episode_loss_fn(env, config, periods_per_epis):

  simul_episode = create_simul_episode_fn(env, config, periods_per_epis)

  # Define function that gives targets for value updates
  def get_targets(trajectory, last_val):

    def get_advantages(gae_and_next_value, transition):
      gae, next_value = gae_and_next_value
      done, value, reward = (transition.done,transition.value,transition.reward,)
      delta = reward + env.discount_rate * next_value * (1 - done) - value
      gae = delta + env.discount_rate * config["gae_lambda"] * (1 - done) * gae
      return (gae, value), gae

    _, advantages = jax.lax.scan(get_advantages,(jnp.zeros_like(last_val), last_val),trajectory,
                                reverse=True,unroll=1,)
    targets = advantages + trajectory.value
    return targets

  def episode_loss_fn(params, train_state, epis_rng):
    returns, trajectory, last_val = simul_episode(params, train_state, jnp.stack(epis_rng))
    values = trajectory.value
    targets = get_targets(trajectory, last_val)
    actor_loss = -returns
    value_loss = jnp.mean(jnp.square(values - targets))
    value_loss_perc = jnp.mean((values-targets)/targets)
    return actor_loss + value_loss, (actor_loss, value_loss, value_loss_perc)

  return episode_loss_fn

## 4. Create Training function

No we define a function specifying an entire epoch of learning. This is the minimal unit of computation that we wil compile and pass to all the devices (devices are synchronized to average gradients).

In [None]:
def get_apg_train_fn(env, config):
  episode_loss_fn = create_episode_loss_fn(env, config, config["periods_per_epis"])

  def episode_train_fn(train_state, epis_rng):
    grad_fn = jax.value_and_grad(episode_loss_fn, has_aux=True)
    loss_metrics, grads = grad_fn(train_state.params, train_state, epis_rng)
    grads = jax.lax.pmean(grads, axis_name="episodes")
    train_state = train_state.apply_gradients(grads=grads)
    grad_mean = jnp.mean(jnp.array(jax.tree_util.tree_leaves(jax.tree_map(jnp.mean, grads))))
    grad_max = jnp.max(jnp.array(jax.tree_util.tree_leaves(jax.tree_map(lambda x: jnp.max(jnp.abs(x)), grads))))
    grad_metrics = (grad_mean,grad_max)
    episode_metrics = (loss_metrics, grad_metrics)
    return train_state, episode_metrics

  def step_train_fn(train_state, step_rng):
    step_rng, *epis_rng = random.split(step_rng, config["epis_per_step"] + 1)
    train_state, batch_metrics = jax.vmap(episode_train_fn, in_axes=(None,0), out_axes=(None,0), axis_name="episodes")(train_state, jnp.stack(epis_rng))
    return train_state, batch_metrics

  def epoch_train_fn(train_state, epoch_rng):
    """Vectorise and repeat the update to complete an epoch, made aout of steps_per_epoch episodes."""
    epoch_rng, *step_rngs = random.split(epoch_rng, config["steps_per_epoch"] + 1)
    train_state, epoch_metrics = lax.scan(step_train_fn, train_state, jnp.stack(step_rngs))
    return train_state, epoch_rng, epoch_metrics

  return epoch_train_fn


## Create Evaluation function

In [None]:
def get_eval_fn(env, config):
  episode_loss_fn = create_episode_loss_fn(env, config, config["eval_periods_per_epis"])

  def episode_grads_and_metrics(train_state, epis_rng):
    grad_fn = jax.value_and_grad(episode_loss_fn, has_aux=True)
    loss_metrics, grads = grad_fn(train_state.params, train_state, epis_rng)
    grads = jax.lax.pmean(grads, axis_name="episodes")
    grad_mean = jnp.mean(jnp.array(jax.tree_util.tree_leaves(jax.tree_map(jnp.mean, grads))))
    grad_max = jnp.max(jnp.array(jax.tree_util.tree_leaves(jax.tree_map(lambda x: jnp.max(jnp.abs(x)), grads))))
    grad_metrics = (grad_mean,grad_max)
    episode_metrics = (loss_metrics, grad_metrics)
    return episode_metrics

  def eval_fn(train_state, eval_rng):
    epis_rng = random.split(eval_rng, config["eval_n_epis"])
    loss_metrics, grad_metrics = jax.vmap(episode_grads_and_metrics, in_axes=(None,0), out_axes=(0), axis_name="episodes")(train_state, jnp.stack(epis_rng))
    eval_metrics = jnp.mean(loss_metrics[0]),jnp.mean(loss_metrics[1][0]),jnp.mean(loss_metrics[1][1]),(1-jnp.abs(jnp.mean(loss_metrics[1][2])))*100, jnp.mean(grad_metrics[0]),jnp.max(grad_metrics[1])
    return eval_metrics

  return eval_fn

## 4. Configure experiment

In [None]:
# Configuration dictionary

# CREATE LEARNING RATE SCHEDULE (https://github.com/google-deepmind/optax/blob/master/optax/_src/schedule.py)
lr_schedule = optax.join_schedules(
      schedules= [optax.linear_schedule(0,0.01,100),
                  optax.constant_schedule(0.01),
                  optax.constant_schedule(0.001),
                  optax.constant_schedule(0.0001),
                  optax.cosine_decay_schedule(0.0001,1000)
                  ],
      boundaries=[300,1000,1500,2000] # the number of episodes at which to switch


      )

config_apg = {
    "learning_rate": lr_schedule,
    "n_epochs": 60,
    "steps_per_epoch": 50,
    "epis_per_step": 1024*8,
    "periods_per_epis": 32,
    "eval_n_epis": 1024*32,
    "eval_periods_per_epis": 32,
    "gae_lambda": 0.95,
    "max_grad_norm": None,
    "layers_actor": [16,8],
    "layers_critic": [8,4],
    "seed": 42,
    "fp64_precision": False, # GPUs use fp32 by default. Restart notebook if you run it wiht True and change it to False
    "run_name": "apg_RbcMS_GPUT4_Sep17_fp16",
    "date": "Sep_13",
    "working_dir": "/content/drive/MyDrive/Jaxecon/APG/" # replace with your folder
}


## 5. Create experiment


In [None]:
def run_experiment(env, config):
  """Runs experiment."""

  print("Staring experiment... \n")
  if config["fp64_precision"]:
    from jax.config import config as config_jax
    config_jax.update("jax_enable_x64", True)

  n_cores = len(jax.devices())  # get available TPU cores.

  # CREATE NN, RNGS, TRAIN_STATE AND EPOQUE UPDATE
  nn_policy = ActorCritic(actions_dim=env.action_dim,hidden_dims_actor=config["layers_actor"],hidden_dims_critic=config["layers_critic"], activation_final_actor=nn.softmax)
  if config["max_grad_norm"]:
    optim = optax.chain(optax.clip_by_global_norm(config["max_grad_norm"]), optax.adam(config["learning_rate"]))
  else:
    optim = optax.chain(optax.adam(config["learning_rate"]))

  print("Neural Net and Optimizer Created... \n")
  # INIATILZE ENV AND ALGO STATES
  rng, rng_pol, rng_env, rng_epoch, rng_eval = random.split(random.PRNGKey(config["seed"]), num=5)  # random number generator

  obs, env_state = env.reset(rng_env)
  train_state = ts_class.TrainState.create(apply_fn=nn_policy.apply, params=nn_policy.init(rng_pol, obs), tx=optim)

  # GET EPOQUE TRAIN AND EVAL FUNCTIONS
  epoch_update = jax.jit(get_apg_train_fn(env, config))
  eval = jax.jit(get_eval_fn(env, config))

  # COMPILE CODE

  print("Starting compilation... \n")
  time_start = time()
  epoch_update(train_state, rng_epoch)  # compiles
  eval(train_state, rng_eval)
  time_compilation = time() - time_start
  print("Time Elapsed for Compilation:", time_compilation, "seconds")

  print("Compilation completed. Proceeding to run an epoque and calculate performance statistics ... \n")

  # RUN AN EPOCH TO GET TIME STATS
  time_start = time()
  epoch_update(train_state, rng_epoch) # run one epoque
  time_epoch = time() - time_start
  print("Time Elapsed for Epoch:", time_epoch, "seconds")
  print("Steps per second:", n_cores*config["steps_per_epoch"]*config["epis_per_step"]*config["periods_per_epis"]/time_epoch, "st/s")

  # RUN AN EVAL TO GET TIME STATS
  time_start = time()
  eval(train_state, rng_eval) # run one epoque
  time_eval = time() - time_start
  print("Time Elapsed for Eval:", time_eval, "seconds")

  print("Estimated time for full experiment",
        (time_epoch+time_eval)*config["n_epochs"]/60, "minutes \n")

  print("Proceding to run all epoques... \n")
  # CREATE LISTS TO STORE METRICS
  mean_losses, mean_actor_losses, mean_critic_losses, mean_critic_accs, mean_grads, max_grads = [], [], [], [], [], []

  # RUN ALL THE EPOCHS
  time_start = time()
  for i in range(1,config["n_epochs"]+1):
    train_state, rng_epoch, epoch_metrics = epoch_update(train_state, rng_epoch)
    eval_metrics = eval(train_state, rng_eval)
    mean_losses.append(float(jnp.mean(epoch_metrics[0][0])))
    mean_actor_losses.append(float(jnp.mean(epoch_metrics[0][1][0])))
    mean_critic_losses.append(float(jnp.mean(epoch_metrics[0][1][1])))
    mean_critic_accs.append(float((1-jnp.abs(jnp.mean(epoch_metrics[0][1][2])))*100))
    mean_grads.append(float(jnp.mean(epoch_metrics[1][0])))
    max_grads.append(float(jnp.mean(jnp.max(epoch_metrics[1][1]))))

    print('Iteration:', i*config["steps_per_epoch"],
          ", Mean_loss:", jnp.mean(epoch_metrics[0][0]),
          ", Mean_actor_loss:", jnp.mean(epoch_metrics[0][1][0]),
          ", Mean_critic_loss:", jnp.mean(epoch_metrics[0][1][1]),
          ", Mean_critic_acc:", (1-jnp.abs(jnp.mean(epoch_metrics[0][1][2])))*100,
          ", Mean_grads:", jnp.mean(epoch_metrics[1][0]),
          ", Max_grads:", jnp.max(epoch_metrics[1][1]),
          ", Learning rate:", config["learning_rate"](i*config["steps_per_epoch"]),
          "\n"
          )

    print('Evaluation:     ',
          ", Mean_loss:", eval_metrics[0],
          ", Mean_actor_loss:", eval_metrics[1],
          ", Mean_critic_loss:", eval_metrics[2],
          ", Mean_critic_acc:", eval_metrics[3],
          ", Mean_grads:", eval_metrics[4],
          ", Max_grads:", eval_metrics[5],
          "\n"
          )

  # STORE RESULTS
  print("Minimum loss attained in training:", min(mean_losses))


  time_fullexp = (time() - time_start)/60
  print("Time Elapsed for Full Experiment:", time_fullexp, "minutes")

  results = {
      "min_loss":  min(mean_losses),
      "min_actor_loss":  min(mean_actor_losses),
      "min_critic_loss":  min(mean_critic_losses),
      "last_critic_accs": mean_critic_accs[-1],
      "Time for Full Experiment (m)": time_fullexp,
      "Time for epoch (s)": time_epoch,
      "Time for Compilation (s)": time_compilation,
      "Steps per second": n_cores * config["steps_per_epoch"] * config["periods_per_epis"]/time_epoch,
      "n_cores": n_cores,
      "periods_per_epis": config["periods_per_epis"],
      "epis_per_step": config["epis_per_step"],
      "steps_per_epoch": config["steps_per_epoch"],
      "n_epochs": config["n_epochs"],
      "layers_actor": config["layers_actor"],
      "layers_critic": config["layers_critic"],
      "date": config["date"],
      "seed": config["seed"],
      "Losses_list": mean_losses,
      "Actor_losses_list": mean_actor_losses,
      "Critic_losses_list": mean_critic_losses,
      "Critic_accs_list": mean_critic_accs,
      "Mean_grads_list": mean_grads,
      "Max_grads_list": max_grads,
      # "last_eval_metrics": eval_metrics
  }
  print(results)
  if not os.path.exists(config['working_dir']+config['run_name']):
    os.mkdir(config['working_dir']+config['run_name'])
  with open(config['working_dir']+config['run_name']+"/results.json", "w") as write_file:
    json.dump(results, write_file)

  # store checkpoint
  checkpoints.save_checkpoint(ckpt_dir=config['working_dir']+config['run_name'], target=train_state, step=config["n_epochs"]*config["steps_per_epoch"])

  # PLOT LEARNING

  # Mean Losses
  plt.plot([(i)*config["steps_per_epoch"] for i in range(len(mean_losses))], mean_losses)
  plt.xlabel('Episodes (NN updates)')
  plt.ylabel('Mean Losses')
  plt.savefig(config['working_dir']+config['run_name']+'/mean_losses.jpg')
  plt.close()

  # Mean Actor Losses
  plt.plot([(i)*config["steps_per_epoch"] for i in range(len(mean_actor_losses))], mean_actor_losses)
  plt.xlabel('Episodes (NN updates)')
  plt.ylabel('Mean Actor Losses')
  plt.savefig(config['working_dir']+config['run_name']+'/mean_actor_losses.jpg')
  plt.close()

  # Mean Actoor Losses
  plt.plot([(i)*config["steps_per_epoch"] for i in range(len(mean_critic_losses))], mean_critic_losses)
  plt.xlabel('Episodes (NN updates)')
  plt.ylabel('Mean Critic Losses')
  plt.savefig(config['working_dir']+config['run_name']+'/mean_critic_losses.jpg')
  plt.close()

  # Mean Accuracy
  plt.plot([(i)*config["steps_per_epoch"] for i in range(len(mean_critic_accs))], mean_critic_accs)
  plt.xlabel('Episodes (NN updates)')
  plt.ylabel('Mean Accuracy (%)')
  plt.savefig(config['working_dir']+config['run_name']+'/mean_critic_accuracy.jpg')
  plt.close()

  # Learning rate schedule
  plt.plot([i*config["steps_per_epoch"] for i in range(len(mean_losses))], [config["learning_rate"](i*config["steps_per_epoch"]) for i in range(len(mean_losses))])
  plt.xlabel('Episodes (NN updates)')
  plt.ylabel('Learning Rate')
  plt.savefig(config['working_dir']+config['run_name']+'/learning_rate.jpg')
  plt.close()

  return train_state

## 6. Run Experiment

In [None]:
final_train_state = run_experiment(RbcMultiSector(N=8), config_apg)

Staring experiment... 



AttributeError: ignored

## 7. Test Results

In [None]:
# Test policies in steady state
env_test = RbcMultiSector(N=8)
params_test = final_train_state.params # Choose parameteres to test
# params_test = params_pretrained
nn_policy_test = final_train_state.apply_fn
rng_test = random.PRNGKey(1)

# test steady state policies
obs_init, _ = env.reset(rng_test)
obs_ss = jnp.zeros_like(obs_init, dtype=jnp.float32)
# obs_ss = obs_init
policy_ss = nn_policy_test(params_test, obs_ss) # nn_policy is the nn policy that comes from the training experiment
print("Pretrain Policy in ss (should be ~ 1):", policy_ss)





Pretrain Policy in ss (should be ~ 1): (Array([0.10403576, 0.14393546, 0.19268866, 0.11237198, 0.09305013,
       0.09898397, 0.09659538, 0.15833871], dtype=float32), Array(0.39711103, dtype=float32))


In [None]:
final_train_state.__dict__

{'step': Array(10000, dtype=int32, weak_type=True),
 'apply_fn': <bound method Module.apply of ActorCritic(
     # attributes
     actions_dim = 8
     hidden_dims_actor = [16, 8]
     hidden_dims_critic = [8, 4]
     activations = tanh
     activation_final_actor = softmax
     activation_final_critic = None
 )>,
 'params': {'params': {'Dense_0': {'bias': Array([ 0.13972686, -0.04131378, -0.15165542, -0.2570419 , -0.05978048,
            0.0825581 ,  0.0706772 , -0.0233331 , -0.05665956, -0.03936341,
           -0.0098231 , -0.03846536, -0.18649097,  0.2756047 , -0.31761026,
           -0.09461976], dtype=float32),
    'kernel': Array([[-3.19528699e-01, -3.05196345e-01,  6.01582408e-01,
             6.99338764e-02, -3.41943726e-02,  3.97388458e-01,
             3.77708197e-01,  6.27666665e-03,  4.70901996e-01,
             4.51826543e-01,  1.96015462e-02,  2.00139150e-01,
            -2.08836317e-01,  3.54485333e-01, -1.76357955e-01,
            -3.26189846e-01],
           [ 3.474910

## Deprecated

In [None]:
def get_apg_train_fn(env, config):

  def calculate_gae(traj_batch, last_val):
    def get_advantages(gae_and_next_value, transition):
      gae, next_value = gae_and_next_value
      done, value, reward = (transition.done,transition.value,transition.reward,)
      delta = reward + env.discount_rate * next_value * (1 - done) - value
      gae = delta + env.discount_rate * config["gae_lambda"] * (1 - done) * gae
      return (gae, value), gae

    _, advantages = jax.lax.scan(get_advantages,(jnp.zeros_like(last_val), last_val),traj_batch,reverse=True,unroll=16,)
    return advantages, advantages + traj_batch.value

  def simul_episode(params, train_state, epis_rng):
    obs, env_state = env.reset(epis_rng)
    epis_rng, *period_rngs = random.split(epis_rng, config["periods_per_epis"] + 1)
    runner_state = params, env_state, obs , 0, 1

    def period_step(runner_state, period_rng):

      params, env_state, obs, returns, discount= runner_state
      # SELECT ACTION
      action, value_notnorm = train_state.apply_fn(params, obs)
      value = value_notnorm*env.value_ss

      # STEP ENV
      obs, env_state, reward, done, info = env.step(period_rng, env_state, action)
      transition = Transition(done, action, value, reward, obs, info)
      returns = returns + discount*reward
      discount = env.discount_rate*discount
      runner_state = (params, env_state, obs, returns, discount)
      return runner_state, transition

    # GET TRAJECTORIES
    runner_state, traj_batch = jax.lax.scan(period_step, runner_state, jnp.stack(period_rngs))

    # CALCULATE TARGETS FOR VALUE UPDATE
    _, _, last_obs, returns, discount = runner_state
    _ , last_val_notnorm = train_state.apply_fn(train_state.params, last_obs)
    last_val = last_val_notnorm*env.value_ss
    returns = returns+discount*last_val
    _, targets = calculate_gae(traj_batch, last_val)

    return returns, traj_batch.value, targets, last_obs

  def episode_loss_fn(params, train_state, epis_rng):
    returns, values, targets, last_obs = simul_episode(params, train_state, jnp.stack(epis_rng))
    actor_loss = -returns
    value_loss = jnp.mean(jnp.square(values - targets))
    value_loss_perc = jnp.mean((values-targets)/targets)
    return actor_loss+value_loss, (actor_loss, value_loss, value_loss_perc)

  def episode_train_fn(train_state, epis_rng):
    grad_fn = jax.value_and_grad(episode_loss_fn, has_aux=True)
    loss_metrics, grads = grad_fn(train_state.params, train_state, epis_rng)
    grads = jax.lax.pmean(grads, axis_name="episodes")
    train_state = train_state.apply_gradients(grads=grads)
    grad_mean = jnp.mean(jnp.array(jax.tree_util.tree_leaves(jax.tree_map(jnp.mean, grads))))
    grad_max = jnp.max(jnp.array(jax.tree_util.tree_leaves(jax.tree_map(jnp.max, grads))))
    grad_metrics = (grad_mean,grad_max)
    episode_metrics = (loss_metrics, grad_metrics)
    return train_state, episode_metrics

  # UPDATE NETWORK
  def step_train_fn(train_state, step_rng):
    step_rng, *epis_rng = random.split(step_rng, config["epis_per_step"] + 1)
    train_state, batch_metrics = jax.vmap(episode_train_fn, in_axes=(None,0), out_axes=(None,0), axis_name="episodes")(train_state, jnp.stack(epis_rng))
    return train_state, batch_metrics

  def epoch_train_fn(train_state, epoch_rng):
    """Vectorise and repeat the update to complete an epoch, made aout of steps_per_epoch episodes."""
    epoch_rng, *step_rngs = random.split(epoch_rng, config["steps_per_epoch"] + 1)
    train_state, epoch_metrics = lax.scan(step_train_fn, train_state, jnp.stack(step_rngs))
    return train_state, epoch_rng, epoch_metrics

  return epoch_train_fn
