In [None]:
%pip install jax
%pip install numpy
%pip install matplotlib
%pip install xminigrid
%pip install gymnax
%pip install distrax

In [2]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import distrax

import timeit
import imageio
import matplotlib.pyplot as plt
from tqdm.auto import trange, tqdm

from flax import nnx
import xminigrid

In [3]:
# class TimeStep(struct.PyTreeNode):
#     # hidden environment state, such as grid, agent, goal, etc
#     state: State

#     # similar to the dm_env enterface
#     step_type: StepType
#     reward: jax.Array
#     discount: jax.Array
#     observation: jax.Array

## Utils

### Wrapper

In [4]:
from gymnax.environments.environment import Environment
import abc
from typing import Any, Generic, Optional, TypeVar

import chex
import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
from gymnax.environments import environment, spaces
import xminigrid

@struct.dataclass
class EnvState:
    time: int


@struct.dataclass
class EnvParams:
    max_steps_in_episode: int = 1

class XMiniGridGymnaxWrapper(Environment):
    """
    将 xminigrid 环境包装为 gymnax 风格环境
    """
    def __init__(self, xminigrid_env):
        super().__init__()
        self.xminigrid_env = xminigrid_env

    @property
    def default_params(self):
        # return EnvParams()
        return self.xminigrid_env.default_params()

    def step_env(self, key, state, action, params):
        timestep = self.xminigrid_env.step(params, state, action)
        obs = timestep.observation
        reward = timestep.reward
        done = (timestep.step_type == 2)  # StepType.LAST == 2
        info = {"discount": timestep.discount}
        state = timestep
        return obs, state, reward, done, info

    def reset_env(self, key, params):
        timestep = self.xminigrid_env.reset(params, key)
        obs = timestep.observation
        return obs, timestep

    def get_obs(self, state, params=None, key=None):
        return state.observation

    def is_terminal(self, state, params):
        return state.step_type == 2

    # useless here
    @property
    def name(self):
      return "xminigrid"

    @property
    # potential issue here
    #implementation of xland:
      #def num_actions(self, params: EnvParamsT) -> int:
      # return int(NUM_ACTIONS)
    def num_actions(self):
        return 6

    def action_space(self, params):
        return spaces.Discrete(6)

    def observation_shape(self, params):
        return self.xminigrid_env.observation_shape(params)

    def observation_space(self, params):
        shape = self.observation_shape(params)
        return spaces.Box(low=0, high=255, shape=shape, dtype=jnp.float32)

### Encoders

In [5]:
import jax.nn as nn


class Encoder(nnx.Module):
  def __init__(self, input_dim: int, hidden_dim: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(input_dim, hidden_dim, rngs=rngs)
    self.layer_norm0 = nnx.LayerNorm(hidden_dim, rngs=rngs)

  def __call__(self, x: jax.Array):
    h = self.linear(x)
    return self.layer_norm0(h)

class ActionEncoder(nnx.Module):
  def __init__(self, input_dim: int, hidden_dim: int, rngs: nnx.Rngs):
    self.embed = nnx.Embed(input_dim, hidden_dim, rngs=rngs)
    self.layer_norm0 = nnx.LayerNorm(hidden_dim, rngs=rngs)

  def __call__(self, x: jax.Array):
    h = self.embed(x)
    return self.layer_norm0(h)

class JointEncoder(nnx.Module):
  def __init__(self, hidden_dim: int, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
    self.linear2 = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
    self.layer_norm0 = nnx.LayerNorm(hidden_dim, rngs=rngs)
    self.layer_norm1 = nnx.LayerNorm(hidden_dim, rngs=rngs)
    self.layer_norm2 = nnx.LayerNorm(hidden_dim, rngs=rngs)
    self.layer_norm3 = nnx.LayerNorm(hidden_dim, rngs=rngs)

  def __call__(self, x: jax.Array, rng):
    dist_distrax = distrax.MultivariateNormalDiag(loc=x, scale_diag=1e-1*jnp.ones_like(x))
    # potential shape issue
    x = dist_distrax.sample(seed=rng, sample_shape=(1,))
    x = self.layer_norm0(x)
    h0 = self.linear1(x)
    h = nn.relu(h0)
    h = self.layer_norm1(h) + h0
    h0 = self.linear2(h)
    h = self.layer_norm2(h) + h0
    return self.layer_norm3(h)

### Actor

In [6]:
from jax import lax
import distrax

class Actor(nnx.Module):
  # environment related ???
  log_std_min: float = -4
  log_std_max: float = 2

  def __init__(self, obs_dim, action_dim, hidden_dim, rngs: nnx.Rngs):
    self.mean = nnx.Linear(hidden_dim, action_dim, rngs=rngs)
    self.log_std = nnx.Linear(hidden_dim, action_dim, rngs=rngs)

  def __call__(self, x: jnp.ndarray):
    mean = self.mean(x)
    log_std = jnp.clip(self.log_std(x), self.log_std_min, self.log_std_max)
    return mean, log_std

### Functions

#### computaion

In [7]:
def compute_info_gain_normal(mean, prec, l_prec, next_obs):
  prec = jnp.maximum(prec, 1e-6)
  posterior_prec = prec + l_prec
  prec_ratio = prec / posterior_prec

  posterior_mean = (prec * mean + l_prec * next_obs) / posterior_prec

  delta_mean = next_obs - posterior_mean
  kl = delta_mean * delta_mean * prec
  kl = kl + prec_ratio - jnp.log(prec_ratio) - 1
  kl = 0.5 * jnp.sum(kl, axis=-1)
  return kl, delta_mean

@jax.jit
def compute_expected_info_gain_normal(prec, l_prec):
  prec = jnp.maximum(prec, 1e-6)
  prec_ratio = l_prec / prec
  mi_matrix = 0.5 * jnp.sum(jnp.log(1+prec_ratio), axis=-1)
  return mi_matrix

jnp.set_printoptions(precision=3,suppress=True)
from flax.training import train_state
from jax.scipy.special import gamma,digamma, gammaln, kl_div

def batch_random_split(batch_key,num=2):
    split_keys = jax.vmap(jax.random.split,in_axes=(0,None))(batch_key,num)
    return [split_keys[:, i]  for i in range(num) ]


#### shape manipulation

In [10]:
import matplotlib.pyplot as plt

def reshape(arr):
  if arr.dim < 3:
    raise ValueError("Input array must have at least 3 dimensions (n, b, c, ...).")

  n, b, c, *x_dims = arr.shapes
  # Transpose the first two axes (n, b) to (b, n)
  # We construct the axes tuple dynamically for flexibility
  transpose_axes = (1, 0) + tuple(range(2, arr.ndim))
  transposed_arr = jnp.transpose(arr, axes=transpose_axes)
  # Reshape into (b, n*c, x0, x1, ...)
  new_shape = (b, n * c, *x_dims)
  reshaped_arr = jnp.reshape(transposed_arr, new_shape)

  return reshaped_arr

from typing import List, Any

# Define a type alias for PyTree for better readability
PyTree = Any
from typing import List, Any

# Define a type alias for PyTree for better readability
PyTree = Any
def unpack_pytree_by_first_index(pytree: PyTree) -> List[PyTree]:
    """
    Unpacks a PyTree of JAX arrays along their first dimension (id).

    This function assumes that all JAX arrays within the PyTree
    have a consistent first dimension (the 'id' dimension) and that
    you want to create a separate PyTree for each 'id'.

    Args:
        pytree: A JAX PyTree where the leaves are JAX arrays
                with a leading 'id' dimension.

    Returns:
        A list of PyTrees, where each PyTree corresponds to a single
        'id' from the original PyTree.
    """
    # Get the size of the first dimension from any leaf array
    # We assume all arrays have the same first dimension size.
    first_leaf = jax.tree_util.tree_leaves(pytree)[0]
    num_ids = first_leaf.shape[0]

    # Create a list to store the unpacked PyTrees
    unpacked_pytrees = []

    # Iterate through each ID
    for i in range(num_ids):
        # Use tree_map to slice each array in the PyTree at the current ID
        sliced_pytree = jax.tree_util.tree_map(lambda x: x[i], pytree)
        unpacked_pytrees.append(sliced_pytree)

    return unpacked_pytrees

def unpack_states(pytree):
    return unpack_pytree_by_first_index(jax.tree.map(reshape, pytree))

### Others

In [8]:
class Likelihood_Prec(nnx.Module):
  log_std_min: float = -2
  log_std_max: float = 2

  def __init__(self, obs_dim: int, hidden_dim: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(hidden_dim, obs_dim, rngs=rngs)

  def __call__(self, x: jnp.ndarray):
    log_std = jnp.clip(self.linear(x), self.log_std_min, self.log_std_max)
    return jnp.exp(-log_std)

### Computation

In [56]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax import nnx

jnp.set_printoptions(precision=2,suppress=True)
from flax.training import train_state
from jax.scipy.special import gamma,digamma, gammaln, kl_div
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

@nnx.jit
def compute_info_gain_dirichlet(alpha,next_obs):
    """
    计算互信息矩阵
    alpha: (batch, num_states)
    next_obs: (batch)  index
    输出形状: (batch)
    """

    alpha = jnp.maximum(alpha, 1e-6)
    sum_alpha = jnp.sum(alpha, axis=-1,keepdims=False)

    next_obs = jnp.expand_dims(next_obs, -1).astype("int")
    post_alpha = jnp.take_along_axis(alpha,next_obs,-1)
    post_alpha = post_alpha.squeeze(-1)

    probs = post_alpha / sum_alpha
    log_probs = jnp.log(probs)

    # ( batch)
    entropy = - log_probs

    # ( batch)
    posterior_digamma = digamma(post_alpha+1)

    # ( batch)
    sum_digamma = digamma(sum_alpha+1)

    # ( batch)
    negative_posterior_entropy = posterior_digamma - sum_digamma

    # (batch)
    posterior_kl = entropy + negative_posterior_entropy
    return posterior_kl

@nnx.jit
def compute_mi_dirichlet(alpha):
    """
    计算互信息矩阵
    输入形状: (num_actions, num_states)
    输出形状: (num_actions)
    """

    alpha = jnp.maximum(alpha, 1e-6)
    num_states = alpha.shape[-1]
    sum_alpha = jnp.sum(alpha, axis=-1,keepdims=True)

    probs = alpha / sum_alpha

    log_probs = jnp.log(probs)

    # ( num_actions)
    entropy = - jnp.sum(probs * log_probs,axis=-1)

    # ( num_actions, num_states)
    posterior_digamma = digamma(alpha+1)

    # ( num_actions)
    sum_digamma = digamma(sum_alpha+1).squeeze(-1)

    # ( num_actions)
    negative_posterior_entropy = (probs * posterior_digamma).sum(axis=-1) - sum_digamma

    # (num_actions)
    mi_matrix = entropy + negative_posterior_entropy
    return mi_matrix

@nnx.jit
def optimal_action_and_MI_from_alpha(alphas,rng):

    # 计算互信息矩阵
    # 4x8x8x2
    mi_matrix = compute_mi_dirichlet(alphas)  # ( num_actions)

    rng, _rng = jax.random.split(rng)
    random_perturb = 1e-4*jax.random.normal(_rng,mi_matrix.shape)
    mi_matrix = mi_matrix + random_perturb

    # sum over
    mi_matrix_sum = mi_matrix.sum(axis=(1,2,3))

    optimal_actions = jnp.argmax(mi_matrix_sum, axis=-1)
    return optimal_actions, mi_matrix

## Unsupervised Explorer

In [57]:
from gymnax.experimental import RolloutWrapper
# action = self.model_forward(policy_params, obs, rng_net)
import functools
import gymnax
from typing import Union,Optional,Any
import abc

In [58]:
class UnsupervisedExplorer(nnx.Module):

  @abc.abstractmethod
  def update(self, obs, actions, next_obs, dones, info):
    # update variable parameters
    return #{'kl':KL} MI= E[KL]

  @abc.abstractmethod
  def __call__(self, observations, rng):
    return #actions, {"mi":mi_matrix}

class RandomExplorer(UnsupervisedExplorer):

  def __init__(self, num_actions):
    self.num_actions = num_actions

  def update(self, rng, obs, actions, next_obs, dones, info):
    return {}

  def __call__(self, observations, rng):
    if observations.ndim == 1:
      # possible shape issue here
      actions = jax.random.randint(rng, shape=(1,), minval=0, maxval=self.num_actions)
      return actions, {}
    actions = jax.random.randint(rng, shape=(observations.shape[0],), minval=0, maxval=self.num_actions)
    return actions, {}

class BayesianExplorer(UnsupervisedExplorer):

    def __init__(self, num_states, num_actions):
        self.num_actions = num_actions
        self.num_states = num_states
        self.alphas = nnx.Variable(jnp.ones((num_states, num_actions, num_states))/2)

    def update(self,obs,action,next_obs,done,info):

        prior_alphas = self.alphas[obs, action]
        kl=compute_info_gain_dirichlet(prior_alphas,next_obs)
        self.alphas.value = self.alphas.value.at[obs, action,next_obs].add(1)
        return {"kl":kl}
        #big =
       # return {"big":}


    def __call__(self,observations,rng):

      #  alpha = jnp.take(self.alphas,observations.astype(jnp.int32),axis=0)
        alpha = self.alphas[observations.astype(jnp.int32)]
        actions, mi_matrix = optimal_action_and_MI_from_alpha(alpha,rng)
        MI = mi_matrix[actions]
        return actions, {"mi":MI}

In [59]:
class obs_embedder(nnx.Module):
  def __init__(self, num_type: int, embed_dim: int, rngs: nnx.Rngs):
    self.embed = nnx.Embed(num_type, embed_dim, rngs=rngs)

  def __call__(self, obs):
    t_ids = obs[..., 0]
    c_ids = obs[..., 1]

    t_embed = self.embed(t_ids)
    c_embed = self.embed(c_ids)

    return jnp.concatenate([t_embed, c_embed], axis=-1)

class Joint_MLP(nnx.Module):
  def __init__(self, hidden_dim: int, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
    self.linear2 = nnx.Linear(hidden_dim, 2*hidden_dim, rngs=rngs)

  def __call__(self, x: jnp.ndarray):
    h = self.linear1(x)
    h = nn.relu(h)
    out = self.linear2(h)
    return out

In [62]:
class DeepBayesianExplorer(BayesianExplorer):

    def __init__(self, num_states, num_actions,num_hidden):
        super().__init__(num_states, num_actions)
        self.embed_size = num_states + num_actions
        self.num_hidden = num_hidden
        obs_dim = num_states
        self.obs_embeds = obs_embedder(num_states,num_hidden,rngs=nnx.Rngs(0))
        self.action_embeds = nnx.Embed(num_embeddings=num_actions, features=8*8*2*num_hidden, rngs=nnx.Rngs(0))
        self.joint_embeds = Joint_MLP(8*8*2*num_hidden,rngs=nnx.Rngs(0))
    #    self.linear = nnx.Linear(self.embed_size,num_hidden,rngs=nnx.Rngs(0))
        self.w = nnx.Variable(jnp.zeros((8*8*4*num_hidden, 8*8*2*num_states)))
        self.b = nnx.Variable(jnp.ones((8, 8, 2, num_states)) / 2)

    def update(self,obs,action,next_obs,done,info):
        #not necessary just here to log the kl if we were classical bayesian
        kl = super().update(obs,action,next_obs,done,info)["kl"]
        alpha = info["alpha"]
        # alpha 8 x 8 x 2 x 13
        # next_obs 8 x 8 x 2
        deepkl=compute_info_gain_dirichlet(alpha,next_obs)
        #batch x  num_hidden
        T = info["T"].reshape(-1,self.num_hidden)
        ones = jnp.ones_like(T[:,:1])

        #batch x (num_hidden+1)
        T = jnp.concatenate([T,ones],axis=-1)
        #batch x num_states
        #batch x
        #w 512 x 1664
        #so it should acts seperately on each block
        y = jax.nn.one_hot(next_obs.astype(jnp.int32),self.num_states)
        y = y.reshape(-1,self.num_states)
        # jax.debug.print("{}", T_theta)
        T_T = jnp.transpose(T)

        covariance = T @ T_T
        inv_covariance = jnp.linalg.pinv(covariance)
        delta =  T_T @ inv_covariance @ y

        delta_W = delta[:-1]
        delta_b = delta[-1]
        self.w.value = self.w.value + delta_W
        self.b.value = self.b.value + delta_b
        return {"kl":kl,"deepkl":deepkl}
  #  @functools.partial(nnx.jit, static_argnums=(0,))
    def embed_joint(self,action_embed,obs_embed):
        join = lambda x,y :  jnp.concatenate([x,y],axis=-1)
        vmapped = jax.vmap(join,in_axes=(0,None))
        return vmapped(action_embed,obs_embed)

    def __call__(self,observations,rng):
        #( num_actions x num_hidden )

        # 8x8x2
        observations = observations.astype(jnp.int32)

      #  print ("observations",observations.shape)
        # observations: 8 x 8 x 2
        # obs_embed: 8 x 8 x 4 use nn.Embedding
        # flattent to 256

        # 8x8x4
        obs_embed = self.obs_embeds(observations)
        # 256
        obs_embed_flat = obs_embed.reshape(-1)

        # 4 x 256
        action_embed = self.action_embeds(jnp.arange(self.num_actions))

        #num_actions x embed_size    4 x 256
        embed = action_embed+jnp.expand_dims(obs_embed_flat,0)

        # num_actions x num_hidden
        # 4 x 512 using mlp
        T = self.joint_embeds(embed)

        #  num_actions x num_states
       # print ("self.b ",self.b .shape)

        # self.b  8 x 8 x 2 x 13
        # alpha  4 x  8 x 8 x 2 x 13
        # self.w 512 x 1664
        # T @ self.w needs reshape

        # T @ self.w : 4 x 1664
        Tw = (T @ self.w).reshape(4,8,8,2,13)
        alpha = Tw + self.b

        #the current code will give
        # MI the shape of 4 x  8 x 8 x 2, you need to sum over 8 x 8 x 2
        #then take the argmax over action
        actions, mi_matrix = optimal_action_and_MI_from_alpha(alpha,rng)
        # 512
        T = T[actions]
        # 8 x 8 x 2 x 13
        alpha = alpha[actions]
        # 1
        mi_matrix = mi_matrix[actions]
      #  mi_matrix = jnp.take_along_axis(mi_matrix, actions.reshape((-1,1)), axis=1)
      #  actions = actions.squeeze()
        return actions, {"mi":mi_matrix,"T":T,"alpha":alpha}


In [61]:
import jax
import jax.numpy as jnp

# 假设 obs_embedder, JointEncoder, compute_info_gain_dirichlet, optimal_action_and_MI_from_alpha 已经定义

def test_deep_bayesian_explorer():
    # 假设状态空间为 13（如13种类型），动作空间为 4，隐藏层为 8
    num_states = 13
    num_actions = 4
    num_hidden = 2

    # 构造 explorer
    explorer = DeepBayesianExplorer(num_states, num_actions, num_hidden)

    # 构造一个假的观测 (8, 8, 2)，值在 [0, num_states-1] 之间
    rng = jax.random.PRNGKey(0)
    obs = jax.random.randint(rng, (8, 8, 2), 0, num_states)

    # 调用 __call__ 方法
    actions, info = explorer(obs, rng)

    print("actions:", actions)
    print("info keys:", info.keys())
    print("mi shape:", info["mi"].shape)
    print("T shape:", info["T"].shape)
    print("alpha shape:", info["alpha"].shape)

# 运行测试
test_deep_bayesian_explorer()

actions: 1
info keys: dict_keys(['mi', 'T', 'alpha'])
mi shape: (8, 8, 2)
T shape: (512,)
alpha shape: (8, 8, 2, 13)


In [13]:



class DeepSACBayesianExplorer(UnsupervisedExplorer):
  # ent?
  def __init__(self, obs_dim, num_actions, hidden_dim, rngs: nnx.Rngs,
               l_prec=1.0, wd=1e-2, ent_lambda=1e-3, depth=2):

    self.obs_dim = obs_dim
    self.num_actions = num_actions
    self.hidden_dim = hidden_dim
    self.prec_w = nnx.Variable(jnp.zeros((hidden_dim, obs_dim)), name='prec_w')
    self.mean_w = nnx.Variable(jnp.zeros((hidden_dim, obs_dim)), name='mean_w')
    # what is trainable here
    self.trainable_likelihood_prec = Likelihood_Prec(obs_dim, hidden_dim, rngs)
    self.weight_decay = wd
    self.obs_embeds = Encoder(obs_dim, hidden_dim, rngs)
    self.action_embeds = ActionEncoder(num_actions, hidden_dim, rngs)
    self.joint_embeds = JointEncoder(hidden_dim, rngs)
    self.depth = depth
    self.ent_lambda = ent_lambda

  def __call__(self,observations,rng):
      return self.recursive_mi(observations,rng,self.depth)

  def update(self, rng, obs, action, next_obs, done, info):
    mean = info["mean"]
    prec = info["prec"]

    def _likelihood_loss(rng, T, mean, prec, next_obs):
      l_prec = self.trainable_likelihood_prec(T)
      mu = mean
      # model var + inherent var
      sigma = jnp.sqrt(1 / l_prec + 1 / prec)
      dist_distrax = distrax.MultivariateNormalDiag(loc=mu, scale_diag=sigma)
      log_prob = dist_distrax.log_prob(next_obs)
      return -log_prob, l_prec

    # jit here
    predictive_loss, l_prec = _likelihood_loss(rng, info["T"], mean, prec, next_obs)
    # originally jnp.sum
    mean_error = jnp.mean((mean - next_obs)**2)
    deepkl, delta_mean = compute_info_gain_normal(mean, prec, l_prec, next_obs)
    # batch x num_hidden
    T = info["T"].reshape(-1, self.hidden_dim)

    # batch x obs_dim
    l_prec = l_prec.reshape(-1, self.obs_dim)
    delta_mean = delta_mean.reshape(-1, self.obs_dim)

    T_T = jnp.transpose(T)
    covariance = T @ T_T
    inv_covariance = jnp.linalg.pinv(covariance)

    T_Map = T_T @ inv_covariance

    delta_precW = T_Map @ l_prec
    self.prec_w.value = (self.prec_w.value + delta_precW) * (1-self.weight_decay)
    delta_meanW = T_Map @ delta_mean
    self.mean_w.value = (self.mean_w.value + delta_meanW) * (1-self.weight_decay)

    return {"kl":deepkl,  "predictive_loss": predictive_loss, "mean_error":mean_error}

  # jitable
  def loss(self, rng, obs, action, next_obs, done, info):
    def _likelihood_loss(T, mean, prec, next_obs):
      l_prec = self.trainable_likelihood_prec(T)

      mu = mean
      sigma = jnp.sqrt(1 / l_prec + 1 / prec)
      dist_distrax = distrax.MultivariateNormalDiag(loc=mu, scale_diag=sigma)

      log_prob = dist_distrax.log_prob(next_obs)
      return -log_prob

    T, mean, prec = info["T"], info["mean"], info["prec"]
    likelihood_loss = _likelihood_loss(T, mean, prec, next_obs)
    return likelihood_loss

  def batch_loss(self, rng, obs, actions, next_obs, dones, info):
    vmapped = jax.vmap(self.loss)
    return vmapped(rng, obs, actions, next_obs, dones, info)

  def recursive_mi(self, observations, rng, depth):
    obs_embed = self.obs_embed(observations)
    action_embed = self.action_embed(jnp.arange(self.num_actions))
    # possible shape issue
    embed = action_embed + jnp.expand_dims(obs_embed, axis=0)

    # num_actions x embed_size
    T = self.joint_embeds(embed, rng)
    prec = jnp.maximum(T @ self.prec_w, 1e-3)
    # num_actions x obs_dim
    mean = T @ self.mean_w
    l_prec = self.trainable_likelihood_prec(T)

    MI = compute_expected_info_gain_normal(prec, l_prec)

    if depth > 0:
      vmapped = jax.vmap(self.recursive_mi, in_axes=(0,None,None))
      # num_actions x 1
      actions, info = vmapped(mean, rng, depth-1)
      MI = MI + info["mi"]

    actions = jnp.argmax(MI, axis=0)
    T = T[actions]
    MI = MI[actions]
    l_prec = l_prec[actions]
    prec = prec[actions]
    mean = mean[actions]
    return actions, {"mi":MI,"T":T,"obs_embed":obs_embed,"l_prec":l_prec,
                        "prec":prec,"mean":mean}

def show_variable(model, text):

    graphdef, params, vars,others = nnx.split(model, nnx.Param, nnx.Variable,...)

    print(text,vars)


In [None]:
def test():

    rng = jax.random.PRNGKey(0)
    alphas = jnp.zeros((16, 4, 16))
    # Define rollout manager for pendulum env
    manager = CustomRolloutWrapper(env_or_name="MiniGrid-EmptyRandom-8x8", num_env_steps=3)

    # Simple single episode rollout for policy
    obs, action, reward, next_obs, done, timestep, info, cum_ret = manager.single_rollout(rng,None)

    print ("single action",action)
    print ("obs",obs)
    print ("next_obs",next_obs)
    # Multiple rollouts for same network (different rng, e.g. eval)
    rng_batch = jax.random.split(rng, 2)
    print ("reset_state",manager.batch_reset(rng_batch))

    obs, action, reward, next_obs, done,timestep, info, cum_ret = manager.batch_rollout(
        rng_batch,None
    )

    print ("batch action",action)
    print ("obs",obs)
    print ("next_obs",next_obs)
    print ("info",info)

    # next_state = info["next_state"]
    print ("next_state",next_state)
    print ("next_state.time[:,-1]",next_state.time[:,-1])
    last_state = EnvState(
        time=next_state.time[:,-1],  # Becomes shape (2,)
        state=next_state.state[:,-1]  # Becomes shape (2,)
    )
    obs, action, reward, next_obs, done,info, cum_ret = manager.batch_rollout(
        rng_batch, None, timestep
    )
    # Multiple rollouts for different networks + rng (e.g. for ES)
    batch_params = jax.tree_map(  # Stack parameters or use different
        lambda x: jnp.tile(x, (2, 1)).reshape(2, *x.shape), alphas
    )


test()

## Wrapper

In [None]:
from xminigrid.environment import Environment
from typing import Union,Optional,Any
import abc

class CustomRolloutWrapper:
    """Wrapper to define batch evaluation for generation parameters."""

    def __init__(
        self,
        env_or_name: Union[str,Environment] = "Pendulum-v1",
        num_env_steps: Optional[int] = None,
        env_kwargs: Any | None = None,
        env_params: Any | None = None,
    ):
        """Wrapper to define batch evaluation for generation parameters."""
        # Define the RL environment & network forward function
        if env_kwargs is None:
            env_kwargs = {}
        if env_params is None:
            env_params = {}
        if isinstance(env_or_name,Environment):
            self.env = env_or_name
            self.env_params = env_or_name.default_params
        else:
            self.env, self.env_params = xminigrid.make(env_or_name, **env_kwargs)
        self.env_params = self.env_params.replace(**env_params)

        if num_env_steps is None:
            self.num_env_steps = self.env_params.max_steps
        else:
            self.num_env_steps = num_env_steps

    def batch_reset(self, rng_input):
        batch_rest = jax.vmap(self.single_reset_state)
        return batch_rest(rng_input)

    # state vs. timestep, potential issue here
    def single_reset_state(self, rng_input):
        rng_reset, rng_episode = jax.random.split(rng_input)
        timestep = self.env.reset(self.env_params, rng_reset)
        return timestep

    def batch_rollout(self, rng_eval, model:UnsupervisedExplorer, timestep=None, num_steps=1):
        batch_rollout = jax.vmap(self.single_rollout, in_axes=(0,None,None,None))
        return batch_rollout(rng_eval, model, timestep, num_steps)

    def single_rollout(self, rng_eval, model:UnsupervisedExplorer, timestep=None, num_steps=1):
        rng_reset, rng_episode = jax.random.split(rng_eval)

        if timestep is None:
          timestep = self.env.reset(self.env_params, rng_reset)

        obs = timestep.observation

        def policy_step(state_input, _):
          obs, timestep, rng, cum_reward, valid_mask = state_input
          rng, rng_step, rng_net = jax.random.split(rng, 3)
          if model is not None:
              temp,info = model( obs, rng_net)
              action = self.env.action_space(self.env_params).sample(rng_net)
          else:
            # not action space?
            action = action = jax.random.randint(rng_step, shape=(), minval=0, maxval=self.env.num_actions(env_params))
            info = {}

          next_timestep = self.env.step(self.env_params, timestep, action)
          next_obs = next_timestep.observation
          reward = next_timestep.reward
          done = next_timestep.step_type == 2

          info.update({"discount": next_timestep.discount})
          new_cum_reward = cum_reward + reward * valid_mask
          new_valid_mask = valid_mask * (1- done)
          carry = [next_obs, next_timestep, rng, new_cum_reward, new_valid_mask]
          y = [obs, action, reward, next_obs, done, timestep, info]

          return carry, y

        carry_out, scan_out = jax.lax.scan(policy_step, [obs, timestep, rng_episode, jnp.array([0.0]), jnp.array([1.0])], (), num_steps)
        obs, action, reward, next_obs, done, timestep, info = scan_out
        cum_return = carry_out[-2]
        info["last_timestep"] = carry_out[1]

        return obs, action, reward, next_obs, done, timestep, info, cum_return



In [None]:
class UnsupervisedRolloutWrapper(CustomRolloutWrapper):
  def batch_update(self, rng_update, model, obs, action, next_obs, done, info):
    if model is None: return {}
    return model.update(rng_update, obs, action, next_obs, done, info)

## Exploration

In [None]:
import jax
import jax.numpy as jnp
jnp.set_printoptions(precision=2,suppress=True)
from jax.scipy.special import digamma, gammaln, kl_div
import flax.linen as nn
import numpy as np
import optax
import time
import flax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any, Dict
import distrax
import gymnax
import functools
from gymnax.environments import spaces
from gymnax.wrappers import FlattenObservationWrapper, LogWrapper
import matplotlib.pyplot as plt

import matplotlib.pyplot as plt

import optax
from flax.nnx.helpers import TrainState


### Preparation

In [None]:
class MyTrainState(TrainState):
    vars: nnx.Variable
    others: nnx.State

    @property
    def need_train(self):
        return len(self.params) > 0

is_trainable = lambda path, node: ( node.type == nnx.Param and
    True in [ 'trainable' in t for t in path] )

In [None]:
def train_state_from_model(model, tx=optax.adam(0.02)):
  graphdef, trainable_params, vars, others = nnx.split(model, is_trainable, nnx.Variable,...)
  return MyTrainState.create(params=trainable_params, tx=tx, vars=vars, others=others, graphdef=graphdef)

def train_state_update_model(model,state):
    graphdef, trainable_params, vars, others = nnx.split(model,is_trainable, nnx.Variable,...)
    return state.replace(vars=vars,others=others)

def model_from_train_state(state):
    return nnx.merge(state.graphdef, state.params, state.vars,state.others)

In [None]:
# NUM_UPDATES x NUM_ENVS x NUM_STEPS
class Transition(NamedTuple):
    obs: jnp.ndarray
    action: jnp.ndarray
    reward: jnp.ndarray
    next_obs: jnp.ndarray
    done: jnp.ndarray
    info: {}

### Training

In [None]:
def make_train(config):
  config["NUM_UPDATES"] = (config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"]// config["NUM_ENVS"])

  rng = jax.random.PRNGKey(config["SEED"])
  rng_batch = jax.random.split(rng, config["NUM_ENVS"])

  manager = UnsupervisedRolloutWrapper(env_or_name=config["ENV_NAME"])
  num_actions = manager.env.num_actions(manager.env_params)
  obs_dim = manager.env.observation_shape(manager.env_params)

  # model
  if config["MODEL_NAME"] == "XlandDeepSACBayesianExplorer":
    NUM_TITLE_TYPES = 13
    NUM_COLORS = 12
    NUM_CLASSES = NUM_TITLE_TYPES * NUM_COLORS
    model = XlandDeepSACBayesianExplorer(obs_raw_shape=obs_dim,
                                         num_actions=num_actions,
                                         hidden_dim=config["NUM_HIDDEN"],
                                         rngs=nnx.Rngs(config["SEED"]),
                                         wd=config["WD"],
                                         depth=config["DEPTH"])
  else:
    model = XlandRandomExplorer(num_actions)

  @nnx.jit
  def _train_step(state:MyTrainState, rng_loss, obs, action,next_obs,done,info):

    def loss_fn(graphdef, params, vars, others):
      model = nnx.merge(graphdef, params, vars, others)
      return model.batch_loss(rng_loss,obs, action,next_obs,done,info).mean()

    def opt_step(state:MyTrainState, unused):
      grads = jax.grad(loss_fn, 1)(state.graphdef, state.params, state.vars, state.others)
      return state.apply_gradients(grads=grads), None

    state, _ = jax.lax.scan(opt_step, state, None, config["OPT_STEPS"])
    return state

  @nnx.jit
  def _rollout_and_update_step(runner_state, unused):
    train_state, rng_batch, last_timestep= runner_state

    model = model_from_train_state(train_state)
    rng_batch, rng_step, rng_update, rng_loss = batch_random_split(rng_batch, 4)

    rollout_results = manager.batch_rollout(rng_batch, model, timestep=last_timestep, num_steps=config["NUM_STEPS"])
    obs, action, reward, next_obs, done, timestep, info, cum_return = rollout_results

    transition = Transition(obs, action, reward, next_obs, done, info)
    last_timestep = info["last_timestep"]

    update_info = manager.batch_update(rng_update, model, obs, action, next_obs, done, info)
    info.update(update_info)
    train_state = train_state_update_model(model, train_state)

    if train_state.need_train:
      train_state = _train_step(train_state, rng_loss, obs, action, next_obs, done, info)

    runner_state = (train_state, rng_batch, last_timestep)
    return runner_state, (transition, timestep)

  def train(rng_batch, model, manager):

    rng_batch, rng_reset = batch_random_split(rng_batch, 2)
    start_timestep = manager.batch_reset(rng_reset)

    if config["TX"] == "adamw":
      tx = optax.adamw(config["LR"])
    elif config["TX"] == "sgd":
      tx = optax.sgd(config["LR"])
    else:
      tx = None
      assert False, config["TX"] + "is not available"
    train_state = train_state_from_model(model, tx)
    runner_state = (train_state, rng_batch, start_timestep)
    runner_state, output = jax.lax.scan(_rollout_and_update_step, runner_state, None, config["NUM_UPDATES"])

    transitions, timesteps = output
    return {"runner_state": runner_state, "transitions": transitions, "timesteps": timesteps}

  return train, model, manager, rng_batch


In [None]:
def experiment(config):
  print(config)
  train_fn, model, manager, rng_batch = make_train(config)
  train_jit = nnx.jit(train_fn)

  out = jax.block_until_ready(train_fn(rng_batch, model, manager))
  print("data shape:", jax.tree_util.tree_map(lambda x: x.shape, out["transitions"]))

  train_state, rng_batch, last_timestep = out["runner_state"]

  model = model_from_train_state(train_state)

  if "mi" in out["transitions"].info:
    # Create figure and axis
        plt.figure(figsize=(10, 6))
        # Sample JAX NumPy arrays (replace these with your actual arrays)
        #  print (out["transitions"].info)
        eig_array = out["transitions"].info["mi"].reshape(-1)
        big_array = out["transitions"].info["kl"].reshape(-1)
        # Plot both arrays
        plt.plot(eig_array, label='EIG', marker='o', linestyle='-', color='blue')
        plt.plot(big_array, label='BIG', marker='s', linestyle='-', color='red')

        if "smi" in out["transitions"].info:
          smi_array = out["transitions"].info["smi"].reshape(-1)
          plt.plot(smi_array, label='SMI', marker='^', linestyle='-', color='green')

        # add labels and title
        plt.xlabel('Num of Updates')
        plt.ylabel('Information Gain')
        Title = "InfoGains for" + config["MODEL_NAME"]
        Title = Title + "Total InfoGains" + "{:10.4f}".format(big_array.sum().item())
        Title = Title +  " with Seed" +str(config["SEED"])
        plt.title(Title)

        # add grid and legend
        plt.grid(alpha=0.3)
        plt.legend()

        plt.tight_layout()
        plt.savefig(Title.replace(" ","_")+'.pdf', format='pdf', dpi=300, bbox_inches='tight')
        plt.show()

  if "l_prec" in out["transitions"].info:
      l_prec_mean = out["transitions"].info["l_prec"].mean(axis=(1,2,3), keepdims=False)
      mean_error = out["transitions"].info["mean_error"].mean(axis=(1,2), keepdims=False)

      plt.figure(figsize=(10, 6))
      plt.plot(l_prec_mean, label='L_prec', marker='o', linestyle='-', color='blue')
      plt.plot(mean_error, label='Mean Error', marker='s', linestyle='-', color='yellow')

      plt.xlabel('Num of Updates')
      plt.ylabel('Mean Precision')
      Title = "Comparison of Mean Precisions"

      plt.title(Title)

      plt.grid(alpha=0.3)
      plt.legend()

      plt.tight_layout()
      plt.savefig(Title.replace(" ","_")+'.pdf', format='pdf', dpi=300, bbox_inches='tight')
      plt.show()

  return out

In [None]:
env_name = "MiniGrid-EmptyRandom-8x8"
NUM_ENVS = 1 # @param[1,2,4,8,16,32]
TOTAL_TIMESTEPS = 16384 # @param [2048,16384,131072,1048576] {"type":"raw"}
DEPTH = 1 # @param [1,2,4] {"type":"raw"}
NUM_STEPS = 8 # @param [1,2,4,8,16] {"type":"raw"}
NUM_HIDDEN = 128 # @param [32,64,128,256] {"type":"raw"}
WD = 0.1 # @param [0,0.1,0.01,0.001] {"type":"raw"}
MODEL_NAME = "XlandDeepSACBayesianExplorer"  #@param ["DeepSACBayesianExplorer","RandomExplorer","XlandDeepSACBayesianExplorer"]
config = {
    "NUM_ENVS": NUM_ENVS,    #
    "WD": WD,
    "NUM_STEPS": NUM_STEPS,   #steps of roll out between update
    "NUM_OOF": NUM_HIDDEN, # num hidden for now
    "SAC_D_STEPS": 4,
    "ENV_NAME":env_name,
    "SAC_STEP_SIZE": 1.0,
    "SEED": 423,         #highly stochastic
    "TOTAL_TIMESTEPS": TOTAL_TIMESTEPS,   #total steps for all envs
    "NUM_HIDDEN":NUM_HIDDEN,
    "TX":"adamw",
    "DEPTH":DEPTH,
    "LR":2e-4,
    "OPT_STEPS":8,
    "MODEL_NAME": MODEL_NAME,
    "DEBUG": False,
}

In [None]:
out = experiment(config)

## IQL

In [None]:
import os
import time
from functools import partial
from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple

import distrax
import flax
import flax.linen as nn

import jax
import jax.numpy as jnp
import numpy as np
import optax
import tqdm
import wandb
from flax.training.train_state import TrainState
from omegaconf import OmegaConf
from pydantic import BaseModel

os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"


### Config

In [None]:
class IQLConfig(BaseModel):
    # GENERAL
    algo: str = "IQL"
    project: str = "train-IQL"
    env_name: str = "MiniGrid-EmptyRandom-6x6"
    seed: int = 42
    eval_episodes: int = 5
    log_interval: int = 100
    eval_interval: int = 100000
    batch_size: int = 256
    max_steps: int = int(1e6)
    n_jitted_updates: int = 8
    # DATASET
    data_size: int = int(1e6)
    normalize_state: bool = False
    normalize_reward: bool = True
    # NETWORK
    hidden_dims: Tuple[int, int] = (256, 256)
    actor_lr: float = 3e-4
    value_lr: float = 3e-4
    critic_lr: float = 3e-4
    layer_norm: bool = True
    opt_decay_schedule: bool = True
    # IQL SPECIFIC
    expectile: float = (
        0.7  # FYI: for Hopper-me, 0.5 produce better result. (antmaze: expectile=0.9)
    )
    beta: float = (
        3.0  # FYI: for Hopper-me, 6.0 produce better result. (antmaze: beta=10.0)
    )
    tau: float = 0.005
    discount: float = 0.99

    def __hash__(
        self,
    ):  # make config hashable to be specified as static_argnums in jax.jit.
        return hash(self.__repr__())


conf_dict = OmegaConf.from_cli()
config = IQLConfig(**conf_dict)

### Networks

In [None]:
def default_init(scale: Optional[float] = jnp.sqrt(2)):
    return nn.initializers.orthogonal(scale)


class MLP(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    activate_final: bool = False
    kernel_init: Callable[[Any, Sequence[int], Any], jnp.ndarray] = default_init()
    layer_norm: bool = False

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for i, hidden_dims in enumerate(self.hidden_dims):
            x = nn.Dense(hidden_dims, kernel_init=self.kernel_init)(x)
            if i + 1 < len(self.hidden_dims) or self.activate_final:
                if self.layer_norm:  # Add layer norm after activation
                    x = nn.LayerNorm()(x)
                x = self.activations(x)
        return x


class Critic(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    @nn.compact
    def __call__(self, observations: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray:
        batch_size = observations.shape[0]
        actions = jax.nn.one_hot(actions, num_classes=4) #one-hot encoding
        flat_observations = observations.reshape(batch_size, -1)
        inputs = jnp.concatenate([flat_observations, actions], axis=-1)
        critic = MLP((*self.hidden_dims, 1), activations=self.activations)(inputs)
        return jnp.squeeze(critic, -1)


def ensemblize(cls, num_qs, out_axes=0, **kwargs):
    split_rngs = kwargs.pop("split_rngs", {})
    return nn.vmap(
        cls,
        variable_axes={"params": 0},
        split_rngs={**split_rngs, "params": True},
        in_axes=None,
        out_axes=out_axes,
        axis_size=num_qs,
        **kwargs,
    )


class ValueCritic(nn.Module):
    hidden_dims: Sequence[int]
    layer_norm: bool = False

    @nn.compact
    def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
        batch_size = observations.shape[0]
        obs_flat = observations.reshape(batch_size, -1)
        critic = MLP((*self.hidden_dims, 1), layer_norm=self.layer_norm)(obs_flat)
        return jnp.squeeze(critic, -1)


class GaussianPolicy(nn.Module):
    hidden_dims: Sequence[int]
    action_dim: int
    log_std_min: Optional[float] = -5.0
    log_std_max: Optional[float] = 2

    @nn.compact
    def __call__(
        self, observations: jnp.ndarray, temperature: float = 1.0
    ) -> distrax.Distribution:
        outputs = MLP(
            self.hidden_dims,
            activate_final=True,
        )(observations)

        means = nn.Dense(
            self.action_dim, kernel_init=default_init()
        )(outputs)
        log_stds = self.param("log_stds", nn.initializers.zeros, (self.action_dim,))
        log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)

        distribution = distrax.MultivariateNormalDiag(
            loc=means, scale_diag=jnp.exp(log_stds) * temperature
        )
        return distribution

class CatPolicy(nn.Module):
  hidden_dims : Sequence[int]
  action_dim: int

  @nn.compact
  def __call__(self, observations: jnp.ndarray, temperature: float = 1.0) -> distrax.Distribution:
    x = observations.reshape(observations.shape[0], -1) # flatten
    outputs = MLP(self.hidden_dims, activate_final=True)(x)
    logits = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)
    distribution = distrax.Categorical(logits=logits)
    return distribution


### Utils

In [None]:
print(jtu.tree_map(jnp.shape, replay_buffer))
print(type(replay_buffer))
print(replay_buffer["dones"])

In [None]:
class Transition(NamedTuple):
    observations: jnp.ndarray
    actions: jnp.ndarray
    rewards: jnp.ndarray
    next_observations: jnp.ndarray
    dones: jnp.ndarray
    dones_float: jnp.ndarray

In [None]:
def get_normalization(dataset: Transition) -> float:
    # into numpy.ndarray
    dataset = jax.tree_util.tree_map(lambda x: np.array(x), dataset)
    returns = []
    ret = 0
    for r, term in zip(dataset.rewards, dataset.dones_float):
        ret += r
        if term:
            returns.append(ret)
            ret = 0
    return (max(returns) - min(returns)) / 1000

In [None]:
def preprocess_dataset(
     dataset: dict, config: IQLConfig, clip_to_eps: bool = True, eps: float = 1e-5
) -> Transition:

    if clip_to_eps:
        lim = 1 - eps
        dataset["actions"] = jnp.clip(dataset["actions"], -lim, lim)

    # dones_float = np.zeros_like(dataset['dones'])

    # # for i in range(len(dones_float) - 1):
    # #     print(i)
    # #     if np.linalg.norm(dataset['observations'][i + 1] -
    # #                         dataset['next_observations'][i]
    # #                         ) > 1e-6 or dataset['dones'][i] == True:
    # #         dones_float[i] = 1
    # #     else:
    # #         dones_float[i] = 0
    # dones_float[-1] = 1

    obs = dataset['observations']         # shape: (N, 7, 7, 2)
    obs = dataset['observations']         # shape: (N, 7, 7, 2)
    next_obs = dataset['next_observations']  # shape: (N, 7, 7, 2)
    dones = dataset['dones']              # shape: (N,)

    # 展平每个 observation
    obs_flat = obs[1:].reshape((obs.shape[0] - 1, -1))           # shape: (N-1, 98)
    next_obs_flat = next_obs[:-1].reshape((next_obs.shape[0] - 1, -1))  # shape: (N-1, 98)

    # 对每个样本求 L2 范数
    obs_diff = jnp.linalg.norm(obs_flat - next_obs_flat, axis=1)   # shape: (N-1,)
    obs_flag = obs_diff > 1e-6
    done_flag = dones[:-1] == True

    dones_float = jnp.zeros_like(dones, dtype=jnp.float32)
    dones_float = dones_float.at[:-1].set(jnp.logical_or(obs_flag, done_flag).astype(jnp.float32))
    dones_float = dones_float.at[-1].set(1.0)

    dataset = Transition(
        observations=jnp.array(dataset["observations"], dtype=jnp.float32),
        actions=jnp.array(dataset["actions"], dtype=jnp.float32),
        rewards=jnp.array(dataset["rewards"], dtype=jnp.float32),
        next_observations=jnp.array(dataset["next_observations"], dtype=jnp.float32),
        dones=jnp.array(dataset["dones"], dtype=jnp.float32),
        dones_float=jnp.array(dones_float, dtype=jnp.float32),
    )

    # normalize states
    # obs_mean, obs_std = 0, 1
    # if config.normalize_state:
    #     obs_mean = dataset.observations.mean(0)
    #     obs_std = dataset.observations.std(0)
    #     dataset = dataset._replace(
    #         observations=(dataset.observations - obs_mean) / (obs_std + 1e-5),
    #         next_observations=(dataset.next_observations - obs_mean) / (obs_std + 1e-5),
    #     )
    # # normalize rewards
    # if config.normalize_reward:
    #     normalizing_factor = get_normalization(dataset)
    #     dataset = dataset._replace(rewards=dataset.rewards / normalizing_factor)

    # shuffle data and select the first data_size samples
    # data_size = min(config.data_size, len(dataset.observations))
    # rng = jax.random.PRNGKey(config.seed)
    # rng, rng_permute, rng_select = jax.random.split(rng, 3)
    # perm = jax.random.permutation(rng_permute, len(dataset.observations))
    # dataset = jax.tree_util.tree_map(lambda x: x[perm], dataset)
    # assert len(dataset.observations) >= data_size
    # dataset = jax.tree_util.tree_map(lambda x: x[:data_size], dataset)
    return dataset

In [None]:
def expectile_loss(diff, expectile=0.8) -> jnp.ndarray:
    weight = jnp.where(diff > 0, expectile, (1 - expectile))
    return weight * (diff**2)

def target_update(
    model: TrainState, target_model: TrainState, tau: float
) -> TrainState:
    new_target_params = jax.tree_util.tree_map(
        lambda p, tp: p * tau + tp * (1 - tau), model.params, target_model.params
    )
    return target_model.replace(params=new_target_params)


def update_by_loss_grad(
    train_state: TrainState, loss_fn: Callable
) -> Tuple[TrainState, jnp.ndarray]:
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(train_state.params)
    new_train_state = train_state.apply_gradients(grads=grad)
    return new_train_state, loss

### Model

In [None]:
class IQLTrainState(NamedTuple):
    rng: jax.random.PRNGKey
    critic: TrainState
    target_critic: TrainState
    value: TrainState
    actor: TrainState

class IQL(object):

    @classmethod
    def update_critic(
        self, train_state: IQLTrainState, batch: Transition, config: IQLConfig
    ) -> Tuple["IQLTrainState", Dict]:
        next_v = train_state.value.apply_fn(
            train_state.value.params, batch.next_observations
        )
        target_q = batch.rewards + config.discount * (1 - batch.dones) * next_v

        def critic_loss_fn(
            critic_params: flax.core.FrozenDict[str, Any]
        ) -> jnp.ndarray:
            q1, q2 = train_state.critic.apply_fn(
                critic_params, batch.observations, batch.actions
            )
            critic_loss = ((q1 - target_q) ** 2 + (q2 - target_q) ** 2).mean()
            return critic_loss

        new_critic, critic_loss = update_by_loss_grad(
            train_state.critic, critic_loss_fn
        )
        return train_state._replace(critic=new_critic), critic_loss

    @classmethod
    def update_value(
        self, train_state: IQLTrainState, batch: Transition, config: IQLConfig
    ) -> Tuple["IQLTrainState", Dict]:
        q1, q2 = train_state.target_critic.apply_fn(
            train_state.target_critic.params, batch.observations, batch.actions
        )
        q = jax.lax.stop_gradient(jnp.minimum(q1, q2))
        def value_loss_fn(value_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray:
            v = train_state.value.apply_fn(value_params, batch.observations)
            value_loss = expectile_loss(q - v, config.expectile).mean()
            return value_loss

        new_value, value_loss = update_by_loss_grad(train_state.value, value_loss_fn)
        return train_state._replace(value=new_value), value_loss

    @classmethod
    def update_actor(
        self, train_state: IQLTrainState, batch: Transition, config: IQLConfig
    ) -> Tuple["IQLTrainState", Dict]:
        v = train_state.value.apply_fn(train_state.value.params, batch.observations)
        q1, q2 = train_state.critic.apply_fn(
            train_state.target_critic.params, batch.observations, batch.actions
        )
        q = jnp.minimum(q1, q2)
        exp_a = jnp.exp((q - v) * config.beta)
        exp_a = jnp.minimum(exp_a, 100.0)
        def actor_loss_fn(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray:
            dist = train_state.actor.apply_fn(actor_params, batch.observations)
            log_probs = dist.log_prob(batch.actions.astype(jnp.int32))
            actor_loss = -(exp_a * log_probs).mean()
            return actor_loss

        new_actor, actor_loss = update_by_loss_grad(train_state.actor, actor_loss_fn)
        return train_state._replace(actor=new_actor), actor_loss

    @classmethod
    def update_n_times(
        self,
        train_state: IQLTrainState,
        dataset: Transition,
        rng: jax.random.PRNGKey,
        config: IQLConfig,
    ) -> Tuple["IQLTrainState", Dict]:
        for _ in range(config.n_jitted_updates):
            rng, subkey = jax.random.split(rng)
            batch_indices = jax.random.randint(
                subkey, (config.batch_size,), 0, len(dataset.observations)
            )
            batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset)

            train_state, value_loss = self.update_value(train_state, batch, config)
            train_state, actor_loss = self.update_actor(train_state, batch, config)
            train_state, critic_loss = self.update_critic(train_state, batch, config)
            new_target_critic = target_update(
                train_state.critic, train_state.target_critic, config.tau
            )
            train_state = train_state._replace(target_critic=new_target_critic)
        return train_state, {
            "value_loss": value_loss,
            "actor_loss": actor_loss,
            "critic_loss": critic_loss,
        }

    @classmethod
    def get_action(
        self,
        train_state: IQLTrainState,
        observations: np.ndarray,
        seed: jax.random.PRNGKey,
        temperature: float = 1.0,
        max_action: float = 1.0,
    ) -> jnp.ndarray:

        # modified for discrete actions
        dist = train_state.actor.apply_fn(
            train_state.actor.params, observations, temperature=temperature
        )
        actions = jnp.argmax(dist.logits, axis=-1)
        return actions

### Train & Evaluate

In [None]:
def create_iql_train_state(
    rng: jax.random.PRNGKey,
    observations: jnp.ndarray,
    actions: jnp.ndarray,
    config: IQLConfig,
) -> IQLTrainState:
    rng, actor_rng, critic_rng, value_rng = jax.random.split(rng, 4)
    # initialize actor
    action_dim = 4

    # Gaussian Model
    # actor_model = GaussianPolicy(
    #     config.hidden_dims,
    #     action_dim=action_dim,
    #     log_std_min=-5.0,
    # )

    # Cat Model
    actor_model = CatPolicy(
        config.hidden_dims,
        action_dim = action_dim
    )

    if config.opt_decay_schedule:
        schedule_fn = optax.cosine_decay_schedule(-config.actor_lr, config.max_steps)
        actor_tx = optax.chain(optax.scale_by_adam(), optax.scale_by_schedule(schedule_fn))
    else:
        actor_tx = optax.adam(learning_rate=config.actor_lr)
    actor = TrainState.create(
        apply_fn=actor_model.apply,
        params=actor_model.init(actor_rng, observations),
        tx=actor_tx,
    )
    # initialize critic
    critic_model = ensemblize(Critic, num_qs=2)(config.hidden_dims)
    critic = TrainState.create(
        apply_fn=critic_model.apply,
        params=critic_model.init(critic_rng, observations, actions),
        tx=optax.adam(learning_rate=config.critic_lr),
    )
    target_critic = TrainState.create(
        apply_fn=critic_model.apply,
        params=critic_model.init(critic_rng, observations, actions),
        tx=optax.adam(learning_rate=config.critic_lr),
    )
    # initialize value
    value_model = ValueCritic(config.hidden_dims, layer_norm=config.layer_norm)
    value = TrainState.create(
        apply_fn=value_model.apply,
        params=value_model.init(value_rng, observations),
        tx=optax.adam(learning_rate=config.value_lr),
    )
    return IQLTrainState(
        rng,
        critic=critic,
        target_critic=target_critic,
        value=value,
        actor=actor,
    )

In [None]:
def evaluate(
    policy_fn, env, env_params, num_episodes: int, rng
) -> float:
    print("evaluation started")
    episode_returns = []

    for i in range(num_episodes):
      rng, _rng = jax.random.split(rng)
      episode_return = 0

      timestep = env.reset(env_params, _rng)
      done = timestep.step_type == 2
      observation = timestep.observation

      while not done:
          # potential case issue
          obs = observation[None, ...]
          action = policy_fn(observations=obs)

          if isinstance(action, (jnp.ndarray, np.ndarray)) and action.shape == (1,):
            action = int(action[0])

          timestep = env.step(env_params, timestep, action)
          reward = timestep.reward
          done = timestep.step_type == 2
          observation = timestep.observation

          episode_return += reward
      episode_returns.append(episode_return)
    return float(jnp.mean(jnp.array(episode_returns)))

In [None]:
if __name__ == "__main__":
    wandb.init(config=config, project=config.project)

    rng = jax.random.PRNGKey(config.seed)
    rng, _rng = jax.random.split(rng)

    env, env_params = xminigrid.make("MiniGrid-EmptyRandom-6x6")
    env = GymAutoResetWrapper(env)

    dataset= preprocess_dataset(replay_buffer, config)

    # create train_state
    example_batch: Transition = jax.tree_util.tree_map(lambda x: x[0], dataset)
    train_state: IQLTrainState = create_iql_train_state(
        _rng,
        example_batch.observations[None, ...],
        example_batch.actions[None, ...],
        config,
    )

    algo = IQL()
    update_fn = jax.jit(algo.update_n_times, static_argnums=(3,))
    act_fn = jax.jit(algo.get_action)
    num_steps = config.max_steps // config.n_jitted_updates
    eval_interval = config.eval_interval // config.n_jitted_updates
    for i in tqdm.tqdm(range(1, num_steps + 1), smoothing=0.1, dynamic_ncols=True):
        rng, subkey = jax.random.split(rng)
        train_state, update_info = update_fn(train_state, dataset, subkey, config)

        if i % config.log_interval == 0:
            train_metrics = {f"training/{k}": v for k, v in update_info.items()}
            wandb.log(train_metrics, step=i)

        # if i % eval_interval == 0:
        #     policy_fn = partial(
        #         act_fn,
        #         temperature=0.0,
        #         seed=jax.random.PRNGKey(0),
        #         train_state=train_state,
        #     )
        #     normalized_score = evaluate(
        #         policy_fn,
        #         env,
        #         env_params,
        #         rng = _rng,
        #         num_episodes=config.eval_episodes,
        #     )
        #     print(i, normalized_score)
        #     eval_metrics = {f"{config.env_name}/normalized_score": normalized_score}
        #     wandb.log(eval_metrics, step=i)
    # final evaluation
    policy_fn = partial(
        act_fn,
        temperature=0.0,
        seed=jax.random.PRNGKey(0),
        train_state=train_state,
    )
    normalized_score = evaluate(
        policy_fn,
        env,
        env_params,
        rng = _rng,
        num_episodes=config.eval_episodes,
    )
    print("Final Evaluation", normalized_score)
    wandb.log({f"{config.env_name}/final_normalized_score": normalized_score})
    wandb.finish()

## Collect Rollouts

In [None]:
from xminigrid.wrappers import GymAutoResetWrapper

def build_rollout(env, env_params, num_steps):
  def rollout(rng):
    def _step_fn(carry, _):
      rng, timestep = carry
      rng, _rng = jax.random.split(rng)
      action = jax.random.randint(_rng, shape=(), minval=0, maxval=env.num_actions(env_params))

      timestep = env.step(env_params, timestep, action)

      return (rng, timestep), (timestep,action)

    rng, _rng = jax.random.split(rng)
    timestep = env.reset(env_params, _rng)
    rng, (transitions, actions) = jax.lax.scan(_step_fn, (rng, timestep), None, length=num_steps)

    return transitions, actions
  return rollout

In [None]:
env, env_params = xminigrid.make("MiniGrid-EmptyRandom-8x8")
env = GymAutoResetWrapper(env)

rollout_fn = jax.jit(build_rollout(env, env_params, num_steps=1e6))

transitions, actions = rollout_fn(jax.random.key(0))

In [None]:
obs_dim = env.observation_shape(env_params)

In [None]:
print(obs_dim)

(7, 7, 2)


In [None]:
print("Transitions shapes: \n", jtu.tree_map(jnp.shape, transitions))
print("Actions shape:", actions.shape)
print(type(actions))

Transitions shapes: 
 TimeStep(state=State(key=(1000000,), step_num=(1000000,), grid=(1000000, 8, 8, 2), agent=AgentState(position=(1000000, 2), direction=(1000000,), pocket=(1000000, 2)), goal_encoding=(1000000, 5), rule_encoding=(1000000, 1, 7), carry=EnvCarry()), step_type=(1000000,), reward=(1000000,), discount=(1000000,), observation=(1000000, 7, 7, 2))
Actions shape: (1000000,)
<class 'jaxlib.xla_extension.ArrayImpl'>


In [None]:
def create_replay_buffer(transitions, actions):

  observations = transitions.observation # (T, 7, 7, 2)
  rewards = transitions.reward # (T,)
  dones = transitions.step_type == 2 # (T,)
  next_observations = jnp.concatenate([observations[1:], observations[-1:]], axis=0) #(T, 7, 7, 2)
  actions = jnp.array(actions, dtype=jnp.int32) #(T,)

  replay_buffer = {'observations': observations,
                   'actions': actions,
                   'rewards': rewards,
                   'next_observations': next_observations,
                   'dones': dones}

  print("=== Replay Buffer 构建完成 ===")
  print(f"数据点数量: {len(observations)}")
  print(f"平均奖励: {jnp.mean(rewards):.4f}")
  print(f"Episode结束次数: {jnp.sum(dones)}")
  print(f"动作分布: {jnp.bincount(actions)}")
  return replay_buffer

Potential issue with sparse reward

In [None]:
replay_buffer = create_replay_buffer(transitions, actions)

=== Replay Buffer 构建完成 ===
数据点数量: 1000000
平均奖励: 0.0028
Episode结束次数: 9572
动作分布: [167326 166610 166812 166592 166378 166282]


In [None]:
def create_batches(replay_buffer, batch_size=32, num_batches=None):
  data_size = len(replay_buffer['observations'])

  if num_batches is None:
    num_batches = max(1, data_size // batch_size)

  batches = []

  rng = jax.random.PRNGKey(0)

