Dependencies

In [1]:
!pip install gymnax
!pip install distrax
!pip install ogbench
!pip install git+https://github.com/riiswa/pointax.git
!pip install ml_collections

Collecting gymnax
  Downloading gymnax-0.0.9-py3-none-any.whl.metadata (19 kB)
Downloading gymnax-0.0.9-py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.6/86.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: gymnax
Successfully installed gymnax-0.0.9
Collecting distrax
  Downloading distrax-0.1.5-py3-none-any.whl.metadata (13 kB)
Downloading distrax-0.1.5-py3-none-any.whl (319 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m319.7/319.7 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: distrax
Successfully installed distrax-0.1.5
Collecting ogbench
  Downloading ogbench-1.1.5-py3-none-any.whl.metadata (946 bytes)
Collecting mujoco>=3.1.6 (from ogbench)
  Downloading mujoco-3.3.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m2.3 MB/s[0m eta [36m0:

Wrapper

In [2]:
from gymnax.environments import environment, spaces
from pointax.types import EnvState, EnvParams
import jax

class PMwrapper(environment.Environment[EnvState, EnvParams]):
    def __init__(self, pm_env):
        super().__init__()
        self.pm_env = pm_env

    def default_params(self):
        # return EnvParams()
        return self.pm_env.default_params

    def step_env(self, key, state, action, params):
        obs, state, reward, done, info = self.pm_env.step_env(key, state, action, params)
        n_obs = obs[:4]
        goal = obs[4:]
        info["goal_position"] = goal
        return n_obs, state, reward, done, info

    def reset_env(self, key, params):
        obs, state = self.pm_env.reset_env(key, params)
        obs = obs[:4]
        return obs, state

    def get_obs(self, state, params=None, key=None):
        obs = self.pm_env.get_obs(state, params)
        obs = obs[:4]
        return obs

    def name(self):
        suffix = "Dense" if self.pm_env.reward_type_str == "dense" else ""
        return f"pointax/PointMaze_{self.pm_env.maze_id}{suffix}"

    def num_actions(self):
        return 2

    def action_space(self, params=None):
        return self.pm_env.action_space(params)

    def observation_space(self, params):
        return spaces.Box(low=-jnp.inf, high=jnp.inf, shape=(4,), dtype=jnp.float32)


In [3]:
import pointax
env = PMwrapper(pointax.make_umaze(reward_type="sparse"))
params = env.default_params()

# Reset and step
key = jax.random.PRNGKey(42)
obs, state = env.reset_env(key, params)

action = jax.numpy.array([0.5, 0.0])  # Move right
obs, state, reward, done, info = env.step_env(key, state, action, params)

print(obs, obs.shape)
print(f"Reward: {reward}, Success: {info['is_success']}")

[1.0201304  0.90668106 0.5        0.        ] (4,)
Reward: 0.0, Success: False


# Exploration

In [4]:
from gymnax.experimental import RolloutWrapper
# action = self.model_forward(policy_params, obs, rng_net)
import functools
import gymnax
from typing import Union,Optional,Any
import abc

import jax
import jax.numpy as jnp
from flax import nnx
import pointax

class UnsupervisedExplorer(nnx.Module):

    @abc.abstractmethod
    def update(self,obs,actions,next_obs,dones,info):
      #update variable parameters
        return #{"kl":KL} MI = E KL

    @abc.abstractmethod
    def __call__(self,observations,rng):

        return #actions, {"mi":mi_matrix}

from gymnax.environments.environment import Environment
class CustomRolloutWrapper:
    """Wrapper to define batch evaluation for generation parameters."""

    def __init__(
        self,
        env_or_name: Union[str,Environment] = "Pendulum-v1",
        num_env_steps: Optional[int] = None,
        env_kwargs: Any | None = None,
        env_params: Any | None = None,
    ):
        """Wrapper to define batch evaluation for generation parameters."""
        # Define the RL environment & network forward function
        if env_kwargs is None:
            env_kwargs = {}
        if env_params is None:
            env_params = {}
        if isinstance(env_or_name,Environment):
            self.env = env_or_name
            self.env_params = env_or_name.default_params
        else:
            # Umaze
            self.env = PMwrapper(pointax.make_umaze(reward_type="sparse"))
            self.env_params = self.env.default_params()
        self.env_params = self.env_params.replace(**env_params)

        if num_env_steps is None:
            self.num_env_steps = self.env_params.max_steps_in_episode
        else:
            self.num_env_steps = num_env_steps

#    @functools.partial(nnx.jit, static_argnums=(0,))
    def batch_reset(self,rng_input):
        batch_reset = jax.vmap(self.single_reset_state)
        return batch_reset(rng_input)

 #   @functools.partial(nnx.jit, static_argnums=(0,))
    def single_reset_state(self,rng_input):
        rng_reset, rng_episode = jax.random.split(rng_input)
        obs, state = self.env.reset(rng_reset, self.env_params)
        return state

   # @functools.partial(nnx.jit, static_argnums=(0,4))
    def batch_rollout(self, rng_eval, model:UnsupervisedExplorer,
                      env_state=None,num_steps=1):
        """Evaluate a generation of networks on RL/Supervised/etc. task."""
        # vmap over different MC fitness evaluations for single network
        batch_rollout = jax.vmap(self.single_rollout, in_axes=(0, None,0,None))
        return batch_rollout(rng_eval, model, env_state,num_steps)

    # @functools.partial(nnx.jit, static_argnums=(0,4))
    def single_rollout(self, rng_input, model:UnsupervisedExplorer,
                       env_state=None,num_steps=1):
        """Rollout a pendulum episode with lax.scan."""
        # Reset the environment
        rng_reset, rng_episode = jax.random.split(rng_input)

        if env_state is None:
            obs, env_state = self.env.reset(rng_reset, self.env_params)
        else:
            obs = self.env.get_obs(env_state)

        def policy_step(state_input, _):
            """lax.scan compatible step transition in jax env."""
            obs, state,  rng, cum_reward, valid_mask = state_input
            rng, rng_step, rng_net = jax.random.split(rng, 3)
            if model is not None:
                action,info = model(obs, rng_net)
            else:
                action = self.env.action_space(self.env_params).sample(rng_net)
                info = {}
        #    print ("policy step action",action.shape)
            next_obs, next_state, reward, done, step_info = self.env.step(
                rng_step, state, action, self.env_params
            )
            info.update(step_info)
            new_cum_reward = cum_reward + reward * valid_mask
            new_valid_mask = valid_mask * (1 - done)
            carry = [
                next_obs,
                next_state,
                rng,
                new_cum_reward,
                new_valid_mask,
            ]
            y = [obs, action, reward, next_obs, done, state, info]
            return carry, y

        # Scan over episode step loop
        carry_out, scan_out = jax.lax.scan(
            policy_step,
            [
                obs,
                env_state,
                rng_episode,
                jnp.array([0.0]),
                jnp.array([1.0]),
            ],
            (),
            num_steps,
        )
        # Return the sum of rewards accumulated by agent in episode rollout
        obs, action, reward, next_obs, done, state, info = scan_out
        cum_return = carry_out[-2]
        info["last_state"] = carry_out[1]
        return obs, action, reward, next_obs, done,state, info, cum_return

class UnsupervisedRolloutWrapper(CustomRolloutWrapper):

 #   @functools.partial(nnx.jit, static_argnums=(0,))
    def batch_update(self, rng_update,model, obs, action,next_obs,done,info):
        if model is None: return {}
        return model.update(rng_update,obs, action,next_obs,done,info)


In [5]:
import jax
import jax.numpy as jnp
import jax.nn as nn
from flax import nnx

jnp.set_printoptions(precision=3,suppress=True)
from flax.training import train_state
from jax.scipy.special import gamma,digamma, gammaln, kl_div

def batch_random_split(batch_key,num=2):
    split_keys = jax.vmap(jax.random.split,in_axes=(0,None))(batch_key,num)
    return [split_keys[:, i]  for i in range(num) ]
@jax.jit
def compute_info_gain_normal(mean,prec,l_prec, next_obs):
    """
    mean: (batch, obs_dim)
    prec: (batch, obs_dim)  N(u;mean(s,a),(prec(s,a))^-0.5) N(next_obs;u,(l_prec(s,a))^-0.5)
    l_prec: (batch, obs_dim)    likelihood_precision
    next_obs: (batch, obs_dim)

    output: (batch)
    """

    prec = jnp.maximum(prec, 1e-6)
    posteior_prec = prec + l_prec
    prec_ratio = prec / posteior_prec

    posterior_mean = (prec * mean + next_obs * l_prec) /posteior_prec

    delta_mean =  next_obs - posterior_mean

    kl  = delta_mean * delta_mean * prec   #* ( l_prec / posteior_prec ) ** 2
    kl = kl + prec_ratio - jnp.log(prec_ratio) - 1
    kl = 0.5 * jnp.sum(kl,axis=-1)
    return kl, delta_mean

@jax.jit
def compute_expected_info_gain_normal(prec,l_prec):
    """
    prec: (batch, obs_dim)
    l_prec: (batch, obs_dim)    likelihood_precision

    output: (batch)
    """


    prec = jnp.maximum(prec, 1e-6)
    prec_ratio = l_prec / prec
    mi_matrix = 0.5 * jnp.sum( jnp.log(1+prec_ratio),axis=-1)
    return mi_matrix


class JointEncoder(nnx.Module):
    def __init__(self, hidden_dims: int, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(hidden_dims,hidden_dims,rngs=rngs)
        self.linear2 = nnx.Linear(hidden_dims,hidden_dims,rngs=rngs)
        self.layer_norm0 = nnx.LayerNorm(hidden_dims,rngs=rngs)
        self.layer_norm1 = nnx.LayerNorm(hidden_dims,rngs=rngs)
        self.layer_norm2 = nnx.LayerNorm(hidden_dims,rngs=rngs)
        self.layer_norm3 = nnx.LayerNorm(hidden_dims,rngs=rngs)

    def __call__(self, x: jax.Array,rng):
        dist_distrax =  distrax.MultivariateNormalDiag(x,1e-1*jnp.ones_like(x))
        x = dist_distrax.sample(seed=rng, sample_shape=())
        x = self.layer_norm0(x)
        h0 = self.linear1(x)
        h = nn.relu(h0)
        h = self.layer_norm1(h) +h0
        h0 = self.linear2(h)
        h = self.layer_norm2(h)+h0
        return  self.layer_norm3(h)

class Encoder(nnx.Module):
    def __init__(self, input_dim: int, hidden_dims: int, rngs: nnx.Rngs):
        self.linear = nnx.Linear(input_dim,hidden_dims,rngs=rngs)
        self.layer_norm0 = nnx.LayerNorm(hidden_dims,rngs=rngs)

    def __call__(self, x: jax.Array):
        h = self.linear(x)
        return  self.layer_norm0(h)

class ActionEncoder(nnx.Module):
    def __init__(self, num_actions: int, hidden_dims: int, rngs: nnx.Rngs):
        self.embed = nnx.Embed(num_actions,hidden_dims,rngs=rngs)
        self.layer_norm0 = nnx.LayerNorm(hidden_dims,rngs=rngs)
    def __call__(self, x: jax.Array):
        return  self.layer_norm0(self.embed(x))

from jax import lax
import distrax


class MLP(nnx.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, rngs):
        self.linear1 = nnx.Linear(input_dim, hidden_dim, rngs=rngs)
        self.linear2 = nnx.Linear(hidden_dim, output_dim, rngs=rngs)

    def __call__(self, x):
        h = jax.nn.relu(self.linear1(x))
        return self.linear2(h)

class Actor(nnx.Module):
    log_std_min: float = -4
    log_std_max: float = 2

    def __init__(self, obs_dim, action_dim,hidden_dim, rngs: nnx.Rngs):

        self.linear = nnx.Linear(obs_dim, hidden_dim, rngs=rngs)

        self.mean = nnx.Linear(hidden_dim, action_dim, rngs=rngs)
        self.log_std = nnx.Linear(hidden_dim, action_dim, rngs=rngs)

   #     self.linear1 = nnx.Linear(hidden_dim, action_dim, rngs=rngs)

    def __call__(self, x: jnp.ndarray):
        x = self.linear(x)
        mean = self.mean(x)
        log_std = self.log_std(x)
        log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
        # probably need stablize here
        pi = distrax.MultivariateNormalDiag(mean, jnp.exp(log_std))
        return pi

class Critic(nnx.Module):
  def __init__(self, obs_dim, action_dim,hidden_dim, rngs: nnx.Rngs):
    self.ln = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
    self.act = nnx.relu(hidden_dim, 1, rngs=rngs)

  def __call__(self, x: jnp.array):
    x = self.ln(x)
    x = self.act(x)

    return x

class TwinCritic(nnx.Module):
    def __init__(self, input_dim, hidden_dim, rngs: nnx.Rngs):
        self.trunk1 = nnx.Linear(input_dim, hidden_dim, rngs=rngs)
        self.trunk2 = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
        self.q1 = nnx.Linear(hidden_dim, 1, rngs=rngs)
        self.q2 = nnx.Linear(hidden_dim, 1, rngs=rngs)

    def __call__(self, x: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
        h = jax.nn.relu(self.trunk1(x))
        h = jax.nn.relu(self.trunk2(h))
        q1 = self.q1(h)
        q2 = self.q2(h)
        return q1, q2


class Likelihood_Prec(nnx.Module):
    log_std_min: float = -2
    log_std_max: float = 2

    def __init__(self, obs_dim, hidden_dim, rngs: nnx.Rngs):

        self.linear = nnx.Linear(hidden_dim, obs_dim, rngs=rngs)
   #     self.linear1 = nnx.Linear(hidden_dim, action_dim, rngs=rngs)

    def __call__(self, x: jnp.ndarray):
        log_std = self.linear(x)
        log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
        return jnp.exp(-log_std)


def show_variable(model,text):

    graphdef, params, vars,others = nnx.split(model, nnx.Param, nnx.Variable,...)

    print(text,vars)



ddpg

In [6]:
class DDPGExplorer(UnsupervisedExplorer):
  def __init__(self, obs_dim, action_dim, hidden_dim, rngs) -> None:
      self.trainable_actor = Actor(obs_dim, action_dim, hidden_dim, rngs)
      self.trainable_critic = TwinCritic(obs_dim + action_dim, hidden_dim, rngs)
      self.trainable_critic_target = TwinCritic(obs_dim + action_dim, hidden_dim, rngs)

  def __call__(self, observations, rng):
    pi = self.trainable_actor(observations)
    # if eval_mode:
    #   action = pi.mean()
    # else:
    #   action = pi.sample(seed=rng)
    action = pi.sample(seed=rng)
    return action

  def update(self,rng,obs,actions,next_obs,dones,info):
    return {}

  def batch_critic_loss(self, rng, obs, actions, next_obs, dones, info):
    rng, rng_act = jax.random.split(rng)
    pi = self.trainable_actor(obs)
    next_actions = pi.sample(seed=rng_act)
    # if clip_action:
    #   nexy_actions = jnp.clip(next_actions, -1.0, 1.0)
    reward = info["reward"]
    discount = info["discount"]

    tq1, tq2 = self.trainable_critic_target(jnp.concatenate([next_obs, next_actions], axis=-1))
    target_v = jnp.minimum(tq1, tq2)
    target_q = reward + discount * target_v

    q1, q2 = self.trainable_critic(jnp.concatenate([obs, actions], axis=-1))
    mse1 = (q1 - target_q) ** 2
    mse2 = (q2 - target_q) ** 2
    loss = jnp.mean(mse1 + mse2)

    # metrics = {
    #     "critic_loss": loss,
    #     "critic_q1": jnp.mean(q1),
    #     "critic_q2": jnp.mean(q2),
    #     "critic_target_q": jnp.mean(target_q),
    # }

    return loss

  def batch_actor_loss(self, rng, obs, actions, next_obs, dones, info):
    rng, rng_act = jax.random.split(rng)
    pi = self.trainable_actor(obs)
    actions = pi.sample(seed=rng_act)
    q1, q2 = self.trainable_critic(jnp.concatenate([obs, actions], axis=-1))
    q = jnp.minimum(q1, q2)
    loss = -jnp.mean(q)

    # # 可选指标
    # log_prob = pi.log_prob(action)
    # if log_prob.ndim == 2:
    #     log_prob = jnp.sum(log_prob, axis=-1, keepdims=True)
    # try:
    #     ent = pi.entropy()
    #     if ent.ndim == 2:
    #         ent = jnp.sum(ent, axis=-1, keepdims=True)
    #     ent_mean = jnp.mean(ent)
    # except Exception:
    #     ent_mean = jnp.nan

    # metrics = {
    #     "actor_loss": loss,
    #     "actor_logprob": jnp.mean(log_prob),
    #     "actor_ent": ent_mean,
    # }
    # return loss, metrics, rng
    return loss



class Disagreement(nnx.Module):
  def __init__(self, obs_dim, action_dim, hidden_dim, rngs, n_models=5):
    self.n_models = n_models
    self.obs_dim = obs_dim

    self.ensemble = []
    for i in range(n_models):
      model_rngs = nnx.Rngs(i)
      model = MLP(obs_dim + action_dim, hidden_dim, obs_dim, model_rngs)
      self.ensemble.append(model)

  def __call__(self, obs, actions, next_obs):
    errors = []
    inputs = jnp.concatenate([obs, actions], axis=-1)

    for model in self.ensemble:
      next_obs_hat = model(inputs)
      model_err = jnp.linalg.norm(next_obs - next_obs_hat, axis=-1, keepdims=True)
      errors.append(model_err)

    return jnp.concatenate(errors, axis=1)

  def get_disagreement(self, obs, actions, next_obs):
    preds = []
    inputs = jnp.concatenate([obs, actions], axis=-1)

    for model in self.ensemble:
      next_obs_hat = model(inputs)
      preds.append(next_obs_hat)

    preds = jnp.stack(preds, axis=0)
    return jnp.var(preds, axis=0).mean(axis=-1)



class DisagreementExplorer(UnsupervisedExplorer):
  def __init__(self, obs_dim, action_dim, hidden_dim, rngs):
    self.trainable_actor = Actor(obs_dim, action_dim, hidden_dim, rngs)
    self.trainable_critic = TwinCritic(obs_dim + action_dim, hidden_dim, rngs)
    self.trainable_critic_target = TwinCritic(obs_dim + action_dim, hidden_dim, rngs)
    self.trainable_disagreement = Disagreement(obs_dim, action_dim, hidden_dim, rngs)

    self.obs_dim = obs_dim
    self.action_dim = action_dim
    self.hidden_dim = hidden_dim


  def dissagreement_loss(self, obs, actions, next_obs):
    error = self.trainable_disagreement(obs, actions, next_obs)
    loss = jnp.mean(error)
    return loss

  def compute_intr_reward(self, obs, actions, next_obs):
    reward = self.disagreement.get_disagreement(obs, actions, next_obs)
    return jnp.expand_dims(reward, axis=-1)

  def batch_critic_loss_with_intrinsic(self, rng, obs, actions, next_obs, dones, info):
    reward = self.compute_intr_reward(obs, actions, next_obs)

    discount = info["discount"]
    rng, rng_act = jax.random.split(rng)
    pi = self.trainable_actor(obs)
    next_actions = pi.sample(seed=rng_act)


    tq1, tq2 = self.trainable_critic_target(jnp.concatenate([next_obs, next_actions], axis=-1))
    target_v = jnp.minimum(tq1, tq2)
    target_q = reward + discount * target_v

    q1, q2 = self.trainable_critic(jnp.concatenate([obs, actions], axis=-1))
    mse1 = (q1 - target_q) ** 2
    mse2 = (q2 - target_q) ** 2
    loss = jnp.mean(mse1 + mse2)

    return loss

  def batch_actor_loss_with_intrinsic(self, rng, obs, actions, next_obs, dones, info):
    rng, rng_act = jax.random.split(rng)
    pi = self.trainable_actor(obs)
    next_actions = pi.sample(seed=rng_act)

    q1, q2 = self.trainable_critic(jnp.concatenate([obs, actions], axis=-1))
    q = jnp.minimum(q1, q2)
    loss = -jnp.mean(q)

    return loss


random & deepbayesian

In [24]:
class RandomExplorer(UnsupervisedExplorer):

    def __init__(self, action_dim):
        self.action_dim = action_dim

    def update(self,rng,obs,action,next_obs,done,info):
      #update variable parameters
        return {} #MI = E KL

    def __call__(self,observations,rng):
        if observations.ndim == 1:
            actions = jax.random.uniform(rng, shape=(self.action_dim,), minval=-1.0, maxval=1.0)
            return actions, {}
        actions = jax.random.uniform(rng, shape=(observations.shape[0], self.action_dim), minval=-1.0, maxval=1.0)
        return actions, {}

class DBActor(nnx.Module):
    log_std_min: float = -4
    log_std_max: float = 2

    def __init__(self, obs_dim, action_dim,hidden_dim, rngs: nnx.Rngs):

        self.mean = nnx.Linear(hidden_dim, action_dim, rngs=rngs)
        self.log_std = nnx.Linear(hidden_dim, action_dim, rngs=rngs)

   #     self.linear1 = nnx.Linear(hidden_dim, action_dim, rngs=rngs)

    def __call__(self, x: jnp.ndarray):
        print ("x",x.shape)
        mean = self.mean(x)
        print ("mean",mean.shape)
        log_std = self.log_std(x)
        print ("log_std",log_std.shape)
        log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
        return mean, log_std

class DBJointEncoder(nnx.Module):
    def __init__(self, hidden_dims: int, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(hidden_dims,hidden_dims,rngs=rngs)
        self.linear2 = nnx.Linear(hidden_dims,hidden_dims,rngs=rngs)
        self.layer_norm0 = nnx.LayerNorm(hidden_dims,rngs=rngs)
        self.layer_norm1 = nnx.LayerNorm(hidden_dims,rngs=rngs)
        self.layer_norm2 = nnx.LayerNorm(hidden_dims,rngs=rngs)
        self.layer_norm3 = nnx.LayerNorm(hidden_dims,rngs=rngs)

    def __call__(self, x: jax.Array):
        x = self.layer_norm0(x)
        h = self.linear1(x)
        h = nn.relu(h)
        h = self.layer_norm1(h)
        h = self.linear2(h)
        h = self.layer_norm2(h)
        return  self.layer_norm3(nn.relu(h + x))


class DeepBayesianExplorer(UnsupervisedExplorer):

    def __init__(self, obs_dim, action_dim,hidden_dim, rngs: nnx.Rngs
                ,l_prec=1.0,weight_decay=1e-2,ent_lambda=1e-3,depth=2):
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.prec_w = nnx.Variable(jnp.zeros((hidden_dim, obs_dim)))
        self.mean_w = nnx.Variable(jnp.zeros((hidden_dim, obs_dim)))

        self.trainable_likelihood_prec = Likelihood_Prec(obs_dim,hidden_dim,rngs)

        self.trainable_actor = DBActor(obs_dim, action_dim,hidden_dim, rngs=rngs)

        self.weight_decay = weight_decay
        self.obs_embeds = Encoder(obs_dim,hidden_dim,rngs)
        self.action_embeds = Encoder(action_dim,hidden_dim,rngs)
        self.joint_embeds =DBJointEncoder(hidden_dim,rngs)

        self.ent_lambda = ent_lambda

    def update(self,rng,obs,action,next_obs,done,info):
    #   next_obs = next_obs["observation"]
        mean = info["mean"]
        prec = info["prec"]
        l_prec = jnp.clip( 1 / jnp.pow(mean - next_obs,2), max=10)

        deepkl, delta_mean = compute_info_gain_normal(mean,prec,l_prec, next_obs)

        def _likelihood_loss(rng, T,mean, prec,next_obs):


            # . x embed_size
            l_prec = self.trainable_likelihood_prec(T)

            mu = mean
            sigma = jnp.sqrt( 1 / l_prec + 1 / prec)

            dist_distrax =  distrax.MultivariateNormalDiag(mu,sigma)

            dist_distrax.log_prob(next_obs)

            return - dist_distrax.log_prob(next_obs), l_prec

        predictive_loss, l_prec = _likelihood_loss(rng, info["T"],mean, prec,next_obs)

        # mean_error = mean - next_obs
        # mean_error = mean_error * mean_error
        # mean_error = jnp.sum(mean_error,axis=-1)
        deepkl, delta_mean = compute_info_gain_normal(mean,prec,l_prec, next_obs)
        #batch x  num_hidden
        T = info["T"].reshape(-1,self.hidden_dim)

        #batch x  obs_dim
        l_prec = l_prec.reshape(-1,self.obs_dim)
        delta_mean = delta_mean.reshape(-1,self.obs_dim)

        # jax.debug.print("{}", T_theta)
        T_T = jnp.transpose(T)

        covariance = T @ T_T
        inv_covariance = jnp.linalg.pinv(covariance)

        T_Map =  T_T @ inv_covariance

        delta_precW = T_Map @ l_prec
        self.prec_w.value = (self.prec_w.value + delta_precW) * (1-self.weight_decay)

        delta_meanW = T_Map @ delta_mean
        self.mean_w.value = (self.mean_w.value + delta_meanW) * (1-self.weight_decay)

        return {"kl":deepkl,"predictive_loss":predictive_loss}

  # @nnx.jit
    def loss(self,rng, obs,action,next_obs,done,info):
      #  next_obs = next_obs["observation"]
        def _likelihood_loss(rng,T,mean, prec,next_obs):


            # . x embed_size
            l_prec = self.trainable_likelihood_prec(T)

            mu = mean
            sigma = jnp.sqrt( 1 / l_prec + 1 / prec)

            dist_distrax =  distrax.MultivariateNormalDiag(mu,sigma)

            return -dist_distrax.log_prob(next_obs) #.sum(-1)

        def _sac_loss(rng, obs_embed):
            #  num_actions
            mean, log_std = self.trainable_actor(obs_embed)

            dist_distrax =  distrax.MultivariateNormalDiag(mean,jnp.exp(log_std))
            actions = dist_distrax.sample(seed=rng, sample_shape=())

            # . x embed_size
            action_embed = self.action_embeds(actions)#.value
            #. x embed_size
            embed = action_embed+obs_embed

            # . x embed_size
            T = self.joint_embeds(embed)
            prec = T @ self.prec_w
            l_prec = self.trainable_likelihood_prec(T)

            MI = compute_expected_info_gain_normal(prec,l_prec)

            print ("MI",MI.shape)
            print ("dist_distrax.entropy()",dist_distrax.entropy().shape)
            return - MI - self.ent_lambda* dist_distrax.entropy()

        rng_sac ,rng_likelihood = jax.random.split(rng, 2)
        obs_embed = info["obs_embed"]
        sac_loss = _sac_loss(rng_sac,obs_embed)
        print ("sac_loss",sac_loss)
        T, mean, prec = info["T"],info["mean"],info["prec"]
        likelihood_loss = _likelihood_loss(rng_likelihood,T,mean, prec,next_obs)
        print ("likelihood_loss",likelihood_loss)
        return sac_loss + likelihood_loss

    def batch_loss(self,rng, obs,action,next_obs,done,info):
        vmapped = jax.vmap(self.loss)
        return vmapped(rng, obs,action,next_obs,done,info)

    def __call__(self,observations,rng):

        # obs_dim
        obs_embed = self.obs_embeds(observations)#.squeeze()

        print ("obs_embed",obs_embed.shape)
        #  num_actions
        mean, log_std = self.trainable_actor(obs_embed)

        print ("mean",mean.shape)
        print ("log_std",log_std.shape)
        dist_distrax =  distrax.MultivariateNormalDiag(mean*0,jnp.exp(log_std))
        actions = dist_distrax.sample(seed=rng, sample_shape=())

        print ("actions",actions.shape)
        # . x embed_size
        action_embed = self.action_embeds(actions)#.value
        print ("action_embed",action_embed.shape)
        #. x embed_size
        embed = action_embed+obs_embed

        # . x embed_size
        T = self.joint_embeds(embed)
        print ("T",T.shape)
        prec = jnp.maximum(T @ self.prec_w ,1e-3)
        mean = T @ self.mean_w
        l_prec = self.trainable_likelihood_prec(T)

        MI = compute_expected_info_gain_normal(prec,l_prec)

        return actions, {"mi":MI,"T":T,"obs_embed":obs_embed,"l_prec":l_prec,
                         "prec":prec,"mean":mean}







ppo

In [9]:
import jax
import jax.numpy as jnp
from flax import nnx
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal

class ActorCritic(nnx.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim, depth, rngs: nnx.Rngs):
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim

        # 共享的特征提取层
        self.feature_extractor = [
            nnx.Linear(obs_dim, hidden_dim, rngs=rngs),
            nnx.tanh
        ]
        for _ in range(depth-1):
            self.feature_extractor.append(nnx.Linear(hidden_dim, hidden_dim, rngs=rngs))
            self.feature_extractor.append(nnx.tanh)

        # Actor头 (策略)
        self.actor_mean = nnx.Linear(hidden_dim, action_dim, rngs=rngs)
        self.actor_logstd = nnx.Linear(hidden_dim, action_dim, rngs=rngs)


        # Critic头 (值函数)
        self.critic_head = nnx.Linear(hidden_dim, 1, rngs=rngs)

    def __call__(self, x):
        for layer in self.feature_extractor:
            x = layer(x)
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd(x)
        action_std = jnp.exp(jnp.clip(action_logstd, -20, 2))
        value = self.critic_head(x)

        pi = distrax.MultivariateNormalDiag(action_mean, action_std)

        return pi, jnp.squeeze(value, axis=-1)



class PPOExplorer(UnsupervisedExplorer):
  def __init__(self, obs_dim,
                     action_dim,
                     hidden_dim,
                     rngs: nnx.Rngs,
                     depth:int = 2,
                     gamma: float = 0.99,
                     gae_lambda: float = 0.95,
                     clip_eps: float = 0.2,
                     ent_coef: float = 0.01,
                     vf_coef: float = 0.5,
                     max_grad_norm: float = 0.5,
                     num_steps: int = 128,
                     num_envs: int = 4,
                     lr: float = 2.5e-4):

    self.obs_dim = obs_dim
    self.action_dim = action_dim
    self.hidden_dim = hidden_dim
    self.depth = depth
    self.gamma = gamma
    self.gae_lambda = gae_lambda
    self.clip_eps = clip_eps
    self.ent_coef = ent_coef
    self.vf_coef = vf_coef
    self.max_grad_norm = max_grad_norm
    self.num_steps = num_steps
    self.num_envs = num_envs

    self.trainable_network = ActorCritic(obs_dim, action_dim,hidden_dim, depth, rngs)
    # self.optimizer = nnx.Optimizer(self.trainable_network, optax.adam(1e-3), wrt=nnx.Param)

  def __call__(self, observations, rng):
    pi, value = self.trainable_network(observations)

    action = pi.sample(seed=rng)
    log_prob = pi.log_prob(action)

    return action, {"log_prob": log_prob, "value": value}

  def update(self,rng,obs,actions,next_obs,dones,info):
    return {}

  def batch_loss(self, rng, obs, actions, next_obs, dones, info):

    _, value = self.trainable_network(obs)
    reward = info["reward"]
    transition = (obs, actions, next_obs, dones, reward, info)

    #GAE
    def _calculate_gae(transition, value):
      def _get_advantages(gae_and_next_value, transition):
        gae, next_value = gae_and_next_value
        obs, actions, next_obs, dones, reward, info = transition
        delta = reward + self.gamma * next_value * (1 - dones) - info["value"]
        gae = (delta + self.gamma * self.gae_lambda * (1 - dones) * gae)
        return (gae, value), gae

      _, advantages = jax.lax.scan(
            _get_advantages,
            (jnp.zeros_like(info["value"]), info["value"]),
            transition,
            reverse=True,
            unroll=16,
      )
      return advantages, advantages + info["value"]

    gae, targets = _calculate_gae(transition, value)

    #Rerun Network
    pi, next_value = self.trainable_network(next_obs)
    log_prob = pi.log_prob(actions)

    # value loss
    value_pred_clipped = value + (next_value - value).clip(-self.clip_eps, self.clip_eps)
    value_losses = jnp.square(next_value - targets)
    value_losses_clipped = jnp.square(value_pred_clipped - targets)
    value_loss = (0.5 * jnp.maximum(value_losses, value_losses_clipped).mean())

    # actor loss
    ratio = jnp.exp(log_prob - info["log_prob"])
    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
    loss_actor1 = ratio * gae
    loss_actor2 = (
        jnp.clip(
            ratio,
            1.0 - self.clip_eps,
            1.0 + self.clip_eps,
        )
        * gae
    )
    loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
    loss_actor = loss_actor.mean()
    entropy = pi.entropy().mean()

    total_loss = loss_actor + self.vf_coef * value_loss - self.ent_coef * entropy

    return total_loss


Algorithm

In [12]:
import jax
import jax.numpy as jnp
jnp.set_printoptions(precision=2,suppress=True)
from jax.scipy.special import digamma, gammaln, kl_div
import flax.linen as nn
import numpy as np
import optax
import time
import flax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any, Dict
import distrax
import gymnax
import functools
from gymnax.environments import spaces
from gymnax.wrappers import FlattenObservationWrapper, LogWrapper
import matplotlib.pyplot as plt

import matplotlib.pyplot as plt

import optax
from flax.nnx.helpers import TrainState

class MyTrainState(TrainState):
    vars: nnx.Variable
    others: nnx.State

    @property
    def need_train(self):
        return len(self.params) > 0

is_trainable = lambda path, node: (
    node.type == nnx.Param and \
    any('trainable' in p_elem for p_elem in path if isinstance(p_elem, str))
)

def train_state_from_model(model,tx=optax.adam(0.02)):
    graphdef, trainable_params, vars, others = nnx.split(model,is_trainable, nnx.Variable,...)
    print(trainable_params)

    state = MyTrainState.create(
      tx=tx,
      graphdef=graphdef,
      params=trainable_params,
      vars=vars,
      others=others,
    )
    return state

def train_state_update_model(model,state):
    graphdef, trainable_params, vars, others = nnx.split(model,is_trainable, nnx.Variable,...)
    return state.replace(vars=vars,others=others)

def model_from_train_state(state):
    return nnx.merge(state.graphdef, state.params, state.vars,state.others)
# prompt: draw heatmap given sequence of states for MountainCar
#state.position, state.velocity
import matplotlib.pyplot as plt

def reshape(arr):
    if arr.ndim < 3:
        raise ValueError("Input array must have at least 3 dimensions (n, b, c, ...).")

    # Get the original shape components
    n, b, c, *x_dims = arr.shape

    # Transpose the first two axes (n, b) to (b, n)
    # We construct the axes tuple dynamically for flexibility
    transpose_axes = (1, 0) + tuple(range(2, arr.ndim))
    transposed_arr = jnp.transpose(arr, axes=transpose_axes)

    # Reshape into (b, n*c, x0, x1, ...)
    new_shape = (b, n * c, *x_dims)
    reshaped_arr = jnp.reshape(transposed_arr, new_shape)

    return reshaped_arr

from typing import List, Any

# Define a type alias for PyTree for better readability
PyTree = Any
def unpack_pytree_by_first_index(pytree: PyTree) -> List[PyTree]:
    """
    Unpacks a PyTree of JAX arrays along their first dimension (id).

    This function assumes that all JAX arrays within the PyTree
    have a consistent first dimension (the 'id' dimension) and that
    you want to create a separate PyTree for each 'id'.

    Args:
        pytree: A JAX PyTree where the leaves are JAX arrays
                with a leading 'id' dimension.

    Returns:
        A list of PyTrees, where each PyTree corresponds to a single
        'id' from the original PyTree.
    """
    # Get the size of the first dimension from any leaf array
    # We assume all arrays have the same first dimension size.
    first_leaf = jax.tree_util.tree_leaves(pytree)[0]
    num_ids = first_leaf.shape[0]

    # Create a list to store the unpacked PyTrees
    unpacked_pytrees = []

    # Iterate through each ID
    for i in range(num_ids):
        # Use tree_map to slice each array in the PyTree at the current ID
        sliced_pytree = jax.tree_util.tree_map(lambda x: x[i], pytree)
        unpacked_pytrees.append(sliced_pytree)

    return unpacked_pytrees
def unpack_states(pytree):
    return unpack_pytree_by_first_index(jax.tree.map(reshape, pytree))
def draw_mountain_car_heatmap(state,config = {}):
    """
    Draws a heatmap representing the trajectory of the MountainCar environment.

    Args:
        state_sequence: A sequence of JAX arrays representing the states
                        of the MountainCar environment. Each state is expected
                        to be a 2-element array [position, velocity].
                        ['CartPole-v1',"MountainCar-v0","Acrobot-v1"]
    """
    title = config["ENV_NAME"] +' MountainCar Heatmap ' +config["MODEL_NAME"]

    plt.figure(figsize=(10, 6))
    if config["ENV_NAME"] == "MountainCar-v0":

        positions = state.position
        velocities = state.velocity

        plt.scatter(positions, velocities, c=range(len(state.time )), cmap='viridis', s=10)
        plt.colorbar(label='Time Steps')
        plt.xlabel('Position')
        plt.ylabel('Velocity')
        plt.grid(True)
    elif config["ENV_NAME"] == "CartPole-v1":
        x = state.x
        theta = state.theta
        plt.scatter(x, theta, c=range(len(state.time )), cmap='viridis', s=10)
        plt.colorbar(label='Time Steps')
        plt.xlabel('x')
        plt.ylabel('theta')
        plt.grid(True)
    elif config["ENV_NAME"] == "Acrobot-v1":
        joint_angle1 = state.joint_angle1
        joint_angle2 = state.joint_angle2
        plt.scatter(joint_angle1, joint_angle2, c=range(len(state.time )), cmap='viridis', s=10)
        plt.colorbar(label='Time Steps')
        plt.xlabel('Angle1')
        plt.ylabel('Angle2')
        plt.grid(True)
    if "TOTAL_TIMESTEPS" in config:
        title += "_TOTAL_TIMESTEPS_"+str(config["TOTAL_TIMESTEPS"])
    if "DEPTH" in config:
        title += "_DEPTH_"+str(config["DEPTH"])
    if "NUM_HIDDEN" in config:
        title += "_NUM_HIDDEN_"+str(config["NUM_HIDDEN"])
    plt.title(title)
    plt.savefig(title.replace(" ","_")+'.pdf', format='pdf', dpi=300, bbox_inches='tight')
    plt.show()
    return plt


# NUM_UPDATES x NUM_ENVS x NUM_STEPS
class Transition(NamedTuple):
    obs: jnp.ndarray
    action: jnp.ndarray
    reward: jnp.ndarray
    next_obs: jnp.ndarray
    done: jnp.ndarray
    info: {}




def make_train(config):

    config["NUM_UPDATES"] = (config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"]// config["NUM_ENVS"])

    rng = jax.random.PRNGKey(config["SEED"])
    rng_batch = jax.random.split(rng, config["NUM_ENVS"])

    manager = UnsupervisedRolloutWrapper(config["ENV_NAME"])
    # it serves as action_dim which is 2 now
    num_actions = manager.env.num_actions()
    obs_dim = manager.env.observation_space(manager.env_params).shape[0]


    low = manager.env.observation_space(manager.env_params).low
    high = manager.env.observation_space(manager.env_params).high

    print ("low",low)
    print ("high",high )
    if config["MODEL_NAME"] == "DeepBayesianExplorer":
        model = DeepBayesianExplorer(obs_dim, num_actions,config["NUM_HIDDEN"],
                                    nnx.Rngs(config["SEED"]),weight_decay=config["WD"],depth=config["DEPTH"])

    if config["MODEL_NAME"] == "PPOExplorer":
        model = PPOExplorer(obs_dim,num_actions,config["NUM_HIDDEN"],nnx.Rngs(config["SEED"]),)

    if config["MODEL_NAME"] == "RandomExplorer":
        model = RandomExplorer(num_actions)

    @nnx.jit
    def _train_step(state:MyTrainState, rng_loss, obs, action,next_obs,done,info):

      def loss_fn(graphdef,params,vars,others):
        model = nnx.merge(graphdef, params, vars,others)
        return model.batch_loss(rng_loss,obs, action,next_obs,done,info).mean()

      def opt_step(state,unused):
        grads = jax.grad(loss_fn,1)(state.graphdef, state.params, state.vars,state.others)
        return state.apply_gradients(grads=grads),None
      state, _ = jax.lax.scan(opt_step, state, None, config["OPT_STEPS"])

      return state
    @nnx.jit
    def _rollout_and_update_step(runner_state, unused):
        # we have to use train_state for jax.lax.scan
        train_state,  rng_batch,last_state = runner_state

        model = model_from_train_state(train_state)
        rng_batch, rng_step,rng_update,rng_loss = batch_random_split(rng_batch,4)

        rollout_results = manager.batch_rollout( rng_batch,model,env_state=last_state,num_steps =  config["NUM_STEPS"])
        obs, action, reward, next_obs, done,state,info, cum_ret = rollout_results

        # obs: num_envs x
        transition = Transition(obs, action, reward, next_obs, done,info)

        last_state = info["last_state"]
        info["reward"] = reward
        update_info = manager.batch_update(rng_update, model,obs, action,next_obs,done,info)
        info.update(update_info)
        train_state = train_state_update_model(model,train_state)

        if train_state.need_train:
            train_state = _train_step(train_state, rng_loss, obs, action,next_obs,done,info)

        #works for tensors
        runner_state = (train_state, rng_batch,last_state)
        return runner_state, (transition, state)

    def train(rng_batch,model,manager):
        # training loop

        rng_batch,  rng_reset = batch_random_split(rng_batch, 2)
        start_state = manager.batch_reset(rng_reset)

        if config["TX"] == "adamw":
            tx = optax.adamw(config["LR"])
        elif config["TX"] == "sgd":
            tx = optax.sgd(config["LR"])
        else:
            tx = None
            assert False, config["TX"] + " is not avaliable"
        train_state = train_state_from_model(model,tx)
      #  rng, _rng = jax.random.split(rng)
        runner_state = (train_state,  rng_batch,start_state)
        runner_state, output= jax.lax.scan(_rollout_and_update_step, runner_state, None, config["NUM_UPDATES"])

        transitions,states = output
        return {"runner_state": runner_state, "transitions": transitions,"states":states}
        # return {"runner_state": runner_state, "collect_data": collect_data, "max_mi_history": max_mi_history}

    return train,model, manager,rng_batch

def experiment(config):
    print(config)
    train_fn,model, manager,rng_batch = make_train(config)
    train_jit = nnx.jit(train_fn)

 #   show_variable(model,"explorer before")
    out = jax.block_until_ready(train_fn(rng_batch,model,manager))
    #data shape: rollout groups = [TOTAL_TIMESTEPS//NUM_ENVS //NUM_STEPS] x NUM_ENVS x NUM_STEPS
    print("data shape:", jax.tree_util.tree_map(lambda x: x.shape, out["transitions"]))

    train_state,  rng_batch, last_state = out["runner_state"]

    model = model_from_train_state(train_state)
    #print ("model",model)



    # if "mi" in out["transitions"].info:
    # # Create figure and axis
    #     plt.figure(figsize=(10, 6))
    #     # Sample JAX NumPy arrays (replace these with your actual arrays)
    #     #  print (out["transitions"].info)
    #     eig_array = out["transitions"].info["mi"].reshape(-1)
    #     big_array = out["transitions"].info["kl"].reshape(-1)
    #     # Plot both arrays
    #     plt.plot(eig_array, label='EIG', marker='o', linestyle='-', color='blue')
    #     plt.plot(big_array, label='BIG', marker='s', linestyle='-', color='red')

    #     if "smi" in out["transitions"].info:
    #         smi_array = out["transitions"].info["smi"].reshape(-1)
    #         plt.plot(smi_array, label='SMI', marker='^', linestyle='-', color='green')

    #     # Add labels and title
    #     plt.xlabel('Num of Updates')
    #     plt.ylabel('Information Gain')
    #     Title = "InfoGains for "+  config["MODEL_NAME"]
    #     Title = Title + "Total InfoGains" +"{:10.4f}".format(big_array.sum().item())
    #     Title = Title +  " with Seed" +str(config["SEED"])
    #     plt.title(Title)

    #     # Add grid and legend
    #     plt.grid(alpha=0.3)
    #     plt.legend()
    #     # Show the plot
    #   # plt.ylim(0, 40)
    #     plt.tight_layout()
    #     plt.savefig(Title.replace(" ","_")+'.pdf', format='pdf', dpi=300, bbox_inches='tight')
    #     plt.show()
    # if "l_prec" in  out["transitions"].info:
    #     l_prec_mean = out["transitions"].info["l_prec"].mean(axis=(1,2,3),keepdims=False)
    #  #   prec_mean = out["transitions"].info["prec"].mean(axis=(1,2,3),keepdims=False)
    #     mean_error = out["transitions"].info["mean_error"].mean(axis=(1,2),keepdims=False)

    #     # Create figure and axis
    #     plt.figure(figsize=(10, 6))

    #     # Plot both arrays
    #     plt.plot(l_prec_mean, label='l_prec', marker='o', linestyle='-', color='blue')
    # #    plt.plot(prec_mean, label='prec', marker='s', linestyle='-', color='red')
    #     plt.plot(mean_error, label='mean_error', marker='p', linestyle='-', color='yellow')

    #     # Add labels and title
    #     plt.xlabel('Num of Updates')
    #     plt.ylabel('Mean Precision')
    #     Title = "Comparison of Mean Precisions"

    #     plt.title(Title)

    #     # Add grid and legend
    #     plt.grid(alpha=0.3)
    #     plt.legend()
    #     # Show the plot
    #     plt.tight_layout()
    #     plt.savefig(Title.replace(" ","_")+'.pdf', format='pdf', dpi=300, bbox_inches='tight')
    #     plt.show()

    # draw_mountain_car_heatmap( unpack_states(out["states"])[0],config)
    return out
'''

result = {}
pdfs = []
#for i in [8,16,32,64,128,256]:
  #  result[i] = {}
for MODEL_NAME in ["BayesianConjugate-v1","DeepBayesianConjugate-v1","DynamicSACBayesianExplorer-v1",
                   "DeepSACBayesianConjugate-v1","DeepRandomBayesianConjugate-v1"]:
    config["MODEL_NAME"] = MODEL_NAME
    result[MODEL_NAME] =[]
    for seed in range(5):
        config["SEED"] = 423+seed
        out , big ,pdf = experiment(config)
        result[MODEL_NAME].append(big)
        pdfs.append(pdf)
'''


'\n\nresult = {}\npdfs = []\n#for i in [8,16,32,64,128,256]:\n  #  result[i] = {}\nfor MODEL_NAME in ["BayesianConjugate-v1","DeepBayesianConjugate-v1","DynamicSACBayesianExplorer-v1",\n                   "DeepSACBayesianConjugate-v1","DeepRandomBayesianConjugate-v1"]:\n    config["MODEL_NAME"] = MODEL_NAME\n    result[MODEL_NAME] =[]\n    for seed in range(5):\n        config["SEED"] = 423+seed\n        out , big ,pdf = experiment(config)\n        result[MODEL_NAME].append(big)\n        pdfs.append(pdf)\n'

In [None]:
# NUM_UPDATES x NUM_ENVS x NUM_STEPS
class Transition(NamedTuple):
    obs: jnp.ndarray
    action: jnp.ndarray
    reward: jnp.ndarray
    next_obs: jnp.ndarray
    done: jnp.ndarray
    info: {}

env_name = 'Umaze'  # @param ["Umaze"] {"type":"raw"}
NUM_ENVS = 4 # @param [1,2,4,8,16,32] {"type":"raw"}
TOTAL_TIMESTEPS = 2048 # @param [2048,16384,131072,1048576] {"type":"raw"}
DEPTH = 1 # @param [1,2,4] {"type":"raw"}
NUM_STEPS = 16 # @param [1,2,4,8,16] {"type":"raw"}
NUM_HIDDEN = 128 # @param [32,64,128,256] {"type":"raw"}
WD = 0.1 # @param [0,0.1,0.01,0.001] {"type":"raw"}
MODEL_NAME = "DeepBayesianExplorer"  #@param ["DeepBayesianExplorer","RandomExplorer","PPOExplorer"]
config = {
    "NUM_ENVS": NUM_ENVS,    #
    "WD": WD,
    "NUM_STEPS": NUM_STEPS,   #steps of roll out between update
    "SAC_D_STEPS": 4,
    "ENV_NAME":env_name,
    "SAC_STEP_SIZE": 1.0,
    "SEED": 423,         #highly stochastic
    "TOTAL_TIMESTEPS": TOTAL_TIMESTEPS,   #total steps for all envs
    "NUM_HIDDEN":NUM_HIDDEN,
    "TX":"adamw",
    "DEPTH":DEPTH,
    "LR":2e-4,
    "OPT_STEPS":8,
    "MODEL_NAME": MODEL_NAME,
    "DEBUG": False,
}


out = experiment(config)

## SAC

### utils

In [None]:
import functools
import glob
import os
import pickle
from typing import Any, Dict, Mapping, Sequence

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax

nonpytree_field = functools.partial(flax.struct.field, pytree_node=False)


class ModuleDict(nn.Module):
    """A dictionary of modules.

    This allows sharing parameters between modules and provides a convenient way to access them.

    Attributes:
        modules: Dictionary of modules.
    """

    modules: Dict[str, nn.Module]

    @nn.compact
    def __call__(self, *args, name=None, **kwargs):
        """Forward pass.

        For initialization, call with `name=None` and provide the arguments for each module in `kwargs`.
        Otherwise, call with `name=<module_name>` and provide the arguments for that module.
        """
        if name is None:
            if kwargs.keys() != self.modules.keys():
                raise ValueError(
                    f'When `name` is not specified, kwargs must contain the arguments for each module. '
                    f'Got kwargs keys {kwargs.keys()} but module keys {self.modules.keys()}'
                )
            out = {}
            for key, value in kwargs.items():
                if isinstance(value, Mapping):
                    out[key] = self.modules[key](**value)
                elif isinstance(value, Sequence):
                    out[key] = self.modules[key](*value)
                else:
                    out[key] = self.modules[key](value)
            return out

        return self.modules[name](*args, **kwargs)


class TrainState(flax.struct.PyTreeNode):
    """Custom train state for models.

    Attributes:
        step: Counter to keep track of the training steps. It is incremented by 1 after each `apply_gradients` call.
        apply_fn: Apply function of the model.
        model_def: Model definition.
        params: Parameters of the model.
        tx: optax optimizer.
        opt_state: Optimizer state.
    """

    step: int
    apply_fn: Any = nonpytree_field()
    model_def: Any = nonpytree_field()
    params: Any
    tx: Any = nonpytree_field()
    opt_state: Any

    @classmethod
    def create(cls, model_def, params, tx=None, **kwargs):
        """Create a new train state."""
        if tx is not None:
            opt_state = tx.init(params)
        else:
            opt_state = None

        return cls(
            step=1,
            apply_fn=model_def.apply,
            model_def=model_def,
            params=params,
            tx=tx,
            opt_state=opt_state,
            **kwargs,
        )

    def __call__(self, *args, params=None, method=None, **kwargs):
        """Forward pass.

        When `params` is not provided, it uses the stored parameters.

        The typical use case is to set `params` to `None` when you want to *stop* the gradients, and to pass the current
        traced parameters when you want to flow the gradients. In other words, the default behavior is to stop the
        gradients, and you need to explicitly provide the parameters to flow the gradients.

        Args:
            *args: Arguments to pass to the model.
            params: Parameters to use for the forward pass. If `None`, it uses the stored parameters, without flowing
                the gradients.
            method: Method to call in the model. If `None`, it uses the default `apply` method.
            **kwargs: Keyword arguments to pass to the model.
        """
        if params is None:
            params = self.params
        variables = {'params': params}
        if method is not None:
            method_name = getattr(self.model_def, method)
        else:
            method_name = None

        return self.apply_fn(variables, *args, method=method_name, **kwargs)

    def select(self, name):
        """Helper function to select a module from a `ModuleDict`."""
        return functools.partial(self, name=name)

    def apply_gradients(self, grads, **kwargs):
        """Apply the gradients and return the updated state."""
        updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
        new_params = optax.apply_updates(self.params, updates)

        return self.replace(
            step=self.step + 1,
            params=new_params,
            opt_state=new_opt_state,
            **kwargs,
        )

    def apply_loss_fn(self, loss_fn):
        """Apply the loss function and return the updated state and info.

        It additionally computes the gradient statistics and adds them to the dictionary.
        """
        grads, info = jax.grad(loss_fn, has_aux=True)(self.params)

        grad_max = jax.tree_util.tree_map(jnp.max, grads)
        grad_min = jax.tree_util.tree_map(jnp.min, grads)
        grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grads)

        grad_max_flat = jnp.concatenate([jnp.reshape(x, -1) for x in jax.tree_util.tree_leaves(grad_max)], axis=0)
        grad_min_flat = jnp.concatenate([jnp.reshape(x, -1) for x in jax.tree_util.tree_leaves(grad_min)], axis=0)
        grad_norm_flat = jnp.concatenate([jnp.reshape(x, -1) for x in jax.tree_util.tree_leaves(grad_norm)], axis=0)

        final_grad_max = jnp.max(grad_max_flat)
        final_grad_min = jnp.min(grad_min_flat)
        final_grad_norm = jnp.linalg.norm(grad_norm_flat, ord=1)

        info.update(
            {
                'grad/max': final_grad_max,
                'grad/min': final_grad_min,
                'grad/norm': final_grad_norm,
            }
        )

        return self.apply_gradients(grads=grads), info


def save_agent(agent, save_dir, epoch):
    """Save the agent to a file.

    Args:
        agent: Agent.
        save_dir: Directory to save the agent.
        epoch: Epoch number.
    """

    save_dict = dict(
        agent=flax.serialization.to_state_dict(agent),
    )
    save_path = os.path.join(save_dir, f'params_{epoch}.pkl')
    with open(save_path, 'wb') as f:
        pickle.dump(save_dict, f)

    print(f'Saved to {save_path}')


def restore_agent(agent, restore_path, restore_epoch):
    """Restore the agent from a file.

    Args:
        agent: Agent.
        restore_path: Path to the directory containing the saved agent.
        restore_epoch: Epoch number.
    """
    candidates = glob.glob(restore_path)

    assert len(candidates) == 1, f'Found {len(candidates)} candidates: {candidates}'

    restore_path = candidates[0] + f'/params_{restore_epoch}.pkl'

    with open(restore_path, 'rb') as f:
        load_dict = pickle.load(f)

    agent = flax.serialization.from_state_dict(agent, load_dict['agent'])

    print(f'Restored from {restore_path}')

    return agent




# network
from typing import Any, Optional, Sequence

import distrax
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp


def default_init(scale=1.0):
    """Default kernel initializer."""
    return nn.initializers.variance_scaling(scale, 'fan_avg', 'uniform')


def ensemblize(cls, num_qs, out_axes=0, **kwargs):
    """Ensemblize a module."""
    return nn.vmap(
        cls,
        variable_axes={'params': 0},
        split_rngs={'params': True},
        in_axes=None,
        out_axes=out_axes,
        axis_size=num_qs,
        **kwargs,
    )


class Identity(nn.Module):
    """Identity layer."""

    def __call__(self, x):
        return x


class MLP(nn.Module):
    """Multi-layer perceptron.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        activations: Activation function.
        activate_final: Whether to apply activation to the final layer.
        kernel_init: Kernel initializer.
        layer_norm: Whether to apply layer normalization.
    """

    hidden_dims: Sequence[int]
    activations: Any = nn.gelu
    activate_final: bool = False
    kernel_init: Any = default_init()
    layer_norm: bool = False

    @nn.compact
    def __call__(self, x):
        for i, size in enumerate(self.hidden_dims):
            x = nn.Dense(size, kernel_init=self.kernel_init)(x)
            if i + 1 < len(self.hidden_dims) or self.activate_final:
                x = self.activations(x)
                if self.layer_norm:
                    x = nn.LayerNorm()(x)
        return x


class LengthNormalize(nn.Module):
    """Length normalization layer.

    It normalizes the input along the last dimension to have a length of sqrt(dim).
    """

    @nn.compact
    def __call__(self, x):
        return x / jnp.linalg.norm(x, axis=-1, keepdims=True) * jnp.sqrt(x.shape[-1])


class Param(nn.Module):
    """Scalar parameter module."""

    init_value: float = 0.0

    @nn.compact
    def __call__(self):
        return self.param('value', init_fn=lambda key: jnp.full((), self.init_value))


class LogParam(nn.Module):
    """Scalar parameter module with log scale."""

    init_value: float = 1.0

    @nn.compact
    def __call__(self):
        log_value = self.param('log_value', init_fn=lambda key: jnp.full((), jnp.log(self.init_value)))
        return jnp.exp(log_value)


class TransformedWithMode(distrax.Transformed):
    """Transformed distribution with mode calculation."""

    def mode(self):
        return self.bijector.forward(self.distribution.mode())


class RunningMeanStd(flax.struct.PyTreeNode):
    """Running mean and standard deviation.

    Attributes:
        eps: Epsilon value to avoid division by zero.
        mean: Running mean.
        var: Running variance.
        clip_max: Clip value after normalization.
        count: Number of samples.
    """

    eps: Any = 1e-6
    mean: Any = 1.0
    var: Any = 1.0
    clip_max: Any = 10.0
    count: int = 0

    def normalize(self, batch):
        batch = (batch - self.mean) / jnp.sqrt(self.var + self.eps)
        batch = jnp.clip(batch, -self.clip_max, self.clip_max)
        return batch

    def unnormalize(self, batch):
        return batch * jnp.sqrt(self.var + self.eps) + self.mean

    def update(self, batch):
        batch_mean, batch_var = jnp.mean(batch, axis=0), jnp.var(batch, axis=0)
        batch_count = len(batch)

        delta = batch_mean - self.mean
        total_count = self.count + batch_count

        new_mean = self.mean + delta * batch_count / total_count
        m_a = self.var * self.count
        m_b = batch_var * batch_count
        m_2 = m_a + m_b + delta**2 * self.count * batch_count / total_count
        new_var = m_2 / total_count

        return self.replace(mean=new_mean, var=new_var, count=total_count)


class GCActor(nn.Module):
    """Goal-conditioned actor.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        action_dim: Action dimension.
        log_std_min: Minimum value of log standard deviation.
        log_std_max: Maximum value of log standard deviation.
        tanh_squash: Whether to squash the action with tanh.
        state_dependent_std: Whether to use state-dependent standard deviation.
        const_std: Whether to use constant standard deviation.
        final_fc_init_scale: Initial scale of the final fully-connected layer.
        gc_encoder: Optional GCEncoder module to encode the inputs.
    """

    hidden_dims: Sequence[int]
    action_dim: int
    log_std_min: Optional[float] = -5
    log_std_max: Optional[float] = 2
    tanh_squash: bool = False
    state_dependent_std: bool = False
    const_std: bool = True
    final_fc_init_scale: float = 1e-2
    gc_encoder: nn.Module = None

    def setup(self):
        self.actor_net = MLP(self.hidden_dims, activate_final=True)
        self.mean_net = nn.Dense(self.action_dim, kernel_init=default_init(self.final_fc_init_scale))
        if self.state_dependent_std:
            self.log_std_net = nn.Dense(self.action_dim, kernel_init=default_init(self.final_fc_init_scale))
        else:
            if not self.const_std:
                self.log_stds = self.param('log_stds', nn.initializers.zeros, (self.action_dim,))

    def __call__(
        self,
        observations,
        goals=None,
        goal_encoded=False,
        temperature=1.0,
    ):
        """Return the action distribution.

        Args:
            observations: Observations.
            goals: Goals (optional).
            goal_encoded: Whether the goals are already encoded.
            temperature: Scaling factor for the standard deviation.
        """
        if self.gc_encoder is not None:
            inputs = self.gc_encoder(observations, goals, goal_encoded=goal_encoded)
        else:
            inputs = [observations]
            if goals is not None:
                inputs.append(goals)
            inputs = jnp.concatenate(inputs, axis=-1)
        outputs = self.actor_net(inputs)

        means = self.mean_net(outputs)
        if self.state_dependent_std:
            log_stds = self.log_std_net(outputs)
        else:
            if self.const_std:
                log_stds = jnp.zeros_like(means)
            else:
                log_stds = self.log_stds

        log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)

        distribution = distrax.MultivariateNormalDiag(loc=means, scale_diag=jnp.exp(log_stds) * temperature)
        if self.tanh_squash:
            distribution = TransformedWithMode(distribution, distrax.Block(distrax.Tanh(), ndims=1))

        return distribution


class GCDiscreteActor(nn.Module):
    """Goal-conditioned actor for discrete actions.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        action_dim: Action dimension.
        final_fc_init_scale: Initial scale of the final fully-connected layer.
        gc_encoder: Optional GCEncoder module to encode the inputs.
    """

    hidden_dims: Sequence[int]
    action_dim: int
    final_fc_init_scale: float = 1e-2
    gc_encoder: nn.Module = None

    def setup(self):
        self.actor_net = MLP(self.hidden_dims, activate_final=True)
        self.logit_net = nn.Dense(self.action_dim, kernel_init=default_init(self.final_fc_init_scale))

    def __call__(
        self,
        observations,
        goals=None,
        goal_encoded=False,
        temperature=1.0,
    ):
        """Return the action distribution.

        Args:
            observations: Observations.
            goals: Goals (optional).
            goal_encoded: Whether the goals are already encoded.
            temperature: Inverse scaling factor for the logits (set to 0 to get the argmax).
        """
        if self.gc_encoder is not None:
            inputs = self.gc_encoder(observations, goals, goal_encoded=goal_encoded)
        else:
            inputs = [observations]
            if goals is not None:
                inputs.append(goals)
            inputs = jnp.concatenate(inputs, axis=-1)
        outputs = self.actor_net(inputs)

        logits = self.logit_net(outputs)

        distribution = distrax.Categorical(logits=logits / jnp.maximum(1e-6, temperature))

        return distribution


class GCValue(nn.Module):
    """Goal-conditioned value/critic function.

    This module can be used for both value V(s, g) and critic Q(s, a, g) functions.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        layer_norm: Whether to apply layer normalization.
        ensemble: Whether to ensemble the value function.
        gc_encoder: Optional GCEncoder module to encode the inputs.
    """

    hidden_dims: Sequence[int]
    layer_norm: bool = True
    ensemble: bool = True
    gc_encoder: nn.Module = None

    def setup(self):
        mlp_module = MLP
        if self.ensemble:
            mlp_module = ensemblize(mlp_module, 2)
        value_net = mlp_module((*self.hidden_dims, 1), activate_final=False, layer_norm=self.layer_norm)

        self.value_net = value_net

    def __call__(self, observations, goals=None, actions=None):
        """Return the value/critic function.

        Args:
            observations: Observations.
            goals: Goals (optional).
            actions: Actions (optional).
        """
        if self.gc_encoder is not None:
            inputs = [self.gc_encoder(observations, goals)]
        else:
            inputs = [observations]
            if goals is not None:
                inputs.append(goals)
        if actions is not None:
            inputs.append(actions)
        inputs = jnp.concatenate(inputs, axis=-1)

        v = self.value_net(inputs).squeeze(-1)

        return v


class GCDiscreteCritic(GCValue):
    """Goal-conditioned critic for discrete actions."""

    action_dim: int = None

    def __call__(self, observations, goals=None, actions=None):
        actions = jnp.eye(self.action_dim)[actions]
        return super().__call__(observations, goals, actions)


class GCBilinearValue(nn.Module):
    """Goal-conditioned bilinear value/critic function.

    This module computes the value function as V(s, g) = phi(s)^T psi(g) / sqrt(d) or the critic function as
    Q(s, a, g) = phi(s, a)^T psi(g) / sqrt(d), where phi and psi output d-dimensional vectors.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        latent_dim: Latent dimension.
        layer_norm: Whether to apply layer normalization.
        ensemble: Whether to ensemble the value function.
        value_exp: Whether to exponentiate the value. Useful for contrastive learning.
        state_encoder: Optional state encoder.
        goal_encoder: Optional goal encoder.
    """

    hidden_dims: Sequence[int]
    latent_dim: int
    layer_norm: bool = True
    ensemble: bool = True
    value_exp: bool = False
    state_encoder: nn.Module = None
    goal_encoder: nn.Module = None

    def setup(self):
        mlp_module = MLP
        if self.ensemble:
            mlp_module = ensemblize(mlp_module, 2)

        self.phi = mlp_module((*self.hidden_dims, self.latent_dim), activate_final=False, layer_norm=self.layer_norm)
        self.psi = mlp_module((*self.hidden_dims, self.latent_dim), activate_final=False, layer_norm=self.layer_norm)

    def __call__(self, observations, goals, actions=None, info=False):
        """Return the value/critic function.

        Args:
            observations: Observations.
            goals: Goals.
            actions: Actions (optional).
            info: Whether to additionally return the representations phi and psi.
        """
        if self.state_encoder is not None:
            observations = self.state_encoder(observations)
        if self.goal_encoder is not None:
            goals = self.goal_encoder(goals)

        if actions is None:
            phi_inputs = observations
        else:
            phi_inputs = jnp.concatenate([observations, actions], axis=-1)

        phi = self.phi(phi_inputs)
        psi = self.psi(goals)

        v = (phi * psi / jnp.sqrt(self.latent_dim)).sum(axis=-1)

        if self.value_exp:
            v = jnp.exp(v)

        if info:
            return v, phi, psi
        else:
            return v


class GCDiscreteBilinearCritic(GCBilinearValue):
    """Goal-conditioned bilinear critic for discrete actions."""

    action_dim: int = None

    def __call__(self, observations, goals=None, actions=None, info=False):
        actions = jnp.eye(self.action_dim)[actions]
        return super().__call__(observations, goals, actions, info)


class GCMRNValue(nn.Module):
    """Metric residual network (MRN) value function.

    This module computes the value function as the sum of a symmetric Euclidean distance and an asymmetric
    L^infinity-based quasimetric.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        latent_dim: Latent dimension.
        layer_norm: Whether to apply layer normalization.
        encoder: Optional state/goal encoder.
    """

    hidden_dims: Sequence[int]
    latent_dim: int
    layer_norm: bool = True
    encoder: nn.Module = None

    def setup(self):
        self.phi = MLP((*self.hidden_dims, self.latent_dim), activate_final=False, layer_norm=self.layer_norm)

    def __call__(self, observations, goals, is_phi=False, info=False):
        """Return the MRN value function.

        Args:
            observations: Observations.
            goals: Goals.
            is_phi: Whether the inputs are already encoded by phi.
            info: Whether to additionally return the representations phi_s and phi_g.
        """
        if is_phi:
            phi_s = observations
            phi_g = goals
        else:
            if self.encoder is not None:
                observations = self.encoder(observations)
                goals = self.encoder(goals)
            phi_s = self.phi(observations)
            phi_g = self.phi(goals)

        sym_s = phi_s[..., : self.latent_dim // 2]
        sym_g = phi_g[..., : self.latent_dim // 2]
        asym_s = phi_s[..., self.latent_dim // 2 :]
        asym_g = phi_g[..., self.latent_dim // 2 :]
        squared_dist = ((sym_s - sym_g) ** 2).sum(axis=-1)
        quasi = jax.nn.relu((asym_s - asym_g).max(axis=-1))
        v = jnp.sqrt(jnp.maximum(squared_dist, 1e-12)) + quasi

        if info:
            return v, phi_s, phi_g
        else:
            return v


class GCIQEValue(nn.Module):
    """Interval quasimetric embedding (IQE) value function.

    This module computes the value function as an IQE-based quasimetric.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        latent_dim: Latent dimension.
        dim_per_component: Dimension of each component in IQE (i.e., number of intervals in each group).
        layer_norm: Whether to apply layer normalization.
        encoder: Optional state/goal encoder.
    """

    hidden_dims: Sequence[int]
    latent_dim: int
    dim_per_component: int
    layer_norm: bool = True
    encoder: nn.Module = None

    def setup(self):
        self.phi = MLP((*self.hidden_dims, self.latent_dim), activate_final=False, layer_norm=self.layer_norm)
        self.alpha = Param()

    def __call__(self, observations, goals, is_phi=False, info=False):
        """Return the IQE value function.

        Args:
            observations: Observations.
            goals: Goals.
            is_phi: Whether the inputs are already encoded by phi.
            info: Whether to additionally return the representations phi_s and phi_g.
        """
        alpha = jax.nn.sigmoid(self.alpha())
        if is_phi:
            phi_s = observations
            phi_g = goals
        else:
            if self.encoder is not None:
                observations = self.encoder(observations)
                goals = self.encoder(goals)
            phi_s = self.phi(observations)
            phi_g = self.phi(goals)

        x = jnp.reshape(phi_s, (*phi_s.shape[:-1], -1, self.dim_per_component))
        y = jnp.reshape(phi_g, (*phi_g.shape[:-1], -1, self.dim_per_component))
        valid = x < y
        xy = jnp.concatenate(jnp.broadcast_arrays(x, y), axis=-1)
        ixy = xy.argsort(axis=-1)
        sxy = jnp.take_along_axis(xy, ixy, axis=-1)
        neg_inc_copies = jnp.take_along_axis(valid, ixy % self.dim_per_component, axis=-1) * jnp.where(
            ixy < self.dim_per_component, -1, 1
        )
        neg_inp_copies = jnp.cumsum(neg_inc_copies, axis=-1)
        neg_f = -1.0 * (neg_inp_copies < 0)
        neg_incf = jnp.concatenate([neg_f[..., :1], neg_f[..., 1:] - neg_f[..., :-1]], axis=-1)
        components = (sxy * neg_incf).sum(axis=-1)
        v = alpha * components.mean(axis=-1) + (1 - alpha) * components.max(axis=-1)

        if info:
            return v, phi_s, phi_g
        else:
            return v

### agent

In [None]:
import copy
from typing import Any

import flax
import jax
import jax.numpy as jnp
import ml_collections
import optax

# from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
# from utils.networks import GCActor, GCValue, LogParam


class SACAgent(flax.struct.PyTreeNode):
    """Soft actor-critic (SAC) agent."""

    rng: Any
    network: Any
    config: Any = nonpytree_field()

    def critic_loss(self, batch, grad_params, rng):
        """Compute the SAC critic loss."""
        next_dist = self.network.select('actor')(batch['next_observations'])
        next_actions, next_log_probs = next_dist.sample_and_log_prob(seed=rng)

        next_qs = self.network.select('target_critic')(batch['next_observations'], actions=next_actions)
        if self.config['min_q']:
            next_q = jnp.min(next_qs, axis=0)
        else:
            next_q = jnp.mean(next_qs, axis=0)

        target_q = batch['rewards'] + self.config['discount'] * batch['masks'] * next_q
        target_q = target_q - self.config['discount'] * batch['masks'] * next_log_probs * self.network.select('alpha')()

        q = self.network.select('critic')(batch['observations'], actions=batch['actions'], params=grad_params)
        critic_loss = jnp.square(q - target_q).mean()

        return critic_loss, {
            'critic_loss': critic_loss,
            'q_mean': q.mean(),
            'q_max': q.max(),
            'q_min': q.min(),
        }

    def actor_loss(self, batch, grad_params, rng):
        """Compute the SAC actor loss."""
        # Actor loss.
        dist = self.network.select('actor')(batch['observations'], params=grad_params)
        actions, log_probs = dist.sample_and_log_prob(seed=rng)

        qs = self.network.select('critic')(batch['observations'], actions=actions)
        if self.config['min_q']:
            q = jnp.min(qs, axis=0)
        else:
            q = jnp.mean(qs, axis=0)

        actor_loss = (log_probs * self.network.select('alpha')() - q).mean()

        # Entropy loss.
        alpha = self.network.select('alpha')(params=grad_params)
        entropy = -jax.lax.stop_gradient(log_probs).mean()
        alpha_loss = (alpha * (entropy - self.config['target_entropy'])).mean()

        total_loss = actor_loss + alpha_loss

        if self.config['tanh_squash']:
            action_std = dist._distribution.stddev()
        else:
            action_std = dist.stddev().mean()

        return total_loss, {
            'total_loss': total_loss,
            'actor_loss': actor_loss,
            'alpha_loss': alpha_loss,
            'alpha': alpha,
            'entropy': -log_probs.mean(),
            'std': action_std.mean(),
        }

    @jax.jit
    def total_loss(self, batch, grad_params, rng=None):
        """Compute the total loss."""
        info = {}
        rng = rng if rng is not None else self.rng

        rng, actor_rng, critic_rng = jax.random.split(rng, 3)

        critic_loss, critic_info = self.critic_loss(batch, grad_params, critic_rng)
        for k, v in critic_info.items():
            info[f'critic/{k}'] = v

        actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
        for k, v in actor_info.items():
            info[f'actor/{k}'] = v

        loss = critic_loss + actor_loss
        return loss, info

    def target_update(self, network, module_name):
        """Update the target network."""
        new_target_params = jax.tree_util.tree_map(
            lambda p, tp: p * self.config['tau'] + tp * (1 - self.config['tau']),
            self.network.params[f'modules_{module_name}'],
            self.network.params[f'modules_target_{module_name}'],
        )
        network.params[f'modules_target_{module_name}'] = new_target_params

    @jax.jit
    def update(self, batch):
        """Update the agent and return a new agent with information dictionary."""
        new_rng, rng = jax.random.split(self.rng)

        def loss_fn(grad_params):
            return self.total_loss(batch, grad_params, rng=rng)

        new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
        self.target_update(new_network, 'critic')

        return self.replace(network=new_network, rng=new_rng), info

    @jax.jit
    def sample_actions(
        self,
        observations,
        goals=None,
        seed=None,
        temperature=1.0,
    ):
        """Sample actions from the actor."""
        dist = self.network.select('actor')(observations, goals, temperature=temperature)
        actions = dist.sample(seed=seed)
        actions = jnp.clip(actions, -1, 1)
        return actions

    @classmethod
    def create(
        cls,
        seed,
        ex_observations,
        ex_actions,
        config,
    ):
        """Create a new agent.

        Args:
            seed: Random seed.
            ex_observations: Example batch of observations.
            ex_actions: Example batch of actions.
            config: Configuration dictionary.
        """
        rng = jax.random.PRNGKey(seed)
        rng, init_rng = jax.random.split(rng, 2)

        action_dim = ex_actions.shape[-1]

        if config['target_entropy'] is None:
            config['target_entropy'] = -config['target_entropy_multiplier'] * action_dim

        # Define critic and actor networks.
        critic_def = GCValue(
            hidden_dims=config['value_hidden_dims'],
            layer_norm=config['layer_norm'],
            ensemble=True,
        )

        actor_def = GCActor(
            hidden_dims=config['actor_hidden_dims'],
            action_dim=action_dim,
            log_std_min=-5,
            tanh_squash=config['tanh_squash'],
            state_dependent_std=config['state_dependent_std'],
            const_std=False,
            final_fc_init_scale=config['actor_fc_scale'],
        )

        # Define the dual alpha variable.
        alpha_def = LogParam()

        network_info = dict(
            critic=(critic_def, (ex_observations, None, ex_actions)),
            target_critic=(copy.deepcopy(critic_def), (ex_observations, None, ex_actions)),
            actor=(actor_def, (ex_observations, None)),
            alpha=(alpha_def, ()),
        )
        networks = {k: v[0] for k, v in network_info.items()}
        network_args = {k: v[1] for k, v in network_info.items()}

        network_def = ModuleDict(networks)
        network_tx = optax.adam(learning_rate=config['lr'])
        network_params = network_def.init(init_rng, **network_args)['params']
        network = TrainState.create(model_def = network_def, params = network_params, tx=network_tx)

        params = network.params
        params['modules_target_critic'] = params['modules_critic']

        return cls(rng, network=network, config=flax.core.FrozenDict(**config))


def get_config():
    config = ml_collections.ConfigDict(
        dict(
            agent_name='sac',  # Agent name.
            lr=1e-4,  # Learning rate.
            batch_size=256,  # Batch size.
            actor_hidden_dims=(256, 256),  # Actor network hidden dimensions.
            value_hidden_dims=(256, 256),  # Value network hidden dimensions.
            layer_norm=False,  # Whether to use layer normalization.
            discount=0.99,  # Discount factor.
            tau=0.005,  # Target network update rate.
            target_entropy=ml_collections.config_dict.placeholder(float),  # Target entropy (None for automatic tuning).
            target_entropy_multiplier=0.5,  # Multiplier to dim(A) for target entropy.
            tanh_squash=True,  # Whether to squash actions with tanh.
            state_dependent_std=True,  # Whether to use state-dependent standard deviations for actor.
            actor_fc_scale=0.01,  # Final layer initialization scale for actor.
            min_q=True,  # Whether to use min Q (True) or mean Q (False).
        )
    )
    return config

## Integrated

In [11]:
# NUM_UPDATES x NUM_ENVS x NUM_STEPS
class Transition(NamedTuple):
    obs: jnp.ndarray
    action: jnp.ndarray
    reward: jnp.ndarray
    next_obs: jnp.ndarray
    done: jnp.ndarray
    info: {}

env_name = 'Umaze'  # @param ["Umaze"] {"type":"raw"}
NUM_ENVS = 4 # @param [1,2,4,8,16,32] {"type":"raw"}
TOTAL_TIMESTEPS = 2048 # @param [2048,16384,131072,1048576] {"type":"raw"}
DEPTH = 1 # @param [1,2,4] {"type":"raw"}
NUM_STEPS = 16 # @param [1,2,4,8,16] {"type":"raw"}
NUM_HIDDEN = 128 # @param [32,64,128,256] {"type":"raw"}
WD = 0.1 # @param [0,0.1,0.01,0.001] {"type":"raw"}
MODEL_NAME = "DeepBayesianExplorer"  #@param ["DeepBayesianExplorer","RandomExplorer","PPOExplorer"]
config = {
    "NUM_ENVS": NUM_ENVS,    #
    "WD": WD,
    "NUM_STEPS": NUM_STEPS,   #steps of roll out between update
    "SAC_D_STEPS": 4,
    "ENV_NAME":env_name,
    "SAC_STEP_SIZE": 1.0,
    "SEED": 423,         #highly stochastic
    "TOTAL_TIMESTEPS": TOTAL_TIMESTEPS,   #total steps for all envs
    "NUM_HIDDEN":NUM_HIDDEN,
    "TX":"adamw",
    "DEPTH":DEPTH,
    "LR":2e-4,
    "OPT_STEPS":8,
    "MODEL_NAME": MODEL_NAME,
    "DEBUG": False,
}


out = experiment(config)

{'NUM_ENVS': 4, 'WD': 0.1, 'NUM_STEPS': 16, 'SAC_D_STEPS': 4, 'ENV_NAME': 'Umaze', 'SAC_STEP_SIZE': 1.0, 'SEED': 423, 'TOTAL_TIMESTEPS': 2048, 'NUM_HIDDEN': 128, 'TX': 'adamw', 'DEPTH': 1, 'LR': 0.0002, 'OPT_STEPS': 8, 'MODEL_NAME': 'DeepBayesianExplorer', 'DEBUG': False}
low -inf
high inf


UnboundLocalError: cannot access local variable 'model' where it is not associated with a value

In [None]:
print(jax.tree_util.tree_map(lambda x: x.shape, out["transitions"]))

Transition(obs=(64, 4, 8, 4), action=(64, 4, 8, 2), reward=(64, 4, 8), next_obs=(64, 4, 8, 4), done=(64, 4, 8), info={'discount': (64, 4, 8), 'goal_position': (64, 4, 8, 2), 'is_success': (64, 4, 8), 'last_state': EnvState(time=(64, 4), position=(64, 4, 2), velocity=(64, 4, 2), desired_goal=(64, 4, 2)), 'log_prob': (64, 4, 8), 'reward': (64, 4, 8), 'value': (64, 4, 8)})


In [None]:
print(out["transitions"].action)

In [None]:
import jax
import jax.numpy as jnp
from typing import Any, Dict


def apply_her_with_trajectory(batch: Dict[str, jnp.ndarray], her_ratio=0.8, rng=None) -> Dict[str, jnp.ndarray]:

    batch_size, num_envs, seq_len = batch['rewards'].shape
    goal_dim = batch['goals'].shape[-1]

    if rng is None:
        rng = jax.random.PRNGKey(0)

    rng, mask_rng = jax.random.split(rng)
    her_mask = jax.random.uniform(mask_rng, (batch_size, num_envs, seq_len)) < her_ratio

    max_offset = jnp.arange(seq_len - 1, -1, -1)

    rng, offset_rng = jax.random.split(rng)
    rand = jax.random.uniform(offset_rng, (batch_size, num_envs, seq_len))

    offset = jnp.where(max_offset > 0, (rand * max_offset).astype(jnp.int32) + 1, 0)

    batch_idx = jnp.arange(batch_size)[:, None, None]
    env_idx = jnp.arange(num_envs)[None, :, None]
    time_idx = jnp.arange(seq_len)[None, None, :]

    future_time_idx = jnp.minimum(time_idx + offset, seq_len - 1)

    future_idx = (
        batch_idx,
        env_idx,
        future_time_idx
    )

    new_goals = batch['next_observations'][future_idx][..., :goal_dim]

    goals = jnp.where(her_mask[..., None], new_goals, batch['goals'])

    observations = batch['observations'].copy()
    obs_content = observations[..., :-goal_dim]
    obs_goals = observations[..., -goal_dim:]
    new_obs_goals = jnp.where(her_mask[..., None], new_goals, obs_goals)
    observations = jnp.concatenate([obs_content, new_obs_goals], axis=-1)

    next_observations = batch['next_observations'].copy()
    next_obs_content = next_observations[..., :-goal_dim]
    next_obs_goals = next_observations[..., -goal_dim:]
    new_next_obs_goals = jnp.where(her_mask[..., None], new_goals, next_obs_goals)
    next_observations = jnp.concatenate([next_obs_content, new_next_obs_goals], axis=-1)


    achieved_goal = next_observations[..., :goal_dim]
    dist_to_goal = jnp.linalg.norm(achieved_goal - goals, axis=-1)
    new_rewards = (dist_to_goal < 0.05).astype(jnp.float32)


    rewards = jnp.where(her_mask, new_rewards, batch['rewards'])

    return {
        **batch,
        'goals': goals,
        'observations': observations,
        'next_observations': next_observations,
        'rewards': rewards
    }

def prepare_dataset_for_sac(transition: Any, use_her=True, her_ratio=0.8, rng=None) -> Dict[str, jnp.ndarray]:

    batch_size, num_envs, seq_len = transition.obs.shape[:3]
    obs_dim = transition.obs.shape[-1]

    observations = transition.obs.reshape(batch_size, num_envs, seq_len, obs_dim)
    next_observations = transition.next_obs.reshape(batch_size, num_envs, seq_len, obs_dim)
    actions = transition.action.reshape(batch_size, num_envs, seq_len, 1)
    rewards = transition.reward.reshape(batch_size, num_envs, seq_len)
    done = transition.done.reshape(batch_size, num_envs, seq_len)


    if 'goal_position' in transition.info:
        goal_dim = transition.info['goal_position'].shape[-1]
        goals = transition.info['goal_position'].reshape(batch_size, num_envs, seq_len, goal_dim)
    else:
        goal_dim = 0
        goals = None

    batch = {
        'observations': jnp.concatenate([observations, goals], axis=-1) if goals is not None else observations,
        'actions': actions,
        'rewards': rewards,
        'next_observations': jnp.concatenate([next_observations, goals], axis=-1) if goals is not None else next_observations,
        'dones': done,
        'masks': 1.0 - done
    }

    if goals is not None:
        batch['goals'] = goals

    if use_her and goals is not None:
        print("using her")
        batch = apply_her_with_trajectory(batch, her_ratio, rng)

    flat_batch = {
        'observations': batch['observations'].reshape(-1, batch['observations'].shape[-1]),
        'actions': batch['actions'].reshape(-1, 1),
        'rewards': batch['rewards'].reshape(-1),
        'next_observations': batch['next_observations'].reshape(-1, batch['next_observations'].shape[-1]),
        'masks': batch['masks'].reshape(-1)
    }

    if goals is not None:
        flat_batch['goals'] = batch['goals'].reshape(-1, goal_dim)

    return flat_batch

In [None]:
buffer = prepare_dataset_for_sac(out["transitions"])
print(jax.tree_util.tree_map(lambda x: x.shape, buffer))

In [None]:
import os
import time
import numpy as np
import jax
import jax.numpy as jnp
import tqdm
from typing import Dict, Any, Optional
import wandb
from dataclasses import dataclass

@dataclass
class SACConfig:
    # 训练参数
    seed: int = 42
    num_epochs: int = 100
    batch_size: int = 256
    log_freq: int = 1
    eval_freq: int = 10
    eval_episodes: int = 1

    # wandb参数
    run_group: str = "offline_sac"
    project: str = "OfflineSAC"

    # 评估参数
    eval_temperature: float = 1.0
    eval_on_cpu: bool = False

    # SAC算法参数
    lr: float = 1e-4
    actor_hidden_dims: tuple = (256, 256)
    value_hidden_dims: tuple = (256, 256)
    layer_norm: bool = False
    discount: float = 0.99
    tau: float = 0.005
    target_entropy: Optional[float] = None
    target_entropy_multiplier: float = 0.5
    tanh_squash: bool = True
    state_dependent_std: bool = True
    actor_fc_scale: float = 0.01
    min_q: bool = True
    her_ratio: float = 0.8

def get_exp_name(seed):
    """Generate experiment name."""
    return f"offline_sac_seed_{seed}"

def setup_wandb(project, group, name):
    """Setup wandb logging."""
    wandb.init(project=project, group=group, name=name)

def sample_batch_from_data(data, batch_size, rng):
    """Sample batch from offline data."""
    data_size = data['observations'].shape[0]
    indices = jax.random.randint(rng, (batch_size,), 0, data_size)

    batch = {}
    for key in data.keys():
        if data[key] is not None:
            batch[key] = jnp.array(data[key][indices])

    return batch

def evaluate_agent(agent, env, env_params, key, config, num_eval_episodes=10, eval_temperature=1.0, max_steps=10):
    """Evaluate agent performance."""
    eval_rewards = []
    eval_lengths = []


    key, eval_key = jax.random.split(key)


    for _ in range(num_eval_episodes):

        eval_key, reset_key, sample_key, step_key = jax.random.split(eval_key, 4)

        observations, state = env.pm_env.reset_env(reset_key, env_params)

        episode_reward = 0
        episode_length = 0
        done = False

        while not done and episode_length < max_steps:
            actions = agent.sample_actions(
                observations=observations,
                goals=None,
                seed=sample_key,
                temperature=eval_temperature
            )

            observations, state, rewards, dones, _ = env.pm_env.step_env(step_key, state, actions, env_params)
            episode_reward += rewards.sum()
            episode_length += 1
            done = dones.all()

        eval_rewards.append(episode_reward)
        eval_lengths.append(episode_length)

    return {
        'reward_mean': np.mean(eval_rewards),
        'reward_std': np.std(eval_rewards),
        'length_mean': np.mean(eval_lengths),
        'length_std': np.std(eval_lengths),
    }

def offline_training_with_wandb(
    agent: SACAgent,
    replay_buffer: Dict[str, np.ndarray],
    config: SACConfig,
    eval_env_params,
    key,
    eval_env=None,
):

    buffer_size = replay_buffer['observations'].shape[0]
    steps_per_epoch = buffer_size // config.batch_size

    print(f"Buffer size: {buffer_size}")
    print(f"Steps per epoch: {steps_per_epoch}")
    print(f"Total training steps: {config.num_epochs * steps_per_epoch}")


    # wandb.config.update({
    #     'buffer_size': buffer_size,
    #     'steps_per_epoch': steps_per_epoch,
    #     'total_steps': config.num_epochs * steps_per_epoch,
    #     'config': config.__dict__
    # })

    start_time = time.time()
    key, train_key, eval_key = jax.random.split(key, 3)

    for epoch in tqdm.tqdm(range(config.num_epochs), smoothing=0.1, dynamic_ncols=True):
        epoch_start_time = time.time()
        epoch_losses = []

        for step in range(steps_per_epoch):
            batch = sample_batch_from_data(
                replay_buffer,
                config.batch_size,
                agent.rng
            )

            batch = apply_her(
                batch,
                her_ratio=config.her_ratio,
                rng = agent.rng
            )

            agent, info = agent.update(batch)
            epoch_losses.append(info)

        epoch_time = time.time() - epoch_start_time

        if epoch_losses:
            avg_losses = {}
            for key in epoch_losses[0].keys():
                values = [loss[key] for loss in epoch_losses if key in loss]
                if values:
                    avg_losses[key] = np.mean(values)

        if epoch % config.log_freq == 0:

            train_metrics = {
                'epoch': epoch,
                'time/epoch_time': epoch_time,
                'time/total_time': time.time() - start_time,
                'time/avg_step_time': epoch_time / steps_per_epoch,
            }


            if avg_losses:
                for key, value in avg_losses.items():
                    train_metrics[f'training/{key}'] = value

            train_metrics['training/learning_rate'] = config.lr


            # wandb.log(train_metrics, step=epoch)


            print(f"Epoch {epoch}/{config.num_epochs}: "
                  f"Time: {epoch_time:.2f}s, "
                  f"Critic Loss: {avg_losses.get('critic/critic_loss', 0):.4f}, "
                  f"Actor Loss: {avg_losses.get('actor/actor_loss', 0):.4f}, "
                  f"Alpha: {avg_losses.get('actor/alpha', 0):.4f}")

        if epoch % config.eval_freq == 0 and eval_env is not None:
            if config.eval_on_cpu:
                eval_agent = jax.device_put(agent, device=jax.devices('cpu')[0])
            else:
                eval_agent = agent

            eval_info = evaluate_agent(
                agent=eval_agent,
                env=eval_env,
                config=config,
                num_eval_episodes=config.eval_episodes,
                eval_temperature=config.eval_temperature,
                env_params=eval_env_params,
                key=eval_key
            )

            eval_metrics = {f'evaluation/{k}': v for k, v in eval_info.items()}
            eval_metrics['evaluation/epoch'] = epoch

            # wandb.log(eval_metrics, step=epoch)

            print(f"Evaluation at epoch {epoch}: "
                  f"Reward = {eval_info['reward_mean']:.2f} ± {eval_info['reward_std']:.2f}")

    total_time = time.time() - start_time

    return agent

def main():
    config = SACConfig(
        seed=42,
        num_epochs=100,
        batch_size=256,
        project="YourProject",
        run_group="experiment_group"
    )

    # Set up wandb
    exp_name = get_exp_name(config.seed)
    # setup_wandb(config.project, config.run_group, exp_name)


    eval_env = None
    if config.eval_freq > 0:
        eval_env = PMwrapper(pointax.make_umaze(reward_type="sparse"))
        eval_env_params = eval_env.default_params()


    replay_buffer = buffer

    example_transition = dict(
        observations=replay_buffer['observations'][0],
        actions=replay_buffer['actions'][0],
        rewards=replay_buffer['rewards'][0],
        masks=replay_buffer['masks'][0],
        next_observations=replay_buffer['next_observations'][0],
    )
    if replay_buffer['goals'] is not None:
        example_transition['goals'] = replay_buffer['goals'][0]


    sac_config = {
        'agent_name': 'sac',
        'lr': config.lr,
        'batch_size': config.batch_size,
        'actor_hidden_dims': config.actor_hidden_dims,
        'value_hidden_dims': config.value_hidden_dims,
        'layer_norm': config.layer_norm,
        'discount': config.discount,
        'tau': config.tau,
        'target_entropy': config.target_entropy,
        'target_entropy_multiplier': config.target_entropy_multiplier,
        'tanh_squash': config.tanh_squash,
        'state_dependent_std': config.state_dependent_std,
        'actor_fc_scale': config.actor_fc_scale,
        'min_q': config.min_q,
        'her_ratio': config.her_ratio
    }


    agent = SACAgent.create(
        config.seed,
        example_transition['observations'],
        example_transition['actions'],
        sac_config,
    )


    trained_agent = offline_training_with_wandb(
        agent=agent,
        replay_buffer=replay_buffer,
        config=config,
        eval_env=eval_env,
        eval_env_params=eval_env_params,
        key=jax.random.PRNGKey(config.seed)
    )


# 使用示例
if __name__ == "__main__":
    main()