In [1]:
# Taken from:
# https://github.com/corl-team/xland-minigrid/blob/main/training/train_single_task.py

import os
import shutil
import time
from typing import Sequence
from tqdm import tqdm

import distrax
import flax
import flax.linen as nn
import gymnax
import jax
import jax.numpy as jnp
import jax.tree_util
import numpy as np
import optax
import wandb
from flax.jax_utils import replicate, unreplicate
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState

from MetaLearnCuriosity.agents.nn import (
    AtariBYOLPredictor,
    BYOLTarget,
    CloseScannedRNN,
    OpenScannedRNN,
    TemporalRewardCombiner
)
from MetaLearnCuriosity.pmapped_open_es import OpenES
from MetaLearnCuriosity.checkpoints import Save
from MetaLearnCuriosity.logger import WBLogger
from MetaLearnCuriosity.utils import BYOLRewardNorm
from MetaLearnCuriosity.utils import RCBYOLTransition as Transition
from MetaLearnCuriosity.utils import (
    byol_normalize_prior_int_rewards,
    process_output_general,
    update_target_state_with_ema,
)
from MetaLearnCuriosity.wrappers import FlattenObservationWrapper, LogWrapper, VecEnv

environments = [
    "Asterix-MinAtar",
    # "Breakout-MinAtar",
    # "Freeway-MinAtar",
    # "SpaceInvaders-MinAtar",
]


class PPOActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        actor_mean = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(
            actor_mean
        )
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        pi = distrax.Categorical(logits=actor_mean)

        critic = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
        critic = activation(critic)
        critic = nn.Dense(64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(critic)

        return pi, jnp.squeeze(critic, axis=-1)


config = {
    "RUN_NAME": "minatar_byol_ppo",
    "SEED": 42,
    "NUM_SEEDS": 8,
    "LR": 5e-3,
    "NUM_ENVS": 64,
    "NUM_STEPS": 128,
    "TOTAL_TIMESTEPS": 1e7,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 8,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.01,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "relu",
    "ANNEAL_LR": True,
    "ANNEAL_PRED_LR": False,
    "DEBUG": False,
    "PRED_LR": 0.001,
    "REW_NORM_PARAMETER": 0.99,
    "EMA_PARAMETER": 0.99,
    "POP_SIZE":8,
    "ES_SEED": 7,
    "RC_SEED": 23,
    "NUM_GENERATIONS":2,
    # "INT_LAMBDA": 0.001,
}

environments = [
    "Asterix-MinAtar",
    # "Breakout-MinAtar",
    # "Freeway-MinAtar",
    # "SpaceInvaders-MinAtar",
]


def make_config_env(config, env_name):
    config["ENV_NAME"] = env_name
    num_devices = jax.local_device_count()
    assert config["NUM_ENVS"] % num_devices == 0
    config["NUM_ENVS_PER_DEVICE"] = config["NUM_ENVS"] // num_devices
    config["TOTAL_TIMESTEPS_PER_DEVICE"] = config["TOTAL_TIMESTEPS"] // num_devices
    # config["EVAL_EPISODES_PER_DEVICE"] = config["EVAL_EPISODES"] // num_devices
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS_PER_DEVICE"] // config["NUM_STEPS"] // config["NUM_ENVS_PER_DEVICE"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS_PER_DEVICE"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    config["TRAINING_HORIZON"] = (
        config["TOTAL_TIMESTEPS_PER_DEVICE"] // config["NUM_ENVS_PER_DEVICE"]
    )
    env, env_params = gymnax.make(config["ENV_NAME"])
    env = FlattenObservationWrapper(env)
    env = LogWrapper(env)
    env = VecEnv(env)

    return config, env, env_params


def ppo_make_train(rng):
    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac

    def pred_linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["PRED_LR"] * frac

    # INIT NETWORK
    network = PPOActorCritic(env.action_space(env_params).n, activation=config["ACTIVATION"])
    target = BYOLTarget(128)
    pred = AtariBYOLPredictor(128, env.action_space(env_params).n)

    # KEYS
    rng, _rng = jax.random.split(rng)
    rng, _tar_rng = jax.random.split(rng)
    # rng, _en_rng = jax.random.split(rng)
    rng, _pred_rng = jax.random.split(rng)

    # INIT INPUT
    init_x = jnp.zeros((1, config["NUM_ENVS_PER_DEVICE"], *env.observation_space(env_params).shape))
    init_action = jnp.zeros((config["NUM_ENVS_PER_DEVICE"],), dtype=jnp.int32)
    close_init_hstate = CloseScannedRNN.initialize_carry(config["NUM_ENVS_PER_DEVICE"], 128)
    open_init_hstate = OpenScannedRNN.initialize_carry(config["NUM_ENVS_PER_DEVICE"], 128)
    init_bt = jnp.zeros((1, config["NUM_ENVS_PER_DEVICE"], 128))

    init_pred_input = (init_bt, init_x, init_action[np.newaxis, :])

    network_params = network.init(_rng, init_x)
    pred_params = pred.init(_pred_rng, close_init_hstate, open_init_hstate, init_pred_input)
    target_params = target.init(_tar_rng, init_x)

    if config["ANNEAL_LR"]:
        tx = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(learning_rate=linear_schedule, eps=1e-5),
        )
    else:
        tx = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(config["LR"], eps=1e-5),
        )

    if config["ANNEAL_PRED_LR"]:
        pred_tx = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(learning_rate=pred_linear_schedule, eps=1e-5),
        )
    else:
        pred_tx = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(config["PRED_LR"], eps=1e-5),
        )

    train_state = TrainState.create(
        apply_fn=network.apply,
        params=network_params,
        tx=tx,
    )
    pred_state = TrainState.create(
        apply_fn=pred.apply,
        params=pred_params,
        tx=pred_tx,
    )

    target_state = TrainState.create(
        apply_fn=target.apply,
        params=target_params,
        tx=pred_tx,
    )

    rng = jax.random.split(rng, jax.local_device_count())

    return (
        rng,
        train_state,
        pred_state,
        target_state,
        init_bt,
        close_init_hstate,
        open_init_hstate,
        init_action,
    )


def train(
    rng,
    rc_params,
    train_state,
    pred_state,
    target_state,
    init_bt,
    close_init_hstate,
    open_init_hstate,
    init_action,
):
    # REWARD COMBINER
    rc_network=TemporalRewardCombiner()
    # INIT STUFF FOR OPTIMIZATION AND NORMALIZATION
    update_target_counter = 0
    byol_reward_norm_params = BYOLRewardNorm(0, 0, 1, 0)

    # INIT ENV
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, config["NUM_ENVS_PER_DEVICE"])
    obsv, env_state = env.reset(reset_rng, env_params)

    # TRAIN LOOP
    def _update_step(runner_state, unused):
        # COLLECT TRAJECTORIES
        def _env_step(runner_state, unused):
            (
                train_state,
                pred_state,
                target_state,
                bt,
                close_hstate,
                open_hstate,
                last_act,
                env_state,
                last_obs,
                byol_reward_norm_params,
                update_target_counter,
                rng,
            ) = runner_state

            # SELECT ACTION
            rng, _rng = jax.random.split(rng)
            pi, value = train_state.apply_fn(train_state.params, last_obs[np.newaxis, :])
            action = pi.sample(seed=_rng)
            log_prob = pi.log_prob(action)

            # STEP ENV
            rng, _rng = jax.random.split(rng)
            rng_step = jax.random.split(_rng, config["NUM_ENVS_PER_DEVICE"])
            obsv, env_state, reward, done, info = env.step(
                rng_step, env_state, action.squeeze(0), env_params
            )

            # TIME STEP

            norm_time_step = info["timestep"]/config["TRAINING_HORIZON"]

            # INT REWARD
            tar_obs = target_state.apply_fn(target_state.params, obsv[np.newaxis, :])
            pred_input = (bt, last_obs[np.newaxis, :], last_act[np.newaxis, :])
            pred_obs, new_bt, new_close_hstate, new_open_hstate = pred_state.apply_fn(
                pred_state.params, close_hstate, open_hstate, pred_input
            )
            pred_norm = (pred_obs.squeeze(0)) / (
                jnp.linalg.norm(pred_obs.squeeze(0), axis=-1, keepdims=True)
            )
            tar_norm = jax.lax.stop_gradient(
                (tar_obs.squeeze(0)) / (jnp.linalg.norm(tar_obs.squeeze(0), axis=-1, keepdims=True))
            )
            int_reward = jnp.square(jnp.linalg.norm((pred_norm - tar_norm), axis=-1)) * (1 - done)
            value, action, log_prob = (value.squeeze(0), action.squeeze(0), log_prob.squeeze(0))
            transition = Transition(
                done,
                last_act,
                action,
                value,
                reward,
                int_reward,
                log_prob,
                last_obs,
                obsv,
                bt,
                norm_time_step,
                info,
            )
            runner_state = (
                train_state,
                pred_state,
                target_state,
                new_bt,
                new_close_hstate,
                new_open_hstate,
                action,
                env_state,
                obsv,
                byol_reward_norm_params,
                update_target_counter,
                rng,
            )
            return runner_state, transition

        close_initial_hstate, open_initial_hstate = runner_state[4:6]
        runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["NUM_STEPS"])

        # CALCULATE ADVANTAGE
        (
            train_state,
            pred_state,
            target_state,
            bt,
            close_hstate,
            open_hstate,
            last_act,
            env_state,
            last_obs,
            byol_reward_norm_params,
            update_target_counter,
            rng,
        ) = runner_state

        # update_target_counter+=1
        _, last_val = train_state.apply_fn(train_state.params, last_obs[np.newaxis, :])

        def _calculate_gae(traj_batch, last_val, byol_reward_norm_params):
            norm_int_reward, byol_reward_norm_params = byol_normalize_prior_int_rewards(
                traj_batch.int_reward, byol_reward_norm_params, config["REW_NORM_PARAMETER"]
            )
            norm_traj_batch = Transition(
                traj_batch.done,
                traj_batch.prev_action,
                traj_batch.action,
                traj_batch.value,
                traj_batch.reward,
                norm_int_reward,
                traj_batch.log_prob,
                traj_batch.obs,
                traj_batch.next_obs,
                traj_batch.bt,
                traj_batch.norm_time_step,
                traj_batch.info,
            )

            def _get_advantages(gae_and_next_value, transition):
                gae, next_value = gae_and_next_value
                done, value, reward, int_reward, norm_time_step = (
                    transition.done,
                    transition.value,
                    transition.reward,
                    transition.int_reward,
                    transition.norm_time_step,
                )
                rc_input = jnp.concatenate(
                    (reward[:, None], int_reward[:, None], norm_time_step[:, None]), axis=-1
                )              
                int_lambda = rc_network.apply(rc_params,rc_input)
                delta = (
                    (reward + (int_reward * int_lambda))
                    + config["GAMMA"] * next_value * (1 - done)
                    - value
                )
                gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                return (gae, value), gae

            _, advantages = jax.lax.scan(
                _get_advantages,
                (jnp.zeros_like(last_val), last_val),
                norm_traj_batch,
                reverse=True,
                unroll=16,
            )
            return (
                advantages,
                advantages + traj_batch.value,
                norm_int_reward,
                byol_reward_norm_params,
            )

        advantages, targets, norm_int_reward, byol_reward_norm_params = _calculate_gae(
            traj_batch, last_val.squeeze(0), byol_reward_norm_params
        )

        # UPDATE NETWORK
        def _update_epoch(update_state, unused):
            def _update_minbatch(train_states, batch_info):
                traj_batch, advantages, targets, init_close_hstate, init_open_hstate = batch_info
                train_state, pred_state, target_state, update_target_counter = train_states

                def pred_loss(
                    pred_params, target_params, traj_batch, init_close_hstate, init_open_hstate
                ):
                    tar_obs = target_state.apply_fn(target_params, traj_batch.next_obs)
                    pred_input = (traj_batch.bt, traj_batch.obs, traj_batch.prev_action)
                    pred_obs, _, _, _ = pred_state.apply_fn(
                        pred_params, init_close_hstate[0], init_open_hstate[0], pred_input
                    )
                    pred_norm = (pred_obs) / (jnp.linalg.norm(pred_obs, axis=-1, keepdims=True))
                    tar_norm = jax.lax.stop_gradient(
                        (tar_obs) / (jnp.linalg.norm(tar_obs, axis=-1, keepdims=True))
                    )
                    loss = jnp.square(jnp.linalg.norm((pred_norm - tar_norm), axis=-1)) * (
                        1 - traj_batch.done
                    )
                    return loss.mean()

                def _loss_fn(params, traj_batch, gae, targets):
                    # RERUN NETWORK
                    pi, value = train_state.apply_fn(params, traj_batch.obs)
                    log_prob = pi.log_prob(traj_batch.action)

                    # CALCULATE VALUE LOSS
                    value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
                        -config["CLIP_EPS"], config["CLIP_EPS"]
                    )
                    value_losses = jnp.square(value - targets)
                    value_losses_clipped = jnp.square(value_pred_clipped - targets)
                    value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()

                    # CALCULATE ACTOR LOSS
                    ratio = jnp.exp(log_prob - traj_batch.log_prob)
                    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                    loss_actor1 = ratio * gae
                    loss_actor2 = (
                        jnp.clip(
                            ratio,
                            1.0 - config["CLIP_EPS"],
                            1.0 + config["CLIP_EPS"],
                        )
                        * gae
                    )
                    loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                    loss_actor = loss_actor.mean()
                    entropy = pi.entropy().mean()

                    total_loss = (
                        loss_actor + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy
                    )
                    return total_loss, (value_loss, loss_actor, entropy)

                (loss, (vloss, aloss, entropy)), grads = jax.value_and_grad(_loss_fn, has_aux=True)(
                    train_state.params, traj_batch, advantages, targets
                )
                pred_losses, pred_grads = jax.value_and_grad(pred_loss)(
                    pred_state.params,
                    target_state.params,
                    traj_batch,
                    init_close_hstate,
                    init_open_hstate,
                )
                # (loss, vloss, aloss, entropy, pred_losses, grads, pred_grads) = jax.lax.pmean(
                #     (loss, vloss, aloss, entropy, pred_losses, grads, pred_grads),
                #     axis_name="devices",
                # )

                def update_target(
                    target_state, pred_state, update_target_counter=update_target_counter
                ):
                    def true_fun(_):
                        # Perform the EMA update
                        return update_target_state_with_ema(
                            predictor_state=pred_state,
                            target_state=target_state,
                            ema_param=config["EMA_PARAMETER"],
                        )

                    def false_fun(_):
                        # Return the old target_params unchanged
                        return target_state

                    # Conditionally update every 10 steps
                    return jax.lax.cond(
                        update_target_counter % 320 == 0,
                        true_fun,
                        false_fun,
                        None,  # The argument passed to true_fun and false_fun, `_` in this case is unused
                    )

                update_target_counter += 1
                train_state = train_state.apply_gradients(grads=grads)
                pred_state = pred_state.apply_gradients(grads=pred_grads)
                target_state = update_target(target_state, pred_state, update_target_counter)

                return (train_state, pred_state, target_state, update_target_counter), (
                    loss,
                    (vloss, aloss, entropy),
                    pred_losses,
                )

            (
                train_state,
                pred_state,
                target_state,
                update_target_counter,
                init_close_hstate,
                init_open_hstate,
                traj_batch,
                advantages,
                targets,
                rng,
            ) = update_state
            rng, _rng = jax.random.split(rng)
            permutation = jax.random.permutation(_rng, config["NUM_ENVS_PER_DEVICE"])
            batch = (traj_batch, advantages, targets, init_close_hstate, init_open_hstate)

            shuffled_batch = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=1), batch
            )
            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.swapaxes(
                    jnp.reshape(
                        x,
                        [x.shape[0], config["NUM_MINIBATCHES"], -1] + list(x.shape[2:]),
                    ),
                    1,
                    0,
                ),
                shuffled_batch,
            )
            (
                train_state,
                pred_state,
                target_state,
                update_target_counter,
            ), total_loss = jax.lax.scan(
                _update_minbatch,
                (train_state, pred_state, target_state, update_target_counter),
                minibatches,
            )
            update_state = (
                train_state,
                pred_state,
                target_state,
                update_target_counter,
                init_close_hstate,
                init_open_hstate,
                traj_batch,
                advantages,
                targets,
                rng,
            )
            return update_state, total_loss

        traj_batch = Transition(
            traj_batch.done,
            traj_batch.prev_action,
            traj_batch.action,
            traj_batch.value,
            traj_batch.reward,
            traj_batch.int_reward,
            traj_batch.log_prob,
            traj_batch.obs,
            traj_batch.next_obs,
            traj_batch.bt.squeeze(1),
            traj_batch.norm_time_step,
            traj_batch.info,
        )

        update_state = (
            train_state,
            pred_state,
            target_state,
            update_target_counter,
            open_initial_hstate[np.newaxis, :],
            close_initial_hstate[np.newaxis, :],
            traj_batch,
            advantages,
            targets,
            rng,
        )
        update_state, loss_info = jax.lax.scan(
            _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
        )
        train_state, pred_state, target_state, update_target_counter = update_state[:4]
        metric = traj_batch.info
        rng = update_state[-1]
        if config.get("DEBUG"):

            def callback(info):
                return_values = info["returned_episode_returns"][info["returned_episode"]]
                timesteps = (
                    info["timestep"][info["returned_episode"]] * config["NUM_ENVS_PER_DEVICE"]
                )
                for t in range(len(timesteps)):
                    print(f"global step={timesteps[t]}, episodic return={return_values[t]}")

            jax.debug.callback(callback, metric)

        runner_state = (
            train_state,
            pred_state,
            target_state,
            bt,
            close_hstate,
            open_hstate,
            last_act,
            env_state,
            last_obs,
            byol_reward_norm_params,
            update_target_counter,
            rng,
        )
        return runner_state, (metric, loss_info, norm_int_reward, traj_batch.int_reward)

    rng, _rng = jax.random.split(rng)
    runner_state = (
        train_state,
        pred_state,
        target_state,
        init_bt,
        close_init_hstate,
        open_init_hstate,
        init_action,
        env_state,
        obsv,
        byol_reward_norm_params,
        update_target_counter,
        _rng,
    )
    runner_state, extra_info = jax.lax.scan(_update_step, runner_state, None, config["NUM_UPDATES"])
    metric, rl_total_loss, int_reward, norm_int_reward = extra_info
    return {
        "train_state": runner_state[0],
        "metrics": metric,
        "rl_total_loss": rl_total_loss[0],
        "rl_value_loss": rl_total_loss[1][0],
        "rl_actor_loss": rl_total_loss[1][1],
        "rl_entrophy_loss": rl_total_loss[1][2],
        "pred_loss": rl_total_loss[2],
        "int_reward": int_reward,
        "norm_int_reward": norm_int_reward,
        "rng": runner_state[-1]
    }




def es_step(
    train_fn,rng, train_state, pred_state, target_state, init_bt, close_init_hstate, open_init_hstate, init_action
):
    group="reward_combiners"
    tags=["meta-learner", config["ENV_NAME"]]
    name=f'{config["RUN_NAME"]}_{config["ENV_NAME"]}'
    # fit_log=wandb.init(
    #             project="MetaLearnCuriosity",
    #             config=config,
    #             group=group,
    #             tags=tags,
    #             name=f"{name}_fitness",
    #         )
    reward_combiner_network = TemporalRewardCombiner()
    rc_params_pholder = reward_combiner_network.init(
        jax.random.PRNGKey(config["RC_SEED"]), jnp.zeros((1, 3))
    )
    es_rng = jax.random.PRNGKey(config["ES_SEED"])
    strategy = OpenES(
        popsize=config["POP_SIZE"],
        pholder_params=rc_params_pholder,
        opt_name="adam",
        lrate_init=2e-4,
    )

    es_rng, es_rng_init = jax.random.split(es_rng)
    es_params = strategy.default_params
    es_state = strategy.initialize(es_rng_init, es_params)

    for _ in tqdm(range(config["NUM_GENERATIONS"]), desc="Processing Generations"):
        t = time.time()
        es_rng, es_rng_ask = jax.random.split(es_rng)
        x, es_state = strategy.ask(es_rng_ask, es_state, es_params)
        output = jax.block_until_ready(
            train_fn(
                rng,
                x,
                train_state,
                pred_state,
                target_state,
                init_bt,
                close_init_hstate,
                open_init_hstate,
                init_action,
            )
        )
        rewards = output["metrics"]["sum_of_rewards"]
        rng_ = unreplicate(output["rng"])
        # (4, 2, 8, 1220, 128, 16)
        fitness = rewards.mean(-1).mean(2).reshape(rewards.shape[0], rewards.shape[1], -1).sum(-1)
        es_state = strategy.tell(x, fitness, es_state, es_params)
        elapsed_time = time.time() - t
        print(f"Done in {elapsed_time / 60:.2f}min")
    #     fit_log.log({f"{name}_mean_fitness":fitness.mean(),
    #                  f"{name}_best_fitness":jnp.max(fitness)})

    # fit_log.finish()
    # logger = WBLogger(
    #     config=config,
    #     group=group,
    #     tags=tags,
    #     name=name,
    # )
    # # Get the absolute path of the directory
    # checkpoint_directory = f'MLC_logs/flax_ckpt/{config["ENV_NAME"]}/{config["RUN_NAME"]}'
    params=strategy.param_reshaper.reshape_single(es_state.mean[0])
    print(params)
    # path = os.path.abspath(checkpoint_directory)
    # Save(path, params)
    # logger.save_artifact(path)
    # shutil.rmtree(path)

    return es_rng, es_state, fitness,output["rng"]

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
for env_name in environments:
    rng = jax.random.PRNGKey(config["SEED"])
    config, env, env_params = make_config_env(config, env_name)
    print(f"Training in {config['ENV_NAME']}")

    rng = jax.random.split(rng, config["NUM_SEEDS"])
    (
        rng,
        train_state,
        pred_state,
        target_state,
        init_bt,
        close_init_hstate,
        open_init_hstate,
        init_action,
    ) = jax.jit(jax.vmap(ppo_make_train, out_axes=(1, 0, 0, 0, 0, 0, 0, 0)))(rng)
    open_init_hstate = replicate(open_init_hstate, jax.local_devices())
    close_init_hstate = replicate(close_init_hstate, jax.local_devices())
    train_state = replicate(train_state, jax.local_devices())
    pred_state = replicate(pred_state, jax.local_devices())
    target_state = replicate(target_state, jax.local_devices())
    init_bt = replicate(init_bt, jax.local_devices())
    init_action = replicate(init_action, jax.local_devices())
    train_fn = jax.vmap(train, in_axes=(0, None,0,0,0,0,0,0,0))
    train_fn = jax.vmap(train_fn, in_axes=(None, 0,None,None,None,None,None,None,None))
    train_fn = jax.pmap(train_fn, axis_name="devices")
    # Set up fitness function

f,j,k,m=es_step(
        train_fn, rng, train_state, pred_state, target_state, init_bt, close_init_hstate, open_init_hstate, init_action
        )

Training in Asterix-MinAtar
ParameterReshaper: 4 devices detected. Please make sure that the ES population size divides evenly across the number of devices to pmap/parallelize over.
ParameterReshaper: 4481 parameters detected for optimization.


Processing Generations:  50%|█████     | 1/2 [12:36<12:36, 756.95s/it]

Done in 12.62min


Processing Generations: 100%|██████████| 2/2 [24:37<00:00, 738.91s/it]

Done in 12.01min
{'params': {'Dense_0': {'bias': Array([-1.6313215 ,  1.9180655 , -1.2958858 ,  1.5257062 , -1.6895043 ,
       -0.3413339 , -1.9160405 ,  0.20504959,  0.30581132, -0.7549223 ,
       -1.9819707 ,  1.9887234 , -1.920914  ,  1.3097045 ,  0.52386355,
       -0.02650033,  0.83873504,  1.1383698 , -1.3117822 ,  0.6856074 ,
       -0.9332975 ,  0.19743367,  1.4761066 , -1.7068906 ,  1.3362262 ,
       -0.69882005,  0.18876372, -0.6923141 ,  0.6432712 ,  1.2847308 ,
       -0.32504177,  0.01683885, -1.6880344 ,  1.7459917 ,  1.0336772 ,
       -1.7787343 , -1.5043781 , -0.3447996 ,  1.7889901 ,  0.09623729,
       -1.4789462 ,  0.38523203,  1.3221478 , -0.24969758,  0.7906623 ,
        0.8117293 ,  0.50637513,  1.4963506 , -0.22705078,  0.9964611 ,
        0.30855957, -1.3571128 ,  1.2790744 , -1.9557002 ,  0.7683513 ,
       -0.67727995,  1.4581444 , -1.4872427 ,  1.9181707 , -1.683825  ,
       -0.9087271 , -1.2258589 , -0.31389362,  1.736255  ], dtype=float32), 'kernel': A




In [3]:
print(rng.shape)

(4, 8, 2)
