## PPO

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
import time
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any, Dict, Optional, Union, Tuple
from flax.training.train_state import TrainState
import distrax
import gymnax
import functools
from functools import partial
from gymnax.environments import environment, spaces
import chex
import wandb
from flax import struct


class GymnaxWrapper(object):
    """Base class for Gymnax wrappers."""

    def __init__(self, env):
        self._env = env
    
    # provide proxy access to regular attributes of wrapped object
    def __getattr__(self, name):
        return getattr(self._env, name)

@struct.dataclass
class LogEnvState:
    env_state: environment.EnvState
    episode_returns: float
    episode_lengths: int
    returned_episode_returns: float
    returned_episode_lengths: int
    timestep: int

class LogWrapper(GymnaxWrapper):
    """Log the episode returns and lengths."""

    def __init__(self, env: environment.Environment):
        super().__init__(env)

    @partial(jax.jit, static_argnums=(0,))
    def reset(
        self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None
    ) -> Tuple[chex.Array, environment.EnvState]:
        obs, env_state = self._env.reset(key, params)
        state = LogEnvState(env_state, 0, 0, 0, 0, 0)
        return obs, state

    @partial(jax.jit, static_argnums=(0,))
    def step(
        self,
        key: chex.PRNGKey,
        state: environment.EnvState,
        action: Union[int, float],
        params: Optional[environment.EnvParams] = None,
    ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]:
        obs, env_state, reward, done, info = self._env.step(key, state.env_state, action, params)
        new_episode_return = state.episode_returns + reward
        new_episode_length = state.episode_lengths + 1
        state = LogEnvState(
            env_state = env_state,
            episode_returns = new_episode_return * (1 - done),
            episode_lengths = new_episode_length * (1 - done),
            returned_episode_returns = state.returned_episode_returns * (1 - done) + new_episode_return * done,
            returned_episode_lengths = state.returned_episode_lengths * (1 - done) + new_episode_length * done,
            timestep = state.timestep + 1,
        )
        info["returned_episode"] = done
        info["return_info"] = jnp.stack([state.timestep, state.returned_episode_returns])
        return obs, state, reward, done, info

class ScannedRNN(nn.Module):

  @functools.partial(
    nn.scan,
    variable_broadcast='params',
    in_axes=0,
    out_axes=0,
    split_rngs={'params': False})
  @nn.compact
  def __call__(self, carry, x):
    """Applies the module."""
    features = carry[0].shape[-1]
    rnn_state = carry
    ins, resets = x
    rnn_state = jnp.where(resets[:, np.newaxis], self.initialize_carry(ins.shape[0], ins.shape[1]), rnn_state)
    new_rnn_state, y = nn.GRUCell(features)(rnn_state, ins)
    return new_rnn_state, y

  @staticmethod
  def initialize_carry(batch_size, hidden_size):
    return nn.GRUCell(hidden_size, parent=None).initialize_carry(
        jax.random.PRNGKey(0), (batch_size, hidden_size))

class ActorCriticRNN(nn.Module):
    action_dim: Sequence[int]
    config: Dict

    @nn.compact
    def __call__(self, hidden, x):
        obs, dones = x
        embedding = nn.Dense(128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(obs)
        embedding = nn.leaky_relu(embedding)
        embedding = nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(embedding)
        embedding = nn.leaky_relu(embedding)

        rnn_in = (embedding, dones)
        hidden, embedding = ScannedRNN()(hidden, rnn_in)

        actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(embedding)
        actor_mean = nn.leaky_relu(actor_mean)
        actor_mean = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(actor_mean)
        actor_mean = nn.leaky_relu(actor_mean)
        actor_mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
        if self.config["CONTINUOUS"]:
            actor_logtstd = self.param('log_std', nn.initializers.zeros, (self.action_dim,))
            pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))
        else:
            pi = distrax.Categorical(logits=actor_mean)

        critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(embedding)
        critic = nn.leaky_relu(critic)
        critic = nn.Dense(128, kernel_init=orthogonal(2), bias_init=constant(0.0))(critic)
        critic = nn.leaky_relu(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(critic)

        return hidden, pi, jnp.squeeze(critic, axis=-1)

    @staticmethod
    def initialize_carry(batch_size, hidden_size):
        return nn.GRUCell(hidden_size, parent=None).initialize_carry(
            jax.random.PRNGKey(0), (batch_size, hidden_size))

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray

def make_train(config):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    if "ENV_NAME" not in config:
        env, env_params = config["ENV"], config["ENV_PARAMS"]
        env = LogWrapper(env)
    else:
        env, env_params = gymnax.make(config["ENV_NAME"])
        env = LogWrapper(env)

    if "MOD_ENV_PARAMS" in config:
        env_params = env_params.replace(**config["MOD_ENV_PARAMS"])

    config["CONTINUOUS"] = type(env.action_space(env_params)) == spaces.Box 

    def linear_schedule(count):
        frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
        return config["LR"] * frac

    def train(rng):

        # INIT NETWORK
        if config["CONTINUOUS"]:
            network = ActorCriticRNN(env.action_space(env_params).shape[0], config=config)
        else:
            network = ActorCriticRNN(env.action_space(env_params).n, config=config)
        rng, _rng = jax.random.split(rng)
        init_x = (jnp.zeros((1, config["NUM_ENVS"], *env.observation_space(env_params).shape)), jnp.zeros((1, config["NUM_ENVS"])))
        init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 256)
        network_params = network.init(_rng, init_hstate, init_x)
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5))
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)
        init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 256)


        # TRAIN LOOP
        def _update_step(runner_state, unused):

            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, env_state, env_params, last_obs, last_done, hstate, rng = runner_state
                if config["UPDATE_VIS_PROB"]:
                    env_params = env_params.replace(
                        vis_prob=env_params.vis_prob - 1.0 / (config["TOTAL_TIMESTEPS"] / config["NUM_ENVS"])
                    )
                rng, _rng = jax.random.split(rng)

                # SELECT ACTION
                ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :])
                hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)
                value, action, log_prob = value.squeeze(0), action.squeeze(0), log_prob.squeeze(0)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0,None))(
                    rng_step, env_state, action, env_params
                )
                transition = Transition(last_done, action, value, reward, log_prob, last_obs, info)
                runner_state = (train_state, env_state, env_params, obsv, done, hstate, rng)
                return runner_state, transition

            initial_hstate = runner_state[-2]
            runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["NUM_STEPS"])

            # CALCULATE ADVANTAGE
            train_state, env_state, env_params, last_obs, last_done, hstate, rng = runner_state
            ac_in = (last_obs[np.newaxis, :], last_done[np.newaxis, :])
            _, _, last_val = network.apply(train_state.params, hstate, ac_in)
            last_val = last_val.squeeze(0)
            def _calculate_gae(traj_batch, last_val, last_done):
                def _get_advantages(carry, transition):
                    gae, next_value, next_done = carry
                    done, value, reward = transition.done, transition.value, transition.reward 
                    delta = reward + config["GAMMA"] * next_value * (1 - next_done) - value
                    gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - next_done) * gae
                    return (gae, value, done), gae
                _, advantages = jax.lax.scan(_get_advantages, (jnp.zeros_like(last_val), last_val, last_done), traj_batch, reverse=True, unroll=16)
                return advantages, advantages + traj_batch.value
            advantages, targets = _calculate_gae(traj_batch, last_val, last_done)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    init_hstate, traj_batch,  advantages, targets = batch_info
                    def _loss_fn(params, init_hstate, traj_batch, gae, targets):
                        # RERUN NETWORK
                        _, pi, value = network.apply(params, init_hstate[0], (traj_batch.obs, traj_batch.done))
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = jnp.clip(ratio, 1.0 - config["CLIP_EPS"], 1.0 + config["CLIP_EPS"]) * gae
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = loss_actor + config["VF_COEF"] * value_loss - config["ENT_COEF"] * entropy
                        return total_loss, (value_loss, loss_actor, entropy)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(train_state.params, init_hstate, traj_batch, advantages, targets)
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, total_loss

                train_state, init_hstate, traj_batch, advantages, targets, rng = update_state

                rng, _rng = jax.random.split(rng)
                permutation = jax.random.permutation(_rng, config["NUM_ENVS"])
                batch = (init_hstate, traj_batch, advantages, targets)

                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=1), batch
                )

                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.swapaxes(jnp.reshape(
                        x, [x.shape[0], config["NUM_MINIBATCHES"], -1] + list(x.shape[2:])
                    ), 1, 0),
                    shuffled_batch,
                )

                train_state, total_loss = jax.lax.scan(_update_minbatch, train_state, minibatches)
                update_state = (train_state, init_hstate, traj_batch, advantages, targets, rng)
                return update_state, total_loss

            init_hstate = initial_hstate[None,:] # TBH
            update_state = (train_state, init_hstate, traj_batch, advantages, targets, rng)
            update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config["UPDATE_EPOCHS"])
            train_state = update_state[0]
            metric = traj_batch.info
            rng = update_state[-1]

            if config["DEBUG"]:
                metric = (traj_batch.info["return_info"][...,1]*traj_batch.info["returned_episode"]).sum() / traj_batch.info["returned_episode"].sum()
                if config.get("LOG"):
                    def callback(metric):
                        print(metric)
                        wandb.log({"metric": metric})
                else:
                    def callback(metric):
                        print(metric)
                jax.debug.callback(callback, metric)
            else:
                metric = (traj_batch.info["return_info"][...,1]*traj_batch.info["returned_episode"]).sum() / traj_batch.info["returned_episode"].sum()

            runner_state = (train_state, env_state, env_params, last_obs, last_done, hstate, rng)
            return runner_state, metric

        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, env_params, obsv, jnp.zeros((config["NUM_ENVS"]), dtype=bool), init_hstate, _rng) 
        runner_state, metric = jax.lax.scan(_update_step, runner_state, None, config["NUM_UPDATES"])
        return runner_state, metric
    
    return train

## TSP

In [None]:
import jax
import jax.numpy as jnp
from jax import lax
from gymnax.environments import environment, spaces
from typing import Tuple, Optional
import chex
from flax import struct

@struct.dataclass
class EnvState:
    timestep: int
    cur_city : int
    city_pos: jnp.ndarray
    visited_cities: jnp.ndarray
    trial_num: int
    trial_timestep: int
    # OTHER NETWORK IN-CONTEXT
    other_state: Optional[jnp.ndarray] = None
    other_act: Optional[jnp.ndarray] = None
    # NO NETWORK OTHERS
    other_temps: Optional[jnp.array] = None

@struct.dataclass
class EnvParams:
    # OTHER NETWORK IN-CONTEXT
    other_params: Optional[jnp.ndarray] = None
    other_init_state: Optional[jnp.ndarray] = None
    other_init_last_obs: Optional[jnp.ndarray] = None
    temp_scale: Optional[float] = 1.0

class MetaTSP(environment.Environment):
    def __init__(
            self, 
            num_cities=6,
            num_trials=4,
            other_network=None,
            init_rng=None,
            num_agents=1,
            sort_best=True,
            other_temps=None,
            reset_on_mistake=True,
        ):
        super().__init__()
        self.num_cities = num_cities
        self.num_trials = num_trials
        self.other_network = other_network
        self.init_rng = init_rng
        self.num_agents = num_agents
        self.sort_best = sort_best
        self.other_temps = other_temps
        self.reset_on_mistake = reset_on_mistake

    @property
    def default_params(self) -> EnvParams:
        return EnvParams()

    def step_env(
        self, key: chex.PRNGKey, state: EnvState, action: int, params: EnvParams
    ) -> Tuple[chex.Array, EnvState, float, bool, dict]:
        reward = jnp.sqrt(2) - jnp.linalg.norm(state.city_pos[state.cur_city] - state.city_pos[action], axis=-1)
        trial_terminated = jnp.logical_and(state.visited_cities[action] == 1, self.reset_on_mistake)
        trial_terminated = jnp.logical_or(
            trial_terminated,
            state.trial_timestep == self.num_cities - 1,
        )
        reward = jnp.where(trial_terminated, reward + jnp.sqrt(2) - jnp.linalg.norm(
            state.city_pos[action] - state.city_pos[0], axis=-1
        ), reward) # IF LAST TIMESTEP, GIVE LAST REWARD
        reward = reward / jnp.sqrt(2) # NORMALIZE REWARD
        reward = jnp.where(state.visited_cities[action] == 0, reward, -1.0)

        next_trial_timestep = jnp.where(trial_terminated, 0, state.trial_timestep + 1)
        next_trial_num = jnp.where(trial_terminated, state.trial_num + 1, state.trial_num)
        next_city = jnp.where(trial_terminated, 0, action)
        next_visited = state.visited_cities.at[action].set(1)
        reset_visited = jnp.zeros((self.num_cities,), dtype=jnp.int8)
        reset_visited = reset_visited.at[0].set(1)
        next_visited = jnp.where(trial_terminated, reset_visited, next_visited)
        terminated = jnp.logical_and(
            trial_terminated,
            next_trial_num == self.num_trials,
        )

        if self.other_network is None:
            key, key_act = jax.random.split(key)
            dists = jnp.linalg.norm(state.city_pos[0, None, :] - state.city_pos, axis=-1)
            logits = -dists - next_visited * jnp.sqrt(2)
            logits = logits[None,...] / state.other_temps[:,None]
            key_acts = jax.random.split(key_act, self.num_agents)
            other_act = jax.vmap(jax.random.categorical)(key_acts, logits)
            other_state = None
        else:
            key, key_other = jax.random.split(key)
            other_state = jnp.where(trial_terminated, params.other_init_state, state.other_state)
            other_obs = jnp.concatenate([
                jax.nn.one_hot(action, self.num_cities),
                trial_terminated[None,],
                reward[None,],
                jnp.zeros(self.num_agents*self.num_cities+1,),
            ], axis=-1)[None,:]
            other_obs = jnp.tile(other_obs, (self.num_agents, 1))
            other_obs = jnp.where(trial_terminated, params.other_init_last_obs, other_obs)

            ac_in = [other_obs[:,None,None,:], jnp.zeros((self.num_agents, 1, 1))]
            other_state, pi, value = jax.vmap(self.other_network.apply, in_axes=(None,0,0))(params.other_params, other_state, ac_in)
            other_act = pi.sample(seed=key_other).squeeze(1).squeeze(1)

        own_act = jax.nn.one_hot(action, self.num_cities)

        new_state = EnvState(
            state.timestep + 1,
            next_city,
            state.city_pos,
            next_visited,
            next_trial_num, 
            next_trial_timestep, 
            other_state, 
            other_act,
            state.other_temps,
        )

        key, vis_key = jax.random.split(key)
        other_vis = jax.random.uniform(vis_key) < (self.num_trials - state.trial_num - 1) / (self.num_trials-1) # TODO: PLR?
        other_act = jax.nn.one_hot(other_act, self.num_cities)
        other_act = jnp.where(other_vis, other_act, jnp.zeros_like(other_act)).reshape((-1,))
        obs = jnp.concatenate([own_act, trial_terminated[None,], reward[None,], other_act, other_vis[None,]], axis=-1)
        return obs, new_state, reward, terminated, {}

    def reset_env(
        self, key: chex.PRNGKey, params: EnvParams
    ) -> Tuple[chex.Array, EnvState]:
        """Performs resetting of environment."""
        if self.init_rng is None:
            key, key_seq = jax.random.split(key)
            city_pos = jax.random.uniform(key_seq, (self.num_cities, 2))
        else:
            city_pos = jax.random.uniform(self.init_rng, (self.num_cities, 2))
        
        dists = jnp.linalg.norm(city_pos[0, None, :] - city_pos, axis=-1)
        visited_cities = jnp.zeros((self.num_cities,), dtype=jnp.int8)
        visited_cities = visited_cities.at[0].set(1)
        
        if self.other_network is None:
            if self.other_temps is not None:
                other_temps = self.other_temps
            else:
                key, key_prob = jax.random.split(key)
                key_probs = jax.random.split(key_prob, self.num_agents)
                other_temps = jax.vmap(jax.random.uniform)(key_probs) * params.temp_scale
            if self.sort_best:
                other_temps = jnp.sort(other_temps)
            key, key_act = jax.random.split(key)
            logits = -dists - visited_cities * jnp.sqrt(2)
            logits = logits[None,...] / other_temps[:,None]
            key_acts = jax.random.split(key_act, self.num_agents)
            other_act = jax.vmap(jax.random.categorical)(key_acts, logits)

            # CHECK SHAPES AND OUTPUTS
            other_state = None
        else:
            key, key_other = jax.random.split(key)
            other_obs = params.other_init_last_obs
            ac_in = [other_obs[:,None,None,:], jnp.zeros((self.num_agents, 1, 1))]
            other_state, pi, value = jax.vmap(self.other_network.apply, in_axes=(None, 0, 0))(params.other_params, params.other_init_state, ac_in)
            other_act = pi.sample(seed=key_other).squeeze(1).squeeze(1)
            other_temps = None

        state = EnvState(
            timestep=0,
            cur_city=0,
            city_pos=city_pos,
            visited_cities=visited_cities,
            trial_num=0,
            trial_timestep=0,
            other_state=other_state,
            other_act=other_act,
            other_temps=other_temps,
        )

        key, vis_key = jax.random.split(key)
        other_vis = jax.random.uniform(vis_key) < (self.num_trials - state.trial_num - 1) / (self.num_trials - 1)
        other_act = jax.nn.one_hot(other_act, self.num_cities)
        other_act = jnp.where(other_vis, other_act, jnp.zeros_like(other_act)).reshape((-1,))
        obs = jnp.concatenate([jnp.zeros(self.num_cities,), jnp.zeros((2,)), other_act, other_vis[None,]], axis=-1)
        return obs, state
    
    def action_space(
        self, params: Optional[EnvParams] = None
    ) -> spaces.Discrete:
        """Action space of the environment."""
        return spaces.Discrete(self.num_cities)

    def observation_space(self, params: EnvParams) -> spaces.Box:
        """Observation space of the environment."""
        return spaces.Box(jnp.zeros((self.num_cities*(self.num_agents+1)+3,)), jnp.ones((self.num_cities*(self.num_agents+1)+3,)), (self.num_cities*(self.num_agents+1)+3,), dtype=jnp.float32)



In [None]:
config = {
    "LR": 2.5e-4,
    "NUM_ENVS": 16,
    "NUM_STEPS": 128,
    "TOTAL_TIMESTEPS": 8e6,
    "UPDATE_EPOCHS": 2,
    "NUM_MINIBATCHES": 8,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.05,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ANNEAL_LR": True,
    "DEBUG": True,
    "CONTINUOUS": False,
    "UPDATE_VIS_PROB": False,
}

num_cities = 6
num_trials = 8
popsize = 3

env = MetaTSP(
    num_cities=num_cities,
    num_trials=num_trials,
    num_agents=popsize,
)

config["ENV"] = env
config["ENV_PARAMS"] = EnvParams(
    temp_scale=0.5,
)

jit_train = jax.jit(make_train(config))
rng = jax.random.PRNGKey(64)
rng, _rng = jax.random.split(rng)
outs = jit_train(_rng)


In [None]:
# ICL GENERAITION EVAL

init_rng = jax.random.PRNGKey(420)
init_env = MetaTSP(
    num_cities=num_cities,
    num_trials=num_trials,
    num_agents=popsize,
    init_rng=init_rng,
    other_temps=jnp.ones((popsize,))*10,
)


network = ActorCriticRNN(num_cities, config=config)

env = MetaTSP(
    num_cities=num_cities,
    num_trials=num_trials,
    num_agents=popsize,
    init_rng=init_rng,
    other_network=network,
)


init_env_params = EnvParams()


In [None]:
def init_eval_params_and_return(rng, params, env_params):
    rng, _rng = jax.random.split(rng)
    last_obs, env_state = init_env.reset(_rng, env_params)

    init_state = ScannedRNN.initialize_carry(1, 256)

    # COLLECT TRAJECTORIES
    def _env_step(runner_state, unused):
        prev_env_state, env_params, last_obs, last_done, prev_hstate, rng, ever_done, running_r, running_hstate, running_obs = runner_state
        rng, _rng = jax.random.split(rng)

        # SELECT ACTION
        ac_in = (last_obs[None, None, :], last_done[None, None])
        hstate, pi, value = network.apply(params, prev_hstate, ac_in)
        action = pi.sample(seed=_rng).squeeze(0).squeeze(0)

        # STEP ENV
        rng, _rng = jax.random.split(rng)
        obsv, env_state, reward, done, info = init_env.step(
            _rng, prev_env_state, action, env_params
        )

        # ONLY GET REWARDS FROM FIRST META-EPISODE
        temp_r = running_r[prev_env_state.trial_num] + (nn.relu(reward)) * (1.0 - ever_done)
        running_r = running_r.at[prev_env_state.trial_num].set(temp_r)

        # ONLY FINAL HIDDEN STATE AND ACTION
        replace_hstate = jnp.logical_and((prev_env_state.trial_num == env.num_trials - 2), ~ever_done)
        running_hstate = running_hstate * (1.0 - replace_hstate) + hstate * replace_hstate
        running_obs = running_obs * (1.0 - replace_hstate) + obsv * replace_hstate

        ever_done = jnp.logical_or(done, ever_done)
        transition = (env_state, action, reward, replace_hstate, running_r, running_hstate, running_obs)
        runner_state = (env_state, env_params, obsv, done, hstate, rng, ever_done, running_r, running_hstate, running_obs)
        return runner_state, transition

    runner_state = (env_state, env_params, last_obs, False, init_state, rng, False, jnp.zeros((env.num_trials,)), init_state, last_obs)
    runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, env.num_cities*env.num_trials+1)
    return traj_batch

def eval_params_and_return(rng, params, env_params):
    rng, _rng = jax.random.split(rng)
    last_obs, env_state = env.reset(_rng, env_params)

    init_state = ScannedRNN.initialize_carry(1, 256)

    # COLLECT TRAJECTORIES
    def _env_step(runner_state, unused):
        prev_env_state, env_params, last_obs, last_done, prev_hstate, rng, ever_done, running_r, running_hstate, running_obs = runner_state
        rng, _rng = jax.random.split(rng)

        # SELECT ACTION
        ac_in = (last_obs[None, None, :], last_done[None, None])
        hstate, pi, value = network.apply(params, prev_hstate, ac_in)
        action = pi.sample(seed=_rng).squeeze(0).squeeze(0)

        # STEP ENV
        rng, _rng = jax.random.split(rng)
        obsv, env_state, reward, done, info = env.step(
            _rng, prev_env_state, action, env_params
        )

        # ONLY GET REWARDS FROM FIRST META-EPISODE
        temp_r = running_r[prev_env_state.trial_num] + (nn.relu(reward)) * (1.0 - ever_done)
        running_r = running_r.at[prev_env_state.trial_num].set(temp_r)

        # ONLY FINAL HIDDEN STATE AND ACTION
        replace_hstate = jnp.logical_and((prev_env_state.trial_num == env.num_trials - 2), ~ever_done)
        running_hstate = running_hstate * (1.0 - replace_hstate) + hstate * replace_hstate
        running_obs = running_obs * (1.0 - replace_hstate) + obsv * replace_hstate

        ever_done = jnp.logical_or(done, ever_done)
        transition = (env_state, action, reward, running_r, running_hstate, running_obs)
        runner_state = (env_state, env_params, obsv, done, hstate, rng, ever_done, running_r, running_hstate, running_obs)
        return runner_state, transition

    runner_state = (env_state, env_params, last_obs, False, init_state, rng, False, jnp.zeros((env.num_trials,)), init_state, last_obs)
    runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, env.num_cities*env.num_trials+1)
    return traj_batch



In [None]:
rng = jax.random.PRNGKey(12)
test = jax.jit(jax.vmap(init_eval_params_and_return, in_axes=(0, None, None)))
infos = test(jax.random.split(rng, popsize), outs[0][0].params, init_env_params)

In [None]:
num_generations = 8
trial_scores = []
for i in range(num_generations):
    last_trial_scores = infos[-3][:,-1,-1]
    last_trial_scores_idx = jnp.argsort(-last_trial_scores) # Flipped since temperature is worse in pre-training.
    saved_scores = infos[-3][:,-1,:][last_trial_scores_idx]
    saved_states = infos[-2][:,-1,:][last_trial_scores_idx]
    saved_obs = infos[-1][:,-1,:][last_trial_scores_idx]
    trial_scores.append(saved_scores.mean(0))
    print(saved_scores.mean(0))

    env_params = EnvParams(
        other_params=outs[0][0].params,
        other_init_state=saved_states,
        other_init_last_obs=saved_obs,
    )

    rng, _rng = jax.random.split(rng)
    test = jax.jit(jax.vmap(eval_params_and_return, in_axes=(0, None, None)))
    infos = test(jax.random.split(_rng, popsize), outs[0][0].params, env_params)

trial_scores.append(saved_scores.mean(0))

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm

colors = cm.viridis([i/num_generations for i in range(num_generations)])  # Use the viridis colormap

for i, (m, color) in enumerate(zip(trial_scores, colors)):
    plt.plot(jnp.arange(i*num_trials,i*num_trials+num_trials), m, label=f"Generation {i}", color=color, alpha=0.5)

plt.legend()
plt.ylabel("Return")
plt.xlabel("Trial")
plt.title("MemSeq")
plt.show()
