# Setup

In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal, variance_scaling, normal
from typing import Sequence, NamedTuple, Any
from functools import partial
from flax.training.train_state import TrainState
import distrax
import gymnax
from gymnax.wrappers.purerl import LogWrapper, FlattenObservationWrapper
import matplotlib.pyplot as plt
import warnings
import pickle
from pathlib import Path
import tqdm
import pickle

def subsample_with_end(x, subsampling, axis=0):
    leftover = x.shape[axis] % subsampling
    return x[leftover-1::subsampling]

def process_out(out, subsampling=1):
    # metrics: (-1: num_envs, -2: num_steps) (check ? not sure.)
    metrics = jax.tree_util.tree_map(lambda x: x.mean(axis=(-1, -2)), out['metrics'])
    loss_info = jax.tree_util.tree_map(lambda x: x.mean(axis=(-1)), out['loss_info'])

    output = dict(
        metrics=metrics, loss_info=loss_info
    )

    output = jax.tree_util.tree_map(lambda x: x.reshape(x.shape[0], x.shape[1] // subsampling, subsampling, *x.shape[2:]).mean(axis=2), output)
    output["runner_state"] = out["runner_state"]

    return output

def tree_size(tree):
    return sum(jax.tree_util.tree_leaves(jax.tree_util.tree_map(lambda x: x.nbytes, tree)))

def ss_plot(out, ax=None, quantiles=True, individual=False, color='blue', subsampled=1, label=None, sem=False, smoothing=None, skip_first=False, **kwargs):
    if ax is None:
        fig, ax = plt.subplots(1, 1)

    n_runs = out["metrics"]["returned_episode_returns"].shape[0]

    episode_returns = out["metrics"]["returned_episode_returns"]

    mean = episode_returns.mean(axis=0)

    sorted = jnp.sort(episode_returns, axis=0)

    std = episode_returns.std(axis=0)

    if not (smoothing is None):
        mean = smoothing(mean)
        std = smoothing(std)
    x_axis = range(0, mean.shape[0] * subsampled, subsampled)
    skip_first = int(skip_first)

    ax.plot(x_axis[skip_first:], mean[skip_first:], color=color, label=label, **kwargs)
    if quantiles:
        ax.fill_between(x_axis[skip_first:], sorted[1, skip_first:], sorted[-2, skip_first:], alpha=0.2, color=color)
    else:
        factor = 1 / np.sqrt(n_runs) if sem else 1
        ax.fill_between(x_axis[skip_first:], mean[skip_first:] - factor * std[skip_first:], mean[skip_first:] + factor * std[skip_first:], alpha=0.2, color=color)

    if individual:
        ax.plot(episode_returns.reshape(n_runs, -1).T, linestyle=":", color=color)

def save_results(path, results, configs, force=False):
    if not isinstance(path, Path):
        path = Path(path)
    if not force:
        if Path(path).is_file():
            path += '_0'
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, 'wb') as f:
        pickle.dump((results, configs), f)

def drop_runner_state(results):
    out = {}
    for name, value in results.items():
        out[name] = dict(loss_info=value['loss_info'], metrics=value['metrics'])
    return out


# Ensemble-ECPPO + Baseline

In [2]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal, variance_scaling, normal
from typing import Sequence, NamedTuple, Any
from functools import partial
from flax.training.train_state import TrainState
import distrax
import gymnax
from gymnax.wrappers.purerl import LogWrapper, FlattenObservationWrapper
import matplotlib.pyplot as plt
import warnings
import pickle
# warnings.simplefilter(action='ignore', category=FutureWarning)
# warnings.simplefilter(action='ignore', category=UserWarning)


BatchDense = nn.vmap(
    nn.Dense,
    in_axes=-1, out_axes=-1,
    variable_axes={'params': 0,},
    split_rngs={'params': True},
)

prior_bias_init = normal(0.1)

class ActorEpistemicCritic(nn.Module):
    action_dim: Sequence[int]
    ensemble_size: int
    activation: str = "tanh"
    beta: float = 0.

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        pi = distrax.Categorical(logits=actor_mean)
        
        batched_x = x.reshape(*x.shape, 1).repeat(self.ensemble_size, axis=-1)

        critic = BatchDense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(batched_x)
        critic = activation(critic)
        critic = BatchDense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = BatchDense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        prior_critic = BatchDense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=prior_bias_init
        )(batched_x)
        prior_critic = activation(prior_critic)
        prior_critic = BatchDense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=prior_bias_init
        )(prior_critic)
        prior_critic = activation(prior_critic)
        prior_critic = BatchDense(1, kernel_init=orthogonal(1.0), bias_init=prior_bias_init)(
            prior_critic
        )
        return pi, jnp.squeeze(critic, axis=-2) + self.beta * jnp.squeeze(jax.lax.stop_gradient(prior_critic), axis=-2)
    


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray

def make_train(config):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )

    env_kwargs = {}
    if config["ENV_NAME"] == "DeepSea-bsuite":
        env_kwargs["size"] = 15
    env, env_params = gymnax.make(config["ENV_NAME"], **env_kwargs)
    env = FlattenObservationWrapper(env)
    env = LogWrapper(env)

    def linear_schedule(count):
        frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
        return config["LR"] * frac

    def train(rng):

        # INIT NETWORK
        network = ActorEpistemicCritic(env.action_space(env_params).n, activation=config["ACTIVATION"], ensemble_size=config["ENSEMBLE_SIZE"], beta=config["BETA"])
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)

        # print(jax.tree_util.tree_map(lambda x: x.shape, network_params))

        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5))
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        if "UNCERTAINTY_SHIFT" not in config.keys():
            config["UNCERTAINTY_SHIFT"] = 0.3
        
        if "UNCERTAINTY_MULTIPLIER" not in config.keys():
            config["UNCERTAINTY_MULTIPLIER"] = 15.

        if "UNCERTAINTY_BASE" not in config.keys():
            config["UNCERTAINTY_BASE"] = 0.5

        if "UNCERTAINTY_SCALE" not in config.keys():
            config["UNCERTAINTY_SCALE"] = 1.5

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng,  env_params)

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi, value = network.apply(train_state.params, last_obs)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0,None))(
                    rng_step, env_state, action, env_params
                )
                transition = Transition(
                    done, action, value, reward, log_prob, last_obs, info
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            train_state, env_state, last_obs, rng = runner_state
            _, last_val = network.apply(train_state.params, last_obs)
            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward[..., None] + config["GAMMA"] * next_value * (1 - done[..., None]) - value
                    gae = (
                        delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done[..., None]) * gae
                    )
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = (
                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                        )

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)

                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)

                        gae_uncertainty = gae.std(axis=-1)
                        
                        if config["UNCERTAINTY"]:
                            squashed_uncertainty = config["UNCERTAINTY_BASE"] + config["UNCERTAINTY_SCALE"] * jax.nn.sigmoid( config["UNCERTAINTY_MULTIPLIER"] * (-gae_uncertainty + config["UNCERTAINTY_SHIFT"] ))
                        else:
                            squashed_uncertainty = jnp.array(1.)

                        loss_actor1 = ratio * jnp.mean(gae, axis=-1) # average loss over all values
                        loss_actor2 = (
                            jnp.clip(
                                ratio,
                                1.0 - config["CLIP_EPS"] * squashed_uncertainty,
                                1.0 + config["CLIP_EPS"] * squashed_uncertainty,
                            )
                            * jnp.mean(gae, axis=-1)
                        )

                        percentage_clipped = (jnp.abs(ratio - 1) > config["CLIP_EPS"] * squashed_uncertainty).mean()

                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)

                        loss_argmin = (loss_actor1 > loss_actor2).mean()
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, dict(value_loss=value_loss, 
                                                loss_actor=loss_actor, 
                                                entropy=entropy, 
                                                mean_uncertainty=gae_uncertainty.mean(), 
                                                min_uncertainty=gae_uncertainty.min(), 
                                                max_uncertainty=gae_uncertainty.max(),
                                                min_gae = gae.min(),
                                                max_gae = gae.max(),
                                                percentage_clipped=percentage_clipped,
                                                loss_argmin=loss_argmin,
                                                mean_squashed_uncertainty=squashed_uncertainty.mean(),
                                                max_squashed_uncertainty=squashed_uncertainty.max(),
                                                min_squashed_uncertainty=squashed_uncertainty.min())

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, total_loss

                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)

                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                train_state, total_loss = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                return update_state, total_loss

            update_state = (train_state, traj_batch, advantages, targets, rng)
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )
            train_state = update_state[0]
            metric = traj_batch.info
            rng = update_state[-1]

            runner_state = (train_state, env_state, last_obs, rng)
            info = dict(
                metrics=metric,
                loss_info=loss_info
            )
            return runner_state, info
        
        if config["SAVE_STATE_FREQ"]:
            outer_updates = config["NUM_UPDATES"] // config["SAVE_STATE_FREQ"]
            inner_updates = config["SAVE_STATE_FREQ"]
            left_over = config["NUM_UPDATES"] - inner_updates * outer_updates 
            print(inner_updates, outer_updates, left_over)
        else:
            inner_updates = config["NUM_UPDATES"]
            outer_updates = 1
            left_over = 0

        # start with the left_over steps:
        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        runner_state, initial_info = jax.lax.scan(
            _update_step, runner_state, None, left_over
        )

        def _outer_update_step(update_state, unused):
            runner_state, info = jax.lax.scan(
                _update_step, update_state, None, inner_updates
            )
            return runner_state, (runner_state, info)

        # then do the remaining updates
        runner_state, (states, info) = jax.lax.scan(
            _outer_update_step, runner_state, None, outer_updates
        )

        info = jax.tree_util.tree_map(lambda x: x.reshape(-1, *x.shape[2:]), info)
        info = jax.tree_util.tree_map(lambda x, y: jnp.concatenate([x, y], axis=0), initial_info, info)

        info["runner_state"] = states
              
        return info
    return train

In [3]:
import pickle
baseline_config = {
    "LR": 5e-3,
    "NUM_ENVS": 64,
    "NUM_STEPS": 128,
    "TOTAL_TIMESTEPS": 1e7,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 8,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.01,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "relu",
    "ANNEAL_LR": True,
    "ENSEMBLE_SIZE": 1,
    "UNCERTAINTY": False,
    "BETA": 0.,
    "SAVE_STATE_FREQ": False,
}

ecppo_config = {
    "LR": 5e-3,
    "NUM_ENVS": 64,
    "NUM_STEPS": 128,
    "TOTAL_TIMESTEPS": 1e7,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 8,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.01,
    "VF_COEF": .5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "relu",
    "ANNEAL_LR": True,
    "ENSEMBLE_SIZE": 5,
    "UNCERTAINTY": True,
    "BETA": 1.,
    "SAVE_STATE_FREQ": False,
    "UNCERTAINTY_BASE": 0.5,
    "UNCERTAINTY_SCALE": 1.5,
    "UNCERTAINTY_SHIFT": 0.3,
    "UNCERTAINTY_MULTIPLIER": 15,
}

baseline_high_eps = {key: value for key, value in baseline_config.items()}
baseline_high_eps["CLIP_EPS"] = 0.4


baseline_low_eps = {key: value for key, value in baseline_config.items()}
baseline_low_eps["CLIP_EPS"] = 0.1


configs = {
    'baseline': baseline_config,
    'baseline_high_eps': baseline_high_eps,
    'baseline_low_eps': baseline_low_eps,
    'ens-ecppo': ecppo_config,
}

load = False

if load:
    with open('results.pkl', 'rb') as f:
        results, _ = pickle.load(f)
else:
    results = {}

In [4]:
envs = [env for env in gymnax.registered_envs if isinstance(gymnax.make(env)[0].action_space(), gymnax.environments.spaces.Discrete)]

print(envs)

base = 1e7 // (64 * 128)
short_axis = {
    "CartPole-v1": 200,
    "Pong-misc": 400,
    "UmbrellaChain-bsuite": 50,
    "FourRooms-misc": 100,
    "DiscountingChain-bsuite": 200,
    "GaussianBandit-misc": 50,
    "Catch-bsuite": 50,
}

['CartPole-v1', 'Acrobot-v1', 'MountainCar-v0', 'Asterix-MinAtar', 'Breakout-MinAtar', 'Freeway-MinAtar', 'SpaceInvaders-MinAtar', 'Catch-bsuite', 'DeepSea-bsuite', 'MemoryChain-bsuite', 'UmbrellaChain-bsuite', 'DiscountingChain-bsuite', 'MNISTBandit-bsuite', 'SimpleBandit-bsuite', 'FourRooms-misc', 'MetaMaze-misc', 'BernoulliBandit-misc', 'GaussianBandit-misc', 'Pong-misc']


In [5]:
results = {}

pbar = tqdm.tqdm(envs)
for env in pbar:
    if env in short_axis:
        total_timesteps = int(1e7 * short_axis[env] / base)
    else:
        total_timesteps = 1e7
    for name in configs.keys():
        experiment_name = name + "/" + env
    
        if (experiment_name not in results.keys()):

            pbar.set_postfix_str(experiment_name)
            
            config = {key: value for key, value in configs[name].items()} # copy the config
            config["ENV_NAME"] = env
            config["TOTAL_TIMESTEPS"] = total_timesteps

            rng = jax.random.PRNGKey(42) # Each agent and environment uses the same random key
            train_jit = jax.jit(make_train(config))
            out = jax.vmap(train_jit)(jax.random.split(rng, 20)) # The random key is split into 20 keys for 20 runs per agent
            results[experiment_name] = process_out(out)

  return jnp.array(cars, dtype=jnp.int_)
  return jnp.array(cars, dtype=jnp.int_)
  return jnp.array(cars, dtype=jnp.int_)
  return jnp.array(cars, dtype=jnp.int_)
  return jnp.array(cars, dtype=jnp.int_)
  return jnp.array(cars, dtype=jnp.int_)
  return jnp.array(cars, dtype=jnp.int_)
  return jnp.array(cars, dtype=jnp.int_)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
  return lax_numpy.astype(arr, dtype)
100%|██████████| 19/19 [49:07<00:00, 155.11s/it, ens-ecppo/Pong-misc]


In [6]:
save_results('ens-eccpo-results.pkl', drop_runner_state(results), configs=configs, force=True)

# Laplace-ECPPO

In [7]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal, variance_scaling, normal
from typing import Sequence, NamedTuple, Any
from functools import partial
from flax.training.train_state import TrainState
import distrax
import gymnax
from gymnax.wrappers.purerl import LogWrapper, FlattenObservationWrapper
import matplotlib.pyplot as plt
import warnings
import pickle
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
from flax import core, struct
from operator import mul
from functools import reduce, partial
from typing import Tuple

# warnings.simplefilter(action='ignore', category=FutureWarning)
# warnings.simplefilter(action='ignore', category=UserWarning)

def prod(tuple_):
    if tuple_ == ():
        return 0
    return reduce(mul, tuple_, 1.0)


def random_split_like_tree(rng_key, target=None, treedef=None):
    if treedef is None:
        treedef = jax.tree_util.tree_structure(target)
    keys = jax.random.split(rng_key, treedef.num_leaves)
    return jax.tree_util.tree_unflatten(treedef, keys)


def tree_random_normal_like(rng_key, target):
    keys_tree = random_split_like_tree(rng_key, target)
    return jax.tree_util.tree_map(
        lambda l, k: jax.random.normal(k, l.shape, l.dtype),
        target,
        keys_tree,
    )

def variance_scaling(w, base_scale):
    variance = base_scale**2
    fan_in = prod(w.shape[:-1])
    variance /= max(1.0, fan_in)
    scale = jnp.sqrt(variance) * jnp.ones(w.shape)
    return scale

class LaplaceState(NamedTuple):
    fisher: core.FrozenDict[str, Any]
    count: jnp.float32

def laplace_mapping(params, noise, laplace_state: LaplaceState):
    m = laplace_state.count
    normalized_fisher = jax.tree_util.tree_map(lambda fi: fi / m,laplace_state.fisher)
    output = jax.tree_util.tree_map(lambda x, z, lp: x + z / jnp.sqrt(lp), params, noise, normalized_fisher)
    return output

prior_bias_init = normal(0.1)

class ActorEpistemicCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        elif self.activation == "leaky_relu":
            activation = nn.leaky_relu
        else:
            activation = nn.tanh
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        pi = distrax.Categorical(logits=actor_mean)

        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    mean_value: jnp.ndarray
    ensemble_value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray

def make_train(config):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )

    env_kwargs = {}
    if config["ENV_NAME"] == "DeepSea-bsuite":
        env_kwargs["size"] = 15
    env, env_params = gymnax.make(config["ENV_NAME"], **env_kwargs)
    env = FlattenObservationWrapper(env)
    env = LogWrapper(env)

    def linear_schedule(count):
        frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
        return config["LR"] * frac

    def train(rng):

        # INIT NETWORK
        network = ActorEpistemicCritic(env.action_space(env_params).n, activation=config["ACTIVATION"])
        rng, _rng, ensemble_rng = jax.random.split(rng, 3)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        ensemble_noise = jax.vmap(tree_random_normal_like, in_axes=(0, None))(jax.random.split(ensemble_rng, config["ENSEMBLE_SIZE"]), network_params)

        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5))
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        prior_scale = jax.tree_util.tree_map(partial(variance_scaling, base_scale=config["PRIOR_SCALE"]), network_params)
        init_laplace = jax.tree_util.tree_map(lambda x: 1 / x**2, prior_scale)
        laplace_state = LaplaceState(init_laplace, 1.)

        if "UNCERTAINTY_SHIFT" not in config.keys():
            config["UNCERTAINTY_SHIFT"] = 0.3
        
        if "UNCERTAINTY_MULTIPLIER" not in config.keys():
            config["UNCERTAINTY_MULTIPLIER"] = 15

        if "UNCERTAINTY_BASE" not in config.keys():
            config["UNCERTAINTY_BASE"] = 0.5

        if "UNCERTAINTY_SCALE" not in config.keys():
            config["UNCERTAINTY_SCALE"] = 1.5

        if "ADAPTIVE_SIGMA" not in config.keys():
            config["ADAPTIVE_SIGMA"] = True

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng,  env_params)

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, env_state, laplace_state, last_obs, rng = runner_state

                # CREATE THE ENSEMBLE:
                ensemble = jax.vmap(laplace_mapping, in_axes=(None, 0, None))(train_state.params, ensemble_noise, laplace_state)
                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi, mean_value = network.apply(train_state.params, last_obs)
                _, ensemble_value = jax.vmap(network.apply, in_axes=(0, None), out_axes=-1)(ensemble, last_obs)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0,None))(
                    rng_step, env_state, action, env_params
                )
                transition = Transition(
                    done, action, mean_value, ensemble_value, reward, log_prob, last_obs, info
                )
                runner_state = (train_state, env_state, laplace_state, obsv, rng)
                return runner_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            train_state, env_state, laplace_state, last_obs, rng = runner_state

            ensemble = jax.vmap(laplace_mapping, in_axes=(None, 0, None))(train_state.params, ensemble_noise, laplace_state)
            _, last_mean_val = network.apply(train_state.params, last_obs)
            _, last_ens_val = jax.vmap(network.apply, in_axes=(0, None), out_axes=-1)(ensemble, last_obs)

            def _calculate_gae(traj_batch, last_mean_val, last_ens_val):
                def _get_advantages(gaes_and_next_values, transition):
                    mean_gae, ens_gae, next_mean_value, next_ens_value = gaes_and_next_values
                    done, mean_value, ens_value, reward = (
                        transition.done,
                        transition.mean_value,
                        transition.ensemble_value,
                        transition.reward,
                    )
                    # print(next_value.shape, value.shape)
                    mean_delta = reward + config["GAMMA"] * next_mean_value * (1 - done) - mean_value
                    mean_gae = (
                        mean_delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * mean_gae
                    )

                    ens_delta = reward[..., None] + config["GAMMA"] * next_ens_value * (1 - done[..., None]) - ens_value
                    ens_gae = (
                        ens_delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done[..., None]) * ens_gae
                    )

                    return (mean_gae, ens_gae, mean_value, ens_value), (mean_gae, ens_gae)

                _, (mean_advantages, ens_advantages) = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_mean_val), jnp.zeros_like(last_ens_val), last_mean_val, last_ens_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return (mean_advantages, ens_advantages), (mean_advantages + traj_batch.mean_value, ens_advantages + traj_batch.ensemble_value)
            # print(last_val.shape)
            advantages, targets = _calculate_gae(traj_batch, last_mean_val, last_ens_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(carry: Tuple[TrainState, LaplaceState], batch_info):

                    train_state, laplace_state = carry
                    traj_batch, advantages, targets = batch_info
                    n_transistions_seen = train_state.step * config["MINIBATCH_SIZE"]
                    
                    def update_laplace(params, laplace_state, traj_batch):

                        def grad_norm_estimate(params, traj_batch):
                            value_net = lambda params, obs: network.apply(params, obs)[1].squeeze()

                            obs_shape = traj_batch.obs.shape[1:]

                            grads = jax.vmap(jax.grad(value_net), in_axes=(None, 0))(params, traj_batch.obs.reshape(-1, 1, *obs_shape))

                            mean_output = jnp.mean(value_net(params, traj_batch.obs)**2)
                            if config["ADAPTIVE_SIGMA"]:
                                sigma = config["TARGET_SIGMA"] * mean_output
                            else:
                                sigma = config["TARGET_SIGMA"]

                            return jax.tree_util.tree_map(
                                lambda l, g: 1 / l**2 + n_transistions_seen * jnp.mean(g * g / sigma**2, axis=0), 
                                prior_scale, grads
                            )

                        fisher, m = laplace_state
                        fisher = jax.tree_util.tree_map(lambda o, n: (1 - config["FISHER_LR"]) * o + n, fisher, grad_norm_estimate(params, traj_batch))
                        m = (1 - config["FISHER_LR"]) * m + 1
                        return LaplaceState(fisher, m)

                    def _loss_fn(params, traj_batch, gaes, targets):
                        # RERUN NETWORK
                        pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        m_gae, e_gae = gaes
                        m_targets, e_targets = targets

                        # CALCULATE VALUE LOSS
                        mean_value = traj_batch.mean_value
                        value_pred_clipped = mean_value + (
                            value - mean_value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])

        
                        value_losses = jnp.square(value - m_targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - m_targets)
                        value_loss = (
                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                        )

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)

                        m_gae = (m_gae - m_gae.mean()) / (m_gae.std() + 1e-8)
                        e_gae = (e_gae - e_gae.mean()) / (e_gae.std() + 1e-8)

                        gae_uncertainty = e_gae.std(axis=-1)

                        
                        if config["UNCERTAINTY"]:
                            squashed_uncertainty = config["UNCERTAINTY_BASE"] + config["UNCERTAINTY_SCALE"] * jax.nn.sigmoid( config["UNCERTAINTY_MULTIPLIER"] * (-gae_uncertainty + config["UNCERTAINTY_SHIFT"] ))
                        else:
                            squashed_uncertainty = jnp.array(1.)

                        loss_actor1 = ratio * m_gae # average loss over all values
                        loss_actor2 = (
                            jnp.clip(
                                ratio,
                                1.0 - config["CLIP_EPS"] * squashed_uncertainty,
                                1.0 + config["CLIP_EPS"] * squashed_uncertainty,
                            )
                            * m_gae
                        )

                        percentage_clipped = (jnp.abs(ratio - 1) > config["CLIP_EPS"] * squashed_uncertainty).mean()

                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)

                        loss_argmin = (loss_actor1 > loss_actor2).mean()
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, dict(value_loss=value_loss, 
                                                loss_actor=loss_actor, 
                                                entropy=entropy, 
                                                mean_uncertainty=gae_uncertainty.mean(), 
                                                min_uncertainty=gae_uncertainty.min(), 
                                                max_uncertainty=gae_uncertainty.max(),
                                                min_gae = m_gae.min(),
                                                max_gae = m_gae.max(),
                                                percentage_clipped=percentage_clipped,
                                                loss_argmin=loss_argmin,
                                                mean_squashed_uncertainty=squashed_uncertainty.mean(),
                                                max_squashed_uncertainty=squashed_uncertainty.max(),
                                                min_squashed_uncertainty=squashed_uncertainty.min())

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)

                    laplace_state = update_laplace(train_state.params, laplace_state, traj_batch)

                    return (train_state, laplace_state), total_loss

                train_state, traj_batch, advantages, laplace_state, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)

                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                (train_state, laplace_state), total_loss = jax.lax.scan(
                    _update_minbatch, (train_state, laplace_state), minibatches
                )
                update_state = (train_state, traj_batch, advantages, laplace_state, targets, rng)
                return update_state, total_loss

            update_state = (train_state, traj_batch, advantages, laplace_state, targets, rng)
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )
            train_state = update_state[0]
            laplace_state = update_state[3]
            metric = traj_batch.info
            rng = update_state[-1]

            runner_state = (train_state, env_state, laplace_state, last_obs, rng)
            info = dict(
                metrics=metric,
                loss_info=loss_info
            )
            return runner_state, info
        
        if config["SAVE_STATE_FREQ"]:
            outer_updates = config["NUM_UPDATES"] // config["SAVE_STATE_FREQ"]
            inner_updates = config["SAVE_STATE_FREQ"]
            left_over = config["NUM_UPDATES"] - inner_updates * outer_updates 
            print(inner_updates, outer_updates, left_over)
        else:
            inner_updates = config["NUM_UPDATES"]
            outer_updates = 1
            left_over = 0

        # start with the left_over steps:
        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, laplace_state, obsv, _rng)
        runner_state, initial_info = jax.lax.scan(
            _update_step, runner_state, None, left_over
        )

        def _outer_update_step(update_state, unused):
            runner_state, info = jax.lax.scan(
                _update_step, update_state, None, inner_updates
            )
            return runner_state, (runner_state, info)

        # then do the remaining updates
        runner_state, (states, info) = jax.lax.scan(
            _outer_update_step, runner_state, None, outer_updates
        )

        info = jax.tree_util.tree_map(lambda x: x.reshape(-1, *x.shape[2:]), info)
        info = jax.tree_util.tree_map(lambda x, y: jnp.concatenate([x, y], axis=0), initial_info, info)

        info["runner_state"] = states
        info['laplace_state'] = runner_state[2]
              
        return info
    return train

In [8]:
lp_ecppo_config = {
    "LR": 5e-3,
    "NUM_ENVS": 64,
    "NUM_STEPS": 128,
    "TOTAL_TIMESTEPS": 1e7,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 8,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.01,
    "VF_COEF": .5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "leaky_relu",
    "ENV_NAME": "FourRooms-misc",
    "ANNEAL_LR": True,
    "ENSEMBLE_SIZE": 5,
    "UNCERTAINTY": True,
    "SAVE_STATE_FREQ": False,
    "PRIOR_SCALE": 2.0,
    "FISHER_LR": 1e-2,
    "TARGET_SIGMA": 0.1,
    "ADAPTIVE_SIGMA": True,
    "UNCERTAINTY_SHIFT": 0.3,
    "UNCERTAINTY_MULTIPLIER": 15,
}

rng = jax.random.PRNGKey(42)
train_jit = jax.jit(make_train(lp_ecppo_config))

In [9]:

configs = {
    'laplace-ecppo': lp_ecppo_config,
}

load = False

if load:
    with open('results.pkl', 'rb') as f:
        results, _ = pickle.load(f)
else:
    results = {}

In [10]:
import tqdm
import pickle
from pathlib import Path

envs = [env for env in gymnax.registered_envs if isinstance(gymnax.make(env)[0].action_space(), gymnax.environments.spaces.Discrete)]


print(envs)

base = 1e7 // (64 * 128)
short_axis = {
    "CartPole-v1": 200,
    "Pong-misc": 400,
    "UmbrellaChain-bsuite": 50,
    "FourRooms-misc": 100,
    "DiscountingChain-bsuite": 200,
    "GaussianBandit-misc": 50,
    "Catch-bsuite": 50,
}

['CartPole-v1', 'Acrobot-v1', 'MountainCar-v0', 'Asterix-MinAtar', 'Breakout-MinAtar', 'Freeway-MinAtar', 'SpaceInvaders-MinAtar', 'Catch-bsuite', 'DeepSea-bsuite', 'MemoryChain-bsuite', 'UmbrellaChain-bsuite', 'DiscountingChain-bsuite', 'MNISTBandit-bsuite', 'SimpleBandit-bsuite', 'FourRooms-misc', 'MetaMaze-misc', 'BernoulliBandit-misc', 'GaussianBandit-misc', 'Pong-misc']


In [11]:

print(f"All envs: {gymnax.registered_envs}")
print(f"Used envs: {envs}")

n_envs = len(envs)

force_agent = []
force_envs = []

pbar = tqdm.tqdm(envs)
for env in pbar:
    if env in short_axis:
        total_timesteps = int(1e7 * short_axis[env] / base)
    else:
        total_timesteps = 1e7
    for name in configs.keys():
        experiment_name = name + "/" + env

        if (experiment_name not in results.keys()) or (name in force_agent) and (env in force_envs):
            pbar.set_postfix_str(experiment_name)
            
            config = {key: value for key, value in configs[name].items()} # copy the config
            config["ENV_NAME"] = env
            config["TOTAL_TIMESTEPS"] = total_timesteps

            rng = jax.random.PRNGKey(42)
            train_jit = jax.jit(make_train(config))
            out = jax.vmap(train_jit)(jax.random.split(rng, 20))

            results[experiment_name] = process_out(out)

All envs: ['CartPole-v1', 'Pendulum-v1', 'Acrobot-v1', 'MountainCar-v0', 'MountainCarContinuous-v0', 'Asterix-MinAtar', 'Breakout-MinAtar', 'Freeway-MinAtar', 'SpaceInvaders-MinAtar', 'Catch-bsuite', 'DeepSea-bsuite', 'MemoryChain-bsuite', 'UmbrellaChain-bsuite', 'DiscountingChain-bsuite', 'MNISTBandit-bsuite', 'SimpleBandit-bsuite', 'FourRooms-misc', 'MetaMaze-misc', 'PointRobot-misc', 'BernoulliBandit-misc', 'GaussianBandit-misc', 'Reacher-misc', 'Swimmer-misc', 'Pong-misc']
Used envs: ['CartPole-v1', 'Acrobot-v1', 'MountainCar-v0', 'Asterix-MinAtar', 'Breakout-MinAtar', 'Freeway-MinAtar', 'SpaceInvaders-MinAtar', 'Catch-bsuite', 'DeepSea-bsuite', 'MemoryChain-bsuite', 'UmbrellaChain-bsuite', 'DiscountingChain-bsuite', 'MNISTBandit-bsuite', 'SimpleBandit-bsuite', 'FourRooms-misc', 'MetaMaze-misc', 'BernoulliBandit-misc', 'GaussianBandit-misc', 'Pong-misc']


100%|██████████| 19/19 [16:06<00:00, 50.86s/it, laplace-ecppo/Pong-misc]              


In [12]:
save_results('laplace-ecppo-results.pkl', drop_runner_state(results), configs=configs, force=True)