In [1]:
%pip install jax
%pip install numpy
%pip install matplotlib
%pip install xminigrid
%pip install gymnax
%pip install distrax

Collecting xminigrid
  Downloading xminigrid-0.9.1-py3-none-any.whl.metadata (32 kB)
Downloading xminigrid-0.9.1-py3-none-any.whl (58 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.6/58.6 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: xminigrid
Successfully installed xminigrid-0.9.1
Collecting gymnax
  Downloading gymnax-0.0.9-py3-none-any.whl.metadata (19 kB)
Downloading gymnax-0.0.9-py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.6/86.6 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: gymnax
Successfully installed gymnax-0.0.9
Collecting distrax
  Downloading distrax-0.1.5-py3-none-any.whl.metadata (13 kB)
Downloading distrax-0.1.5-py3-none-any.whl (319 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m319.7/319.7 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: distrax
Successfully installed distrax-0.

In [2]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import distrax

import timeit
import imageio
import matplotlib.pyplot as plt
from tqdm.auto import trange, tqdm

from flax import nnx
import xminigrid

In [None]:
# class TimeStep(struct.PyTreeNode):
#     # hidden environment state, such as grid, agent, goal, etc
#     state: State

#     # similar to the dm_env enterface
#     step_type: StepType
#     reward: jax.Array
#     discount: jax.Array
#     observation: jax.Array

## Utils

### Encoders

In [4]:
import jax.nn as nn


class Encoder(nnx.Module):
  def __init__(self, input_dim: int, hidden_dim: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(input_dim, hidden_dim, rngs=rngs)
    self.layer_norm0 = nnx.LayerNorm(hidden_dim, rngs=rngs)

  def __call__(self, x: jax.Array):
    h = self.linear(x)
    return self.layer_norm0(h)

class ActionEncoder(nnx.Module):
  def __init__(self, input_dim: int, hidden_dim: int, rngs: nnx.Rngs):
    self.embed = nnx.Embed(input_dim, hidden_dim, rngs=rngs)
    self.layer_norm0 = nnx.LayerNorm(hidden_dim, rngs=rngs)

  def __call__(self, x: jax.Array):
    h = self.embed(x)
    return self.layer_norm0(h)

class JointEncoder(nnx.Module):
  def __init__(self, hidden_dim: int, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
    self.linear2 = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
    self.layer_norm0 = nnx.LayerNorm(hidden_dim, rngs=rngs)
    self.layer_norm1 = nnx.LayerNorm(hidden_dim, rngs=rngs)
    self.layer_norm2 = nnx.LayerNorm(hidden_dim, rngs=rngs)
    self.layer_norm3 = nnx.LayerNorm(hidden_dim, rngs=rngs)

  def __call__(self, x: jax.Array, rng):
    dist_distrax = distrax.MultivariateNormalDiag(loc=x, scale_diag=1e-1*jnp.ones_like(x))
    # potential shape issue
    x = dist_distrax.sample(seed=rng, sample_shape=(1,))
    x = self.layer_norm0(x)
    h0 = self.linear1(x)
    h = nn.relu(h0)
    h = self.layer_norm1(h) + h0
    h0 = self.linear2(h)
    h = self.layer_norm2(h) + h0
    return self.layer_norm3(h)

In [5]:
import jax
import jax.numpy as jnp
from jax.nn import one_hot

class ImageOneHotEncoder(nnx.Module):
    def __init__(self,
                 output_dim: int,
                 num_title_types: int,  # New: Number of categories for the first channel
                 num_colors: int,
                 rngs: nnx.Rngs,
                 hidden_dim=32,):

        self.num_title_types = num_title_types
        self.num_colors = num_colors

        self.input_channels_after_one_hot = self.num_title_types + self.num_colors

        self.conv1 = nnx.Conv(
            self.input_channels_after_one_hot, hidden_dim, kernel_size=(3, 3), strides=(1, 1), rngs=rngs)
        self.conv2 = nnx.Conv(hidden_dim, hidden_dim*2, kernel_size=(3, 3), strides=(1, 1), rngs=rngs)
        self.conv3 = nnx.Conv(hidden_dim*2, hidden_dim*4, kernel_size=(3, 3), strides=(1, 1), rngs=rngs)

        self.linear = nnx.Linear(1, output_dim, rngs=rngs)

    def __call__(self, obs_raw: jnp.ndarray) -> jnp.ndarray:

        original_ndim = obs_raw.ndim
        # ensure batch dimension
        if original_ndim == 3:
            obs_raw = jnp.expand_dims(obs_raw, axis=0)

        title_ids = obs_raw[..., 0]
        color_ids = obs_raw[..., 1]

        titles_onehot = one_hot(title_ids, num_classes=self.num_title_types)
        colors_onehot = one_hot(color_ids, num_classes=self.num_colors)

        obs_processed = jnp.concatenate(
            [titles_onehot, colors_onehot], axis=-1
        )

        x = nnx.relu(self.conv1(obs_processed))
        x = nnx.relu(self.conv2(x))
        x = nnx.relu(self.conv3(x))

        flattened_features = x.reshape(x.shape[0], -1)

        output_features = self.linear(flattened_features)

        # Remove the batch dimension if the original input didn't have one
        if original_ndim == 3:
            return jnp.squeeze(output_features, axis=0)
        else:
            return output_features

In [6]:
import jax
import jax.numpy as jnp
from jax.nn import one_hot


NUM_TITLE_TYPES = 13
NUM_COLORS = 12


class ObservationActionEncoder(nnx.Module):
    def __init__(self,
                 num_actions: int,
                 hidden_dim: int,
                 action_embedding_dim: int,
                 output_channels: int,
                 num_title_types: int,
                 num_colors: int,
                 rngs: nnx.Rngs):


        self.num_actions = num_actions
        self.action_embedding_dim = action_embedding_dim
        self.num_title_types = num_title_types
        self.num_colors = num_colors
        self.output_channels = output_channels

        self.action_linear_embed = nnx.Linear(num_actions, action_embedding_dim, rngs=rngs)
        self.action_layer_norm = nnx.LayerNorm(action_embedding_dim, rngs=rngs)

        self.obs_input_channels_after_one_hot = self.num_title_types + self.num_colors

        self.cnn_input_channels = self.obs_input_channels_after_one_hot + self.action_embedding_dim

        self.conv1 = nnx.Conv(self.cnn_input_channels, 64, kernel_size=(3, 3), strides=(1, 1), padding='SAME', rngs=rngs)
        self.conv2 = nnx.Conv(hidden_dim, hidden_dim*2, kernel_size=(3, 3), strides=(1, 1), padding='SAME', rngs=rngs)
        self.conv_final = nnx.Conv(hidden_dim*2, self.output_channels, kernel_size=(1, 1), strides=(1, 1), padding='SAME', rngs=rngs)


    def __call__(self, obs_raw: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:

        #output: (Batch, H_out, W_out, output_channels)

        original_obs_ndim = obs_raw.ndim
        if original_obs_ndim == 3:
            obs_raw = jnp.expand_dims(obs_raw, axis=0)

        title_ids = obs_raw[..., 0]
        color_ids = obs_raw[..., 1]

        titles_onehot = one_hot(title_ids, num_classes=self.num_title_types)
        colors_onehot = one_hot(color_ids, num_classes=self.num_colors)

        # obs_processed : (Batch, H, W, self.obs_input_channels_after_one_hot)
        obs_processed = jnp.concatenate(
            [titles_onehot, colors_onehot], axis=-1
        )

        action = action.astype(jnp.int32)

        action_one_hot = one_hot(action, num_classes=self.num_actions)

        action_embedding = self.action_linear_embed(action_one_hot)
        action_embedding = self.action_layer_norm(action_embedding) # (Batch, action_embedding_dim)|(action_embedding_dim,)


        H, W = obs_processed.shape[1], obs_processed.shape[2]

        if action_embedding.ndim == 1:
            action_embedding = jnp.expand_dims(action_embedding, axis=0) # (1, action_embedding_dim)

        #  action_embedding  -> (Batch, 1, 1, action_embedding_dim)
        #  (Batch, H, W, action_embedding_dim)
        action_spatial = jnp.expand_dims(jnp.expand_dims(action_embedding, axis=1), axis=1) # (Batch, 1, 1, D)
        action_spatial = jnp.tile(action_spatial, (1, H, W, 1)) # (Batch, H, W, D)

        #  (Batch, H, W, self.cnn_input_channels)
        cnn_input = jnp.concatenate([obs_processed, action_spatial], axis=-1)

        x = nnx.relu(self.conv1(cnn_input))
        x = nnx.relu(self.conv2(x))
        output_feature_map = self.conv_final(x)

        if original_obs_ndim == 3:
            return jnp.squeeze(output_feature_map, axis=0)
        else:
            return output_feature_map

### Actor

In [7]:
from jax import lax
import distrax

class Actor(nnx.Module):
  # environment related ???
  log_std_min: float = -4
  log_std_max: float = 2

  def __init__(self, obs_dim, action_dim, hidden_dim, rngs: nnx.Rngs):
    self.mean = nnx.Linear(hidden_dim, action_dim, rngs=rngs)
    self.log_std = nnx.Linear(hidden_dim, action_dim, rngs=rngs)

  def __call__(self, x: jnp.ndarray):
    mean = self.mean(x)
    log_std = jnp.clip(self.log_std(x), self.log_std_min, self.log_std_max)
    return mean, log_std

### Functions

#### computaion

In [8]:
def compute_info_gain_normal(mean, prec, l_prec, next_obs):
  prec = jnp.maximum(prec, 1e-6)
  posterior_prec = prec + l_prec
  prec_ratio = prec / posterior_prec

  posterior_mean = (prec * mean + l_prec * next_obs) / posterior_prec

  delta_mean = next_obs - posterior_mean
  kl = delta_mean * delta_mean * prec
  kl = kl + prec_ratio - jnp.log(prec_ratio) - 1
  kl = 0.5 * jnp.sum(kl, axis=-1)
  return kl, delta_mean

@jax.jit
def compute_expected_info_gain_normal(prec, l_prec):
  prec = jnp.maximum(prec, 1e-6)
  prec_ratio = l_prec / prec
  mi_matrix = 0.5 * jnp.sum(jnp.log(1+prec_ratio), axis=-1)
  return mi_matrix

jnp.set_printoptions(precision=3,suppress=True)
from flax.training import train_state
from jax.scipy.special import gamma,digamma, gammaln, kl_div

def batch_random_split(batch_key,num=2):
    split_keys = jax.vmap(jax.random.split,in_axes=(0,None))(batch_key,num)
    return [split_keys[:, i]  for i in range(num) ]


#### shape manipulation

In [9]:
import matplotlib.pyplot as plt

def reshape(arr):
  if arr.dim < 3:
    raise ValueError("Input array must have at least 3 dimensions (n, b, c, ...).")

  n, b, c, *x_dims = arr.shapes
  # Transpose the first two axes (n, b) to (b, n)
  # We construct the axes tuple dynamically for flexibility
  transpose_axes = (1, 0) + tuple(range(2, arr.ndim))
  transposed_arr = jnp.transpose(arr, axes=transpose_axes)
  # Reshape into (b, n*c, x0, x1, ...)
  new_shape = (b, n * c, *x_dims)
  reshaped_arr = jnp.reshape(transposed_arr, new_shape)

  return reshaped_arr

from typing import List, Any

# Define a type alias for PyTree for better readability
PyTree = Any
from typing import List, Any

# Define a type alias for PyTree for better readability
PyTree = Any
def unpack_pytree_by_first_index(pytree: PyTree) -> List[PyTree]:
    """
    Unpacks a PyTree of JAX arrays along their first dimension (id).

    This function assumes that all JAX arrays within the PyTree
    have a consistent first dimension (the 'id' dimension) and that
    you want to create a separate PyTree for each 'id'.

    Args:
        pytree: A JAX PyTree where the leaves are JAX arrays
                with a leading 'id' dimension.

    Returns:
        A list of PyTrees, where each PyTree corresponds to a single
        'id' from the original PyTree.
    """
    # Get the size of the first dimension from any leaf array
    # We assume all arrays have the same first dimension size.
    first_leaf = jax.tree_util.tree_leaves(pytree)[0]
    num_ids = first_leaf.shape[0]

    # Create a list to store the unpacked PyTrees
    unpacked_pytrees = []

    # Iterate through each ID
    for i in range(num_ids):
        # Use tree_map to slice each array in the PyTree at the current ID
        sliced_pytree = jax.tree_util.tree_map(lambda x: x[i], pytree)
        unpacked_pytrees.append(sliced_pytree)

    return unpacked_pytrees

def unpack_states(pytree):
    return unpack_pytree_by_first_index(jax.tree.map(reshape, pytree))

#### drawing

In [10]:
def draw_mountain_car_heatmap(state,config = {}):
    """
    Draws a heatmap representing the trajectory of the MountainCar environment.

    Args:
        state_sequence: A sequence of JAX arrays representing the states
                        of the MountainCar environment. Each state is expected
                        to be a 2-element array [position, velocity].
                        ['CartPole-v1',"MountainCar-v0","Acrobot-v1"]
    """
    title = config["ENV_NAME"] +' MountainCar Heatmap ' +config["MODEL_NAME"]

    plt.figure(figsize=(10, 6))
    if config["ENV_NAME"] == "MountainCar-v0":

        positions = state.position
        velocities = state.velocity

        plt.scatter(positions, velocities, c=range(len(state.time )), cmap='viridis', s=10)
        plt.colorbar(label='Time Steps')
        plt.xlabel('Position')
        plt.ylabel('Velocity')
        plt.grid(True)
    elif config["ENV_NAME"] == "CartPole-v1":
        x = state.x
        theta = state.theta
        plt.scatter(x, theta, c=range(len(state.time )), cmap='viridis', s=10)
        plt.colorbar(label='Time Steps')
        plt.xlabel('x')
        plt.ylabel('theta')
        plt.grid(True)
    elif config["ENV_NAME"] == "Acrobot-v1":
        joint_angle1 = state.joint_angle1
        joint_angle2 = state.joint_angle2
        plt.scatter(joint_angle1, joint_angle2, c=range(len(state.time )), cmap='viridis', s=10)
        plt.colorbar(label='Time Steps')
        plt.xlabel('Angle1')
        plt.ylabel('Angle2')
        plt.grid(True)
    if "TOTAL_TIMESTEPS" in config:
        title += "_TOTAL_TIMESTEPS_"+str(config["TOTAL_TIMESTEPS"])
    if "DEPTH" in config:
        title += "_DEPTH_"+str(config["DEPTH"])
    if "NUM_HIDDEN" in config:
        title += "_NUM_HIDDEN_"+str(config["NUM_HIDDEN"])
    plt.title(title)
    plt.savefig(title.replace(" ","_")+'.pdf', format='pdf', dpi=300, bbox_inches='tight')
    plt.show()
    return plt

### Others

In [11]:
class Likelihood_Prec(nnx.Module):
  log_std_min: float = -2
  log_std_max: float = 2

  def __init__(self, obs_dim: int, hidden_dim: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(hidden_dim, obs_dim, rngs)

  def __call__(self, x: jnp.ndarray):
    log_std = jnp.clip(self.linear(x), self.log_std_min, self.log_std_max)
    return jnp.exp(-log_std)

## Unsupervised Explorer

In [13]:
from gymnax.experimental import RolloutWrapper
# action = self.model_forward(policy_params, obs, rng_net)
import functools
import gymnax
from typing import Union,Optional,Any
import abc

In [38]:
class UnsupervisedExplorer(nnx.Module):

  @abc.abstractmethod
  def update(self, obs, actions, next_obs, dones, info):
    # update variable parameters
    return #{'kl':KL} MI= E[KL]

  @abc.abstractmethod
  def __call__(self, observations, rng):
    return #actions, {"mi":mi_matrix}

class RandomExplorer(UnsupervisedExplorer):

  def __init__(self, num_actions):
    self.num_actions = num_actions

  def update(self, rng, obs, actions, next_obs, dones, info):
    return {}

  def __call__(self, observations, rng):
    if observations.ndim == 1:
      # possible shape issue here
      actions = jax.random.randint(rng, shape=(1,), minval=0, maxval=self.num_actions)
      return actions, {}
    actions = jax.random.randint(rng, shape=(observations.shape[0],), minval=0, maxval=self.num_actions)
    return actions, {}


class DeepSACBayesianExplorer(UnsupervisedExplorer):
  # ent?
  def __init__(self, obs_dim, num_actions, hidden_dim, rngs: nnx.Rngs,
               l_prec=1.0, wd=1e-2, ent_lambda=1e-3, depth=2):
    self.obs_dim = obs_dim
    self.num_actions = num_actions
    self.hidden_dim = hidden_dim
    self.prec_w = nnx.Variable(jnp.zeros((hidden_dim, obs_dim)), name='prec_w')
    self.mean_w = nnx.Variable(jnp.zeros((hidden_dim, obs_dim)), name='mean_w')
    # what is trainable here
    self.trainable_likelihood_prec = Likelihood_Prec(obs_dim, hidden_dim, rngs)
    self.weight_decay = wd
    self.obs_embeds = Encoder(obs_dim, hidden_dim, rngs)
    self.action_embeds = ActionEncoder(num_actions, hidden_dim, rngs)
    self.joint_embeds = JointEncoder(hidden_dim, rngs)
    self.depth = depth
    self.ent_lambda = ent_lambda

  def __call__(self,observations,rng):
      return self.recursive_mi(observations,rng,self.depth)

  def update(self, rng, obs, action, next_obs, done, info):
    mean = info["mean"]
    prec = info["prec"]

    def _likelihood_loss(rng, T, mean, prec, next_obs):
      l_prec = self.trainable_likelihood_prec(T)
      mu = mean
      # model var + inherent var
      sigma = jnp.sqrt(1 / l_prec + 1 / prec)
      dist_distrax = distrax.MultivariateNormalDiag(loc=mu, scale_diag=sigma)
      log_prob = dist_distrax.log_prob(next_obs)
      return -log_prob, l_prec

    # jit here
    predictive_loss, l_prec = _likelihood_loss(rng, info["T"], mean, prec, next_obs)
    # originally jnp.sum
    mean_error = jnp.mean((mean - next_obs)**2)
    deepkl, delta_mean = compute_info_gain_normal(mean, prec, l_prec, next_obs)
    # batch x num_hidden
    T = info["T"].reshape(-1, self.hidden_dim)

    # batch x obs_dim
    l_prec = l_prec.reshape(-1, self.obs_dim)
    delta_mean = delta_mean.reshape(-1, self.obs_dim)

    T_T = jnp.transpose(T)
    covariance = T @ T_T
    inv_covariance = jnp.linalg.pinv(covariance)

    T_Map = T_T @ inv_covariance

    delta_precW = T_Map @ l_prec
    self.prec_w.value = (self.prec_w.value + delta_precW) * (1-self.weight_decay)
    delta_meanW = T_Map @ delta_mean
    self.mean_w.value = (self.mean_w.value + delta_meanW) * (1-self.weight_decay)

    return {"kl":deepkl,  "predictive_loss": predictive_loss, "mean_error":mean_error}

  # jitable
  def loss(self, rng, obs, action, next_obs, done, info):
    def _likelihood_loss(T, mean, prec, next_obs):
      l_prec = self.trainable_likelihood_prec(T)

      mu = mean
      sigma = jnp.sqrt(1 / l_prec + 1 / prec)
      dist_distrax = distrax.MultivariateNormalDiag(loc=mu, scale_diag=sigma)

      log_prob = dist_distrax.log_prob(next_obs)
      return -log_prob

    T, mean, prec = info["T"], info["mean"], info["prec"]
    likelihood_loss = _likelihood_loss(T, mean, prec, next_obs)
    return likelihood_loss

  def batch_loss(self, rng, obs, actions, next_obs, dones, info):
    vmapped = jax.vmap(self.loss)
    return vmapped(rng, obs, actions, next_obs, dones, info)

  def recursive_mi(self, observations, rng, depth):
    obs_embed = self.obs_embed(observations)
    action_embed = self.action_embed(jnp.arange(self.num_actions))
    # possible shape issue
    embed = action_embed + jnp.expand_dims(obs_embed, axis=0)

    # num_actions x embed_size
    T = self.joint_embeds(embed, rng)
    prec = jnp.maximum(T @ self.prec_w, 1e-3)
    # num_actions x obs_dim
    mean = T @ self.mean_w
    l_prec = self.trainable_likelihood_prec(T)

    MI = compute_expected_info_gain_normal(prec, l_prec)

    if depth > 0:
      vmapped = jax.vmap(self.recursive_mi, in_axes=(0,None,None))
      # num_actions x 1
      actions, info = vmapped(mean, rng, depth-1)
      MI = MI + info["mi"]

    actions = jnp.argmax(MI, axis=0)
    T = T[actions]
    MI = MI[actions]
    l_prec = l_prec[actions]
    prec = prec[actions]
    mean = mean[actions]
    return actions, {"mi":MI,"T":T,"obs_embed":obs_embed,"l_prec":l_prec,
                        "prec":prec,"mean":mean}

def show_variable(model, text):

    graphdef, params, vars,others = nnx.split(model, nnx.Param, nnx.Variable,...)

    print(text,vars)


Temporarily requires oof_dim = hidden_dim

In [15]:
class XlandRandomExplorer(UnsupervisedExplorer):

  def __init__(self, num_actions):
    self.num_actions = num_actions

  def update(self, rng, obs, actions, next_obs, dones, info):
    return {}

  def __call__(self, observations, rng):
      batch_size = 1
      if observations.ndim == 4:
          batch_size = observations.shape[0]
      elif observations.ndim == 3:
          batch_size = 1
      else:
          if observations.ndim == 1: # Single 1D feature vector
                batch_size = 1
          elif observations.ndim == 2: # Batch of 1D feature vectors
                batch_size = observations.shape[0]
          else:
              raise ValueError(f"Unsupported observation dimension: {observations.ndim}. Expected 3 (HWC) or 4 (BHWC) or 1 (flat) or 2 (batch_flat).")

      actions = jax.random.randint(rng, shape=(batch_size,), minval=0, maxval=self.num_actions)
      return actions, {}


In [40]:
class XlandDeepSACBayesianExplorer(UnsupervisedExplorer):
  # ent?
  def __init__(self, obs_dim_raw_shape, oof_dim, num_actions, hidden_dim, rngs: nnx.Rngs,
               l_prec=1.0, wd=1e-2, ent_lambda=1e-3, depth=2, num_title_types: int=13, num_colors: int=12):

    self.obs_dim_raw_shape = obs_dim_raw_shape
    self.num_actions = num_actions
    self.hidden_dim = hidden_dim
    self.output_obs_feature_dim = oof_dim


    self.prec_w = nnx.Variable(jnp.zeros((hidden_dim, self.output_obs_feature_dim)), name='prec_w')
    self.mean_w = nnx.Variable(jnp.zeros((hidden_dim, self.output_obs_feature_dim)), name='mean_w')

    self.obs_embed = ImageOneHotEncoder(self.output_obs_feature_dim, num_title_types, num_colors, rngs, hidden_dim)
    # what is trainable here
    self.trainable_likelihood_prec = Likelihood_Prec(self.output_obs_feature_dim, hidden_dim, rngs)
    self.weight_decay = wd
    self.action_embeds = ActionEncoder(num_actions, hidden_dim, rngs)
    self.joint_embeds = JointEncoder(hidden_dim, rngs)
    self.depth = depth
    self.ent_lambda = ent_lambda

  def __call__(self,observations,rng):
    return self.recursive_mi(observations,rng,self.depth)

  def update(self, rng, obs, action, next_obs, done, info):
    mean = info["mean"]
    prec = info["prec"]

    def _likelihood_loss(rng, T, mean, prec, next_obs):
      l_prec = self.trainable_likelihood_prec(T)
      mu = mean
      # model var + inherent var
      sigma = jnp.sqrt(1 / l_prec + 1 / prec)
      dist_distrax = distrax.MultivariateNormalDiag(loc=mu, scale_diag=sigma)
      next_obs_features = self.obs_embed(next_obs)
      log_prob = dist_distrax.log_prob(next_obs_features)
      return -log_prob, l_prec

    # jit here
    predictive_loss, l_prec = _likelihood_loss(rng, info["T"], mean, prec, next_obs)
    # originally jnp.sum
    mean_error = jnp.mean((mean - next_obs)**2)
    deepkl, delta_mean = compute_info_gain_normal(mean, prec, l_prec, next_obs)
    # batch x num_hidden
    T = info["T"].reshape(-1, self.hidden_dim)

    # batch x obs_dim
    l_prec = l_prec.reshape(-1, self.obs_dim)
    delta_mean = delta_mean.reshape(-1, self.obs_dim)

    T_T = jnp.transpose(T)
    covariance = T @ T_T
    inv_covariance = jnp.linalg.pinv(covariance)

    T_Map = T_T @ inv_covariance

    delta_precW = T_Map @ l_prec
    self.prec_w.value = (self.prec_w.value + delta_precW) * (1-self.weight_decay)
    delta_meanW = T_Map @ delta_mean
    self.mean_w.value = (self.mean_w.value + delta_meanW) * (1-self.weight_decay)

    return {"kl":deepkl,  "predictive_loss": predictive_loss, "mean_error":mean_error}

  # jitable
  def loss(self, rng, obs, action, next_obs, done, info):
    def _likelihood_loss(T, mean, prec, next_obs):
      l_prec = self.trainable_likelihood_prec(T)

      mu = mean
      sigma = jnp.sqrt(1 / l_prec + 1 / prec)
      dist_distrax = distrax.MultivariateNormalDiag(loc=mu, scale_diag=sigma)
      next_obs_features = self.obs_embed(next_obs)
      log_prob = dist_distrax.log_prob(next_obs_features)
      return -log_prob

    T, mean, prec = info["T"], info["mean"], info["prec"]
    likelihood_loss = _likelihood_loss(T, mean, prec, next_obs)
    return likelihood_loss

  def batch_loss(self, rng, obs, actions, next_obs, dones, info):
    vmapped = jax.vmap(self.loss)
    return vmapped(rng, obs, actions, next_obs, dones, info)

  def recursive_mi(self, observations, rng, depth):
    obs_embed = self.obs_embed(observations)
    action_embed = self.action_embed(jnp.arange(self.num_actions))
    # possible shape issue
    embed = action_embed + jnp.expand_dims(obs_embed, axis=0)

    # num_actions x embed_size
    T = self.joint_embeds(embed, rng)
    prec = jnp.maximum(T @ self.prec_w, 1e-3)
    # num_actions x obs_dim
    mean = T @ self.mean_w
    l_prec = self.trainable_likelihood_prec(T)

    MI = compute_expected_info_gain_normal(prec, l_prec)

    if depth > 0:
      vmapped = jax.vmap(self.recursive_mi, in_axes=(0,None,None))
      # num_actions x 1
      actions, info = vmapped(mean, rng, depth-1)
      MI = MI + info["mi"]

    actions = jnp.argmax(MI, axis=0)
    T = T[actions]
    MI = MI[actions]
    l_prec = l_prec[actions]
    prec = prec[actions]
    mean = mean[actions]
    return actions, {"mi":MI,"T":T,"obs_embed":obs_embed,"l_prec":l_prec,
                        "prec":prec,"mean":mean}

## Wrapper

In [17]:
from xminigrid.environment import Environment
from typing import Union,Optional,Any
import abc

class CustomRolloutWrapper:
    """Wrapper to define batch evaluation for generation parameters."""

    def __init__(
        self,
        env_or_name: Union[str,Environment] = "Pendulum-v1",
        num_env_steps: Optional[int] = None,
        env_kwargs: Any | None = None,
        env_params: Any | None = None,
    ):
        """Wrapper to define batch evaluation for generation parameters."""
        # Define the RL environment & network forward function
        if env_kwargs is None:
            env_kwargs = {}
        if env_params is None:
            env_params = {}
        if isinstance(env_or_name,Environment):
            self.env = env_or_name
            self.env_params = env_or_name.default_params
        else:
            self.env, self.env_params = xminigrid.make(env_or_name, **env_kwargs)
        self.env_params = self.env_params.replace(**env_params)

        if num_env_steps is None:
            self.num_env_steps = self.env_params.max_steps
        else:
            self.num_env_steps = num_env_steps

    def batch_reset(self, rng_input):
        batch_rest = jax.vmap(self.single_reset_state)
        return batch_rest(rng_input)

    # state vs. timestep, potential issue here
    def single_reset_state(self, rng_input):
        rng_reset, rng_episode = jax.random.split(rng_input)
        timestep = self.env.reset(self.env_params, rng_reset)
        return timestep

    def batch_rollout(self, rng_eval, model:UnsupervisedExplorer, timestep=None, num_steps=1):
        batch_rollout = jax.vmap(self.single_rollout, in_axes=(0,None,None,None))
        return batch_rollout(rng_eval, model, timestep, num_steps)

    def single_rollout(self, rng_eval, model:UnsupervisedExplorer, timestep=None, num_steps=1):
        rng_reset, rng_episode = jax.random.split(rng_eval)

        if timestep is None:
          timestep = self.env.reset(rng_reset, self.env_params)
        else:
          obs = timestep.observation

        def policy_step(state_input, _):
          obs, timestep, rng, cum_reward, valid_mask = state_input
          rng, rng_step, rng_net = jax.random.split(rng, 3)
          if model is not None:
            action, info = model(obs, rng_net)
          else:
            # not action space?
            action = action = jax.random.randint(rng_step, shape=(), minval=0, maxval=self.env.num_actions(env_params))
            info = {}

          next_timestep = self.env.step(self.env_params, timestep, action)
          next_obs = next_timestep.observation
          reward = next_timestep.reward
          done = next_timestep.step_type == 2

          info.update({"discount": next_timestep.discount})
          new_cum_reward = cum_reward + reward * valid_mask
          new_valid_mask = valid_mask * (1- done)
          carry = [next_obs, next_timestep, rng, new_cum_reward, new_valid_mask]
          y = [obs, action, reward, next_obs, done, timestep, info]

          return carry, y

        carry_out, scan_out = jax.lax.scan(policy_step, [obs, timestep, rng_episode, jnp.array([0.0]), jnp.array([1.0])], (), num_steps)
        obs, action, reward, next_obs, done, timestep, info = scan_out
        cum_return = carry_out[-2]
        info["last_timestep"] = carry_out[1]

        return obs, action, reward, next_obs, done, timestep, info, cum_return



In [18]:
class UnsupervisedRolloutWrapper(CustomRolloutWrapper):
  def batch_update(self, rng_update, model, obs, action, next_obs, done, info):
    if model is None: return {}
    return model.update(rng_update, obs, action, next_obs, done, info)

## Exploration

In [19]:
import jax
import jax.numpy as jnp
jnp.set_printoptions(precision=2,suppress=True)
from jax.scipy.special import digamma, gammaln, kl_div
import flax.linen as nn
import numpy as np
import optax
import time
import flax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any, Dict
import distrax
import gymnax
import functools
from gymnax.environments import spaces
from gymnax.wrappers import FlattenObservationWrapper, LogWrapper
import matplotlib.pyplot as plt

import matplotlib.pyplot as plt

import optax
from flax.nnx.helpers import TrainState


### Preparation

In [20]:
class MyTrainState(TrainState):
    vars: nnx.Variable
    others: nnx.State

    @property
    def need_train(self):
        return len(self.params) > 0

is_trainable = lambda path, node: ( node.type == nnx.Param and
    True in [ 'trainable' in t for t in path] )

In [21]:
def train_state_from_model(model, tx=optax.adam(0.02)):
  graphdef, trainable_params, vars, others = nnx.split(model, is_trainable, nnx.Variable,...)
  return MyTrainState.create(params=trainable_params, tx=tx, vars=vars, others=others, graphdef=graphdef)

def train_state_update_model(model,state):
    graphdef, trainable_params, vars, others = nnx.split(model,is_trainable, nnx.Variable,...)
    return state.replace(vars=vars,others=others)

def model_from_train_state(state):
    return nnx.merge(state.graphdef, state.params, state.vars,state.others)

In [22]:
# NUM_UPDATES x NUM_ENVS x NUM_STEPS
class Transition(NamedTuple):
    obs: jnp.ndarray
    action: jnp.ndarray
    reward: jnp.ndarray
    next_obs: jnp.ndarray
    done: jnp.ndarray
    info: {}

### Training

In [23]:
def make_train(config):
  config["NUM_UPDATES"] = (config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"]// config["NUM_ENVS"])

  rng = jax.random.PRNGKey(config["SEED"])
  rng_batch = jax.random.split(rng, config["NUM_ENVS"])

  manager = UnsupervisedRolloutWrapper(env_or_name=config["ENV_NAME"])
  num_actions = manager.env.num_actions(manager.env_params)
  obs_dim = manager.env.observation_shape(manager.env_params)

  # model
  if config["MODEL_NAME"] == "XlandDeepSACBayesianExplorer":
    model = XlandDeepSACBayesianExplorer(obs_dim,
                                         oof_dim=config["NUM_OOF"],
                                         num_actions=num_actions,
                                         hidden_dim=config["NUM_HIDDEN"],
                                         rngs=nnx.Rngs(config["SEED"]),
                                         wd=config["WD"],
                                         depth=config["DEPTH"])
  else:
    model = XlandRandomExplorer(num_actions)

  @nnx.jit
  def _train_step(state:MyTrainState, rng_loss, obs, action,next_obs,done,info):

    def loss_fn(graphdef, params, vars, others):
      model = nnx.merge(graphdef, params, vars, others)
      return model.batch_loss(rng_loss,obs, action,next_obs,done,info).mean()

    def opt_step(state:MyTrainState, unused):
      grads = jax.grad(loss_fn, 1)(state.graphdef, state.params, state.vars, state.others)
      return state.apply_gradients(grads=grads), None

    state, _ = jax.lax.scan(opt_step, state, None, config["OPT_STEPS"])
    return state

  @nnx.jit
  def _rollout_and_update_step(runner_state, unused):
    train_state, rng_batch, last_timestep= runner_state

    model = model_from_train_state(train_state)
    rng_batch, rng_step, rng_update, rng_loss = batch_random_split(rng_batch, 4)

    rollout_results = manager.batch_rollout(rng_batch, model, timestep=last_timestep, num_steps=config["NUM_STEPS"])
    obs, action, reward, next_obs, done, timestep, info, cum_return = rollout_results

    transition = Transition(obs, action, reward, next_obs, done, info)
    last_timestep = info["last_timestep"]

    update_info = manager.batch_update(rng_update, model, obs, action, next_obs, done, info)
    info.update(update_info)
    train_state = train_state_update_model(model, train_state)

    if train_state.need_train:
      train_state = _train_step(train_state, rng_loss, obs, action, next_obs, done, info)

    runner_state = (train_state, rng_batch, last_timestep)
    return runner_state, (transition, timestep)

  def train(rng_batch, model, manager):

    rng_batch, rng_reset = batch_random_split(rng_batch, 2)
    start_timestep = manager.batch_reset(rng_reset)

    if config["TX"] == "adamw":
      tx = optax.adamw(config["LR"])
    elif config["TX"] == "sgd":
      tx = optax.sgd(config["LR"])
    else:
      tx = None
      assert False, config["TX"] + "is not available"
    train_state = train_state_from_model(model, tx)
    runner_state = (train_state, rng_batch, start_timestep)
    runner_state, output = jax.lax.scan(_rollout_and_update_step, runner_state, None, config["NUM_UPDATES"])

    transitions, timesteps = output
    return {"runner_state": runner_state, "transitions": transitions, "timesteps": timesteps}

  return train, model, manager, rng_batch


In [28]:
def experiment(config):
  print(config)
  train_fn, model, manager, rng_batch = make_train(config)
  train_jit = nnx.jit(train_fn)

  out = jax.block_until_ready(train_fn(rng_batch, model, manager))
  print("data shape:", jax.tree_util.tree_map(lambda x: x.shape, out["transitions"]))

  train_state, rng_batch, last_timestep = out["runner_state"]

  model = model_from_train_state(train_state)

  if "mi" in out["transitions"].info:
    # Create figure and axis
        plt.figure(figsize=(10, 6))
        # Sample JAX NumPy arrays (replace these with your actual arrays)
        #  print (out["transitions"].info)
        eig_array = out["transitions"].info["mi"].reshape(-1)
        big_array = out["transitions"].info["kl"].reshape(-1)
        # Plot both arrays
        plt.plot(eig_array, label='EIG', marker='o', linestyle='-', color='blue')
        plt.plot(big_array, label='BIG', marker='s', linestyle='-', color='red')

        if "smi" in out["transitions"].info:
          smi_array = out["transitions"].info["smi"].reshape(-1)
          plt.plot(smi_array, label='SMI', marker='^', linestyle='-', color='green')

        # add labels and title
        plt.xlabel('Num of Updates')
        plt.ylabel('Information Gain')
        Title = "InfoGains for" + config["MODEL_NAME"]
        Title = Title + "Total InfoGains" + "{:10.4f}".format(big_array.sum().item())
        Title = Title +  " with Seed" +str(config["SEED"])
        plt.title(Title)

        # add grid and legend
        plt.grid(alpha=0.3)
        plt.legend()

        plt.tight_layout()
        plt.savefig(Title.replace(" ","_")+'.pdf', format='pdf', dpi=300, bbox_inches='tight')
        plt.show()

  if "l_prec" in out["transitions"].info:
      l_prec_mean = out["transitions"].info["l_prec"].mean(axis=(1,2,3), keepdims=False)
      mean_error = out["transitions"].info["mean_error"].mean(axis=(1,2), keepdims=False)

      plt.figure(figsize=(10, 6))
      plt.plot(l_prec_mean, label='L_prec', marker='o', linestyle='-', color='blue')
      plt.plot(mean_error, label='Mean Error', marker='s', linestyle='-', color='yellow')

      plt.xlabel('Num of Updates')
      plt.ylabel('Mean Precision')
      Title = "Comparison of Mean Precisions"

      plt.title(Title)

      plt.grid(alpha=0.3)
      plt.legend()

      plt.tight_layout()
      plt.savefig(Title.replace(" ","_")+'.pdf', format='pdf', dpi=300, bbox_inches='tight')
      plt.show()

  return out

In [35]:
env_name = "MiniGrid-EmptyRandom-8x8"
NUM_ENVS = 1 # @param[1,2,4,8,16,32]
TOTAL_TIMESTEPS = 16384 # @param [2048,16384,131072,1048576] {"type":"raw"}
DEPTH = 1 # @param [1,2,4] {"type":"raw"}
NUM_STEPS = 8 # @param [1,2,4,8,16] {"type":"raw"}
NUM_HIDDEN = 128 # @param [32,64,128,256] {"type":"raw"}
WD = 0.1 # @param [0,0.1,0.01,0.001] {"type":"raw"}
MODEL_NAME = "XlandDeepSACBayesianExplorer"  #@param ["DeepSACBayesianExplorer","RandomExplorer","XlandDeepSACBayesianExplorer"]
config = {
    "NUM_ENVS": NUM_ENVS,    #
    "WD": WD,
    "NUM_STEPS": NUM_STEPS,   #steps of roll out between update
    "NUM_OOF": NUM_HIDDEN, # num hidden for now
    "SAC_D_STEPS": 4,
    "ENV_NAME":env_name,
    "SAC_STEP_SIZE": 1.0,
    "SEED": 423,         #highly stochastic
    "TOTAL_TIMESTEPS": TOTAL_TIMESTEPS,   #total steps for all envs
    "NUM_HIDDEN":NUM_HIDDEN,
    "TX":"adamw",
    "DEPTH":DEPTH,
    "LR":2e-4,
    "OPT_STEPS":8,
    "MODEL_NAME": MODEL_NAME,
    "DEBUG": False,
}

In [41]:
out = experiment(config)

{'NUM_ENVS': 1, 'WD': 0.1, 'NUM_STEPS': 8, 'NUM_OOF': 128, 'SAC_D_STEPS': 4, 'ENV_NAME': 'MiniGrid-EmptyRandom-8x8', 'SAC_STEP_SIZE': 1.0, 'SEED': 423, 'TOTAL_TIMESTEPS': 16384, 'NUM_HIDDEN': 128, 'TX': 'adamw', 'DEPTH': 1, 'LR': 0.0002, 'OPT_STEPS': 8, 'MODEL_NAME': 'XlandDeepSACBayesianExplorer', 'DEBUG': False, 'NUM_UPDATES': 2048}


TypeError: Linear.__init__() takes 3 positional arguments but 4 were given

## IQL

In [None]:
import os
import time
from functools import partial
from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple

import distrax
import flax
import flax.linen as nn

import jax
import jax.numpy as jnp
import numpy as np
import optax
import tqdm
import wandb
from flax.training.train_state import TrainState
from omegaconf import OmegaConf
from pydantic import BaseModel

os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"


### Config

In [None]:
class IQLConfig(BaseModel):
    # GENERAL
    algo: str = "IQL"
    project: str = "train-IQL"
    env_name: str = "MiniGrid-EmptyRandom-6x6"
    seed: int = 42
    eval_episodes: int = 5
    log_interval: int = 100
    eval_interval: int = 100000
    batch_size: int = 256
    max_steps: int = int(1e6)
    n_jitted_updates: int = 8
    # DATASET
    data_size: int = int(1e6)
    normalize_state: bool = False
    normalize_reward: bool = True
    # NETWORK
    hidden_dims: Tuple[int, int] = (256, 256)
    actor_lr: float = 3e-4
    value_lr: float = 3e-4
    critic_lr: float = 3e-4
    layer_norm: bool = True
    opt_decay_schedule: bool = True
    # IQL SPECIFIC
    expectile: float = (
        0.7  # FYI: for Hopper-me, 0.5 produce better result. (antmaze: expectile=0.9)
    )
    beta: float = (
        3.0  # FYI: for Hopper-me, 6.0 produce better result. (antmaze: beta=10.0)
    )
    tau: float = 0.005
    discount: float = 0.99

    def __hash__(
        self,
    ):  # make config hashable to be specified as static_argnums in jax.jit.
        return hash(self.__repr__())


conf_dict = OmegaConf.from_cli()
config = IQLConfig(**conf_dict)

### Networks

In [None]:
def default_init(scale: Optional[float] = jnp.sqrt(2)):
    return nn.initializers.orthogonal(scale)


class MLP(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    activate_final: bool = False
    kernel_init: Callable[[Any, Sequence[int], Any], jnp.ndarray] = default_init()
    layer_norm: bool = False

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for i, hidden_dims in enumerate(self.hidden_dims):
            x = nn.Dense(hidden_dims, kernel_init=self.kernel_init)(x)
            if i + 1 < len(self.hidden_dims) or self.activate_final:
                if self.layer_norm:  # Add layer norm after activation
                    x = nn.LayerNorm()(x)
                x = self.activations(x)
        return x


class Critic(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    @nn.compact
    def __call__(self, observations: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray:
        batch_size = observations.shape[0]
        actions = jax.nn.one_hot(actions, num_classes=4) #one-hot encoding
        flat_observations = observations.reshape(batch_size, -1)
        inputs = jnp.concatenate([flat_observations, actions], axis=-1)
        critic = MLP((*self.hidden_dims, 1), activations=self.activations)(inputs)
        return jnp.squeeze(critic, -1)


def ensemblize(cls, num_qs, out_axes=0, **kwargs):
    split_rngs = kwargs.pop("split_rngs", {})
    return nn.vmap(
        cls,
        variable_axes={"params": 0},
        split_rngs={**split_rngs, "params": True},
        in_axes=None,
        out_axes=out_axes,
        axis_size=num_qs,
        **kwargs,
    )


class ValueCritic(nn.Module):
    hidden_dims: Sequence[int]
    layer_norm: bool = False

    @nn.compact
    def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
        batch_size = observations.shape[0]
        obs_flat = observations.reshape(batch_size, -1)
        critic = MLP((*self.hidden_dims, 1), layer_norm=self.layer_norm)(obs_flat)
        return jnp.squeeze(critic, -1)


class GaussianPolicy(nn.Module):
    hidden_dims: Sequence[int]
    action_dim: int
    log_std_min: Optional[float] = -5.0
    log_std_max: Optional[float] = 2

    @nn.compact
    def __call__(
        self, observations: jnp.ndarray, temperature: float = 1.0
    ) -> distrax.Distribution:
        outputs = MLP(
            self.hidden_dims,
            activate_final=True,
        )(observations)

        means = nn.Dense(
            self.action_dim, kernel_init=default_init()
        )(outputs)
        log_stds = self.param("log_stds", nn.initializers.zeros, (self.action_dim,))
        log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)

        distribution = distrax.MultivariateNormalDiag(
            loc=means, scale_diag=jnp.exp(log_stds) * temperature
        )
        return distribution

class CatPolicy(nn.Module):
  hidden_dims : Sequence[int]
  action_dim: int

  @nn.compact
  def __call__(self, observations: jnp.ndarray, temperature: float = 1.0) -> distrax.Distribution:
    x = observations.reshape(observations.shape[0], -1) # flatten
    outputs = MLP(self.hidden_dims, activate_final=True)(x)
    logits = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)
    distribution = distrax.Categorical(logits=logits)
    return distribution


### Utils

In [None]:
print(jtu.tree_map(jnp.shape, replay_buffer))
print(type(replay_buffer))
print(replay_buffer["dones"])

In [None]:
class Transition(NamedTuple):
    observations: jnp.ndarray
    actions: jnp.ndarray
    rewards: jnp.ndarray
    next_observations: jnp.ndarray
    dones: jnp.ndarray
    dones_float: jnp.ndarray

In [None]:
def get_normalization(dataset: Transition) -> float:
    # into numpy.ndarray
    dataset = jax.tree_util.tree_map(lambda x: np.array(x), dataset)
    returns = []
    ret = 0
    for r, term in zip(dataset.rewards, dataset.dones_float):
        ret += r
        if term:
            returns.append(ret)
            ret = 0
    return (max(returns) - min(returns)) / 1000

In [None]:
def preprocess_dataset(
     dataset: dict, config: IQLConfig, clip_to_eps: bool = True, eps: float = 1e-5
) -> Transition:

    if clip_to_eps:
        lim = 1 - eps
        dataset["actions"] = jnp.clip(dataset["actions"], -lim, lim)

    # dones_float = np.zeros_like(dataset['dones'])

    # # for i in range(len(dones_float) - 1):
    # #     print(i)
    # #     if np.linalg.norm(dataset['observations'][i + 1] -
    # #                         dataset['next_observations'][i]
    # #                         ) > 1e-6 or dataset['dones'][i] == True:
    # #         dones_float[i] = 1
    # #     else:
    # #         dones_float[i] = 0
    # dones_float[-1] = 1

    obs = dataset['observations']         # shape: (N, 7, 7, 2)
    obs = dataset['observations']         # shape: (N, 7, 7, 2)
    next_obs = dataset['next_observations']  # shape: (N, 7, 7, 2)
    dones = dataset['dones']              # shape: (N,)

    # 展平每个 observation
    obs_flat = obs[1:].reshape((obs.shape[0] - 1, -1))           # shape: (N-1, 98)
    next_obs_flat = next_obs[:-1].reshape((next_obs.shape[0] - 1, -1))  # shape: (N-1, 98)

    # 对每个样本求 L2 范数
    obs_diff = jnp.linalg.norm(obs_flat - next_obs_flat, axis=1)   # shape: (N-1,)
    obs_flag = obs_diff > 1e-6
    done_flag = dones[:-1] == True

    dones_float = jnp.zeros_like(dones, dtype=jnp.float32)
    dones_float = dones_float.at[:-1].set(jnp.logical_or(obs_flag, done_flag).astype(jnp.float32))
    dones_float = dones_float.at[-1].set(1.0)

    dataset = Transition(
        observations=jnp.array(dataset["observations"], dtype=jnp.float32),
        actions=jnp.array(dataset["actions"], dtype=jnp.float32),
        rewards=jnp.array(dataset["rewards"], dtype=jnp.float32),
        next_observations=jnp.array(dataset["next_observations"], dtype=jnp.float32),
        dones=jnp.array(dataset["dones"], dtype=jnp.float32),
        dones_float=jnp.array(dones_float, dtype=jnp.float32),
    )

    # normalize states
    # obs_mean, obs_std = 0, 1
    # if config.normalize_state:
    #     obs_mean = dataset.observations.mean(0)
    #     obs_std = dataset.observations.std(0)
    #     dataset = dataset._replace(
    #         observations=(dataset.observations - obs_mean) / (obs_std + 1e-5),
    #         next_observations=(dataset.next_observations - obs_mean) / (obs_std + 1e-5),
    #     )
    # # normalize rewards
    # if config.normalize_reward:
    #     normalizing_factor = get_normalization(dataset)
    #     dataset = dataset._replace(rewards=dataset.rewards / normalizing_factor)

    # shuffle data and select the first data_size samples
    # data_size = min(config.data_size, len(dataset.observations))
    # rng = jax.random.PRNGKey(config.seed)
    # rng, rng_permute, rng_select = jax.random.split(rng, 3)
    # perm = jax.random.permutation(rng_permute, len(dataset.observations))
    # dataset = jax.tree_util.tree_map(lambda x: x[perm], dataset)
    # assert len(dataset.observations) >= data_size
    # dataset = jax.tree_util.tree_map(lambda x: x[:data_size], dataset)
    return dataset

In [None]:
def expectile_loss(diff, expectile=0.8) -> jnp.ndarray:
    weight = jnp.where(diff > 0, expectile, (1 - expectile))
    return weight * (diff**2)

def target_update(
    model: TrainState, target_model: TrainState, tau: float
) -> TrainState:
    new_target_params = jax.tree_util.tree_map(
        lambda p, tp: p * tau + tp * (1 - tau), model.params, target_model.params
    )
    return target_model.replace(params=new_target_params)


def update_by_loss_grad(
    train_state: TrainState, loss_fn: Callable
) -> Tuple[TrainState, jnp.ndarray]:
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(train_state.params)
    new_train_state = train_state.apply_gradients(grads=grad)
    return new_train_state, loss

### Model

In [None]:
class IQLTrainState(NamedTuple):
    rng: jax.random.PRNGKey
    critic: TrainState
    target_critic: TrainState
    value: TrainState
    actor: TrainState

class IQL(object):

    @classmethod
    def update_critic(
        self, train_state: IQLTrainState, batch: Transition, config: IQLConfig
    ) -> Tuple["IQLTrainState", Dict]:
        next_v = train_state.value.apply_fn(
            train_state.value.params, batch.next_observations
        )
        target_q = batch.rewards + config.discount * (1 - batch.dones) * next_v

        def critic_loss_fn(
            critic_params: flax.core.FrozenDict[str, Any]
        ) -> jnp.ndarray:
            q1, q2 = train_state.critic.apply_fn(
                critic_params, batch.observations, batch.actions
            )
            critic_loss = ((q1 - target_q) ** 2 + (q2 - target_q) ** 2).mean()
            return critic_loss

        new_critic, critic_loss = update_by_loss_grad(
            train_state.critic, critic_loss_fn
        )
        return train_state._replace(critic=new_critic), critic_loss

    @classmethod
    def update_value(
        self, train_state: IQLTrainState, batch: Transition, config: IQLConfig
    ) -> Tuple["IQLTrainState", Dict]:
        q1, q2 = train_state.target_critic.apply_fn(
            train_state.target_critic.params, batch.observations, batch.actions
        )
        q = jax.lax.stop_gradient(jnp.minimum(q1, q2))
        def value_loss_fn(value_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray:
            v = train_state.value.apply_fn(value_params, batch.observations)
            value_loss = expectile_loss(q - v, config.expectile).mean()
            return value_loss

        new_value, value_loss = update_by_loss_grad(train_state.value, value_loss_fn)
        return train_state._replace(value=new_value), value_loss

    @classmethod
    def update_actor(
        self, train_state: IQLTrainState, batch: Transition, config: IQLConfig
    ) -> Tuple["IQLTrainState", Dict]:
        v = train_state.value.apply_fn(train_state.value.params, batch.observations)
        q1, q2 = train_state.critic.apply_fn(
            train_state.target_critic.params, batch.observations, batch.actions
        )
        q = jnp.minimum(q1, q2)
        exp_a = jnp.exp((q - v) * config.beta)
        exp_a = jnp.minimum(exp_a, 100.0)
        def actor_loss_fn(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray:
            dist = train_state.actor.apply_fn(actor_params, batch.observations)
            log_probs = dist.log_prob(batch.actions.astype(jnp.int32))
            actor_loss = -(exp_a * log_probs).mean()
            return actor_loss

        new_actor, actor_loss = update_by_loss_grad(train_state.actor, actor_loss_fn)
        return train_state._replace(actor=new_actor), actor_loss

    @classmethod
    def update_n_times(
        self,
        train_state: IQLTrainState,
        dataset: Transition,
        rng: jax.random.PRNGKey,
        config: IQLConfig,
    ) -> Tuple["IQLTrainState", Dict]:
        for _ in range(config.n_jitted_updates):
            rng, subkey = jax.random.split(rng)
            batch_indices = jax.random.randint(
                subkey, (config.batch_size,), 0, len(dataset.observations)
            )
            batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset)

            train_state, value_loss = self.update_value(train_state, batch, config)
            train_state, actor_loss = self.update_actor(train_state, batch, config)
            train_state, critic_loss = self.update_critic(train_state, batch, config)
            new_target_critic = target_update(
                train_state.critic, train_state.target_critic, config.tau
            )
            train_state = train_state._replace(target_critic=new_target_critic)
        return train_state, {
            "value_loss": value_loss,
            "actor_loss": actor_loss,
            "critic_loss": critic_loss,
        }

    @classmethod
    def get_action(
        self,
        train_state: IQLTrainState,
        observations: np.ndarray,
        seed: jax.random.PRNGKey,
        temperature: float = 1.0,
        max_action: float = 1.0,
    ) -> jnp.ndarray:

        # modified for discrete actions
        dist = train_state.actor.apply_fn(
            train_state.actor.params, observations, temperature=temperature
        )
        actions = jnp.argmax(dist.logits, axis=-1)
        return actions

### Train & Evaluate

In [None]:
def create_iql_train_state(
    rng: jax.random.PRNGKey,
    observations: jnp.ndarray,
    actions: jnp.ndarray,
    config: IQLConfig,
) -> IQLTrainState:
    rng, actor_rng, critic_rng, value_rng = jax.random.split(rng, 4)
    # initialize actor
    action_dim = 4

    # Gaussian Model
    # actor_model = GaussianPolicy(
    #     config.hidden_dims,
    #     action_dim=action_dim,
    #     log_std_min=-5.0,
    # )

    # Cat Model
    actor_model = CatPolicy(
        config.hidden_dims,
        action_dim = action_dim
    )

    if config.opt_decay_schedule:
        schedule_fn = optax.cosine_decay_schedule(-config.actor_lr, config.max_steps)
        actor_tx = optax.chain(optax.scale_by_adam(), optax.scale_by_schedule(schedule_fn))
    else:
        actor_tx = optax.adam(learning_rate=config.actor_lr)
    actor = TrainState.create(
        apply_fn=actor_model.apply,
        params=actor_model.init(actor_rng, observations),
        tx=actor_tx,
    )
    # initialize critic
    critic_model = ensemblize(Critic, num_qs=2)(config.hidden_dims)
    critic = TrainState.create(
        apply_fn=critic_model.apply,
        params=critic_model.init(critic_rng, observations, actions),
        tx=optax.adam(learning_rate=config.critic_lr),
    )
    target_critic = TrainState.create(
        apply_fn=critic_model.apply,
        params=critic_model.init(critic_rng, observations, actions),
        tx=optax.adam(learning_rate=config.critic_lr),
    )
    # initialize value
    value_model = ValueCritic(config.hidden_dims, layer_norm=config.layer_norm)
    value = TrainState.create(
        apply_fn=value_model.apply,
        params=value_model.init(value_rng, observations),
        tx=optax.adam(learning_rate=config.value_lr),
    )
    return IQLTrainState(
        rng,
        critic=critic,
        target_critic=target_critic,
        value=value,
        actor=actor,
    )

In [None]:
def evaluate(
    policy_fn, env, env_params, num_episodes: int, rng
) -> float:
    print("evaluation started")
    episode_returns = []

    for i in range(num_episodes):
      rng, _rng = jax.random.split(rng)
      episode_return = 0

      timestep = env.reset(env_params, _rng)
      done = timestep.step_type == 2
      observation = timestep.observation

      while not done:
          # potential case issue
          obs = observation[None, ...]
          action = policy_fn(observations=obs)

          if isinstance(action, (jnp.ndarray, np.ndarray)) and action.shape == (1,):
            action = int(action[0])

          timestep = env.step(env_params, timestep, action)
          reward = timestep.reward
          done = timestep.step_type == 2
          observation = timestep.observation

          episode_return += reward
      episode_returns.append(episode_return)
    return float(jnp.mean(jnp.array(episode_returns)))

In [None]:
if __name__ == "__main__":
    wandb.init(config=config, project=config.project)

    rng = jax.random.PRNGKey(config.seed)
    rng, _rng = jax.random.split(rng)

    env, env_params = xminigrid.make("MiniGrid-EmptyRandom-6x6")
    env = GymAutoResetWrapper(env)

    dataset= preprocess_dataset(replay_buffer, config)

    # create train_state
    example_batch: Transition = jax.tree_util.tree_map(lambda x: x[0], dataset)
    train_state: IQLTrainState = create_iql_train_state(
        _rng,
        example_batch.observations[None, ...],
        example_batch.actions[None, ...],
        config,
    )

    algo = IQL()
    update_fn = jax.jit(algo.update_n_times, static_argnums=(3,))
    act_fn = jax.jit(algo.get_action)
    num_steps = config.max_steps // config.n_jitted_updates
    eval_interval = config.eval_interval // config.n_jitted_updates
    for i in tqdm.tqdm(range(1, num_steps + 1), smoothing=0.1, dynamic_ncols=True):
        rng, subkey = jax.random.split(rng)
        train_state, update_info = update_fn(train_state, dataset, subkey, config)

        if i % config.log_interval == 0:
            train_metrics = {f"training/{k}": v for k, v in update_info.items()}
            wandb.log(train_metrics, step=i)

        # if i % eval_interval == 0:
        #     policy_fn = partial(
        #         act_fn,
        #         temperature=0.0,
        #         seed=jax.random.PRNGKey(0),
        #         train_state=train_state,
        #     )
        #     normalized_score = evaluate(
        #         policy_fn,
        #         env,
        #         env_params,
        #         rng = _rng,
        #         num_episodes=config.eval_episodes,
        #     )
        #     print(i, normalized_score)
        #     eval_metrics = {f"{config.env_name}/normalized_score": normalized_score}
        #     wandb.log(eval_metrics, step=i)
    # final evaluation
    policy_fn = partial(
        act_fn,
        temperature=0.0,
        seed=jax.random.PRNGKey(0),
        train_state=train_state,
    )
    normalized_score = evaluate(
        policy_fn,
        env,
        env_params,
        rng = _rng,
        num_episodes=config.eval_episodes,
    )
    print("Final Evaluation", normalized_score)
    wandb.log({f"{config.env_name}/final_normalized_score": normalized_score})
    wandb.finish()

## Collect Rollouts

In [None]:
from xminigrid.wrappers import GymAutoResetWrapper

def build_rollout(env, env_params, num_steps):
  def rollout(rng):
    def _step_fn(carry, _):
      rng, timestep = carry
      rng, _rng = jax.random.split(rng)
      action = jax.random.randint(_rng, shape=(), minval=0, maxval=env.num_actions(env_params))

      timestep = env.step(env_params, timestep, action)

      return (rng, timestep), (timestep,action)

    rng, _rng = jax.random.split(rng)
    timestep = env.reset(env_params, _rng)
    rng, (transitions, actions) = jax.lax.scan(_step_fn, (rng, timestep), None, length=num_steps)

    return transitions, actions
  return rollout

In [None]:
env, env_params = xminigrid.make("MiniGrid-EmptyRandom-8x8")
env = GymAutoResetWrapper(env)

rollout_fn = jax.jit(build_rollout(env, env_params, num_steps=1e6))

transitions, actions = rollout_fn(jax.random.key(0))

In [None]:
obs_dim = env.observation_shape(env_params)

In [None]:
print(obs_dim)

(7, 7, 2)


In [None]:
print("Transitions shapes: \n", jtu.tree_map(jnp.shape, transitions))
print("Actions shape:", actions.shape)
print(type(actions))

Transitions shapes: 
 TimeStep(state=State(key=(1000000,), step_num=(1000000,), grid=(1000000, 8, 8, 2), agent=AgentState(position=(1000000, 2), direction=(1000000,), pocket=(1000000, 2)), goal_encoding=(1000000, 5), rule_encoding=(1000000, 1, 7), carry=EnvCarry()), step_type=(1000000,), reward=(1000000,), discount=(1000000,), observation=(1000000, 7, 7, 2))
Actions shape: (1000000,)
<class 'jaxlib.xla_extension.ArrayImpl'>


In [None]:
def create_replay_buffer(transitions, actions):

  observations = transitions.observation # (T, 7, 7, 2)
  rewards = transitions.reward # (T,)
  dones = transitions.step_type == 2 # (T,)
  next_observations = jnp.concatenate([observations[1:], observations[-1:]], axis=0) #(T, 7, 7, 2)
  actions = jnp.array(actions, dtype=jnp.int32) #(T,)

  replay_buffer = {'observations': observations,
                   'actions': actions,
                   'rewards': rewards,
                   'next_observations': next_observations,
                   'dones': dones}

  print("=== Replay Buffer 构建完成 ===")
  print(f"数据点数量: {len(observations)}")
  print(f"平均奖励: {jnp.mean(rewards):.4f}")
  print(f"Episode结束次数: {jnp.sum(dones)}")
  print(f"动作分布: {jnp.bincount(actions)}")
  return replay_buffer

Potential issue with sparse reward

In [None]:
replay_buffer = create_replay_buffer(transitions, actions)

=== Replay Buffer 构建完成 ===
数据点数量: 1000000
平均奖励: 0.0028
Episode结束次数: 9572
动作分布: [167326 166610 166812 166592 166378 166282]


In [None]:
def create_batches(replay_buffer, batch_size=32, num_batches=None):
  data_size = len(replay_buffer['observations'])

  if num_batches is None:
    num_batches = max(1, data_size // batch_size)

  batches = []

  rng = jax.random.PRNGKey(0)



## TD3BC

In [None]:
import os
import time
from functools import partial
from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tqdm
import wandb
from flax.training.train_state import TrainState
from omegaconf import OmegaConf
from pydantic import BaseModel

Functions

In [None]:
def target_update(
    model: TrainState, target_model: TrainState, tau: float
) -> TrainState:
    new_target_params = jax.tree_util.tree_map(
        lambda p, tp: p * tau + tp * (1 - tau), model.params, target_model.params
    )
    return target_model.replace(params=new_target_params)


def update_by_loss_grad(
    train_state: TrainState, loss_fn: Callable
) -> Tuple[TrainState, jnp.ndarray]:
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(train_state.params)
    new_train_state = train_state.apply_gradients(grads=grad)
    return new_train_state, loss

In [None]:
class TD3BCConfig(BaseModel):
    # GENERAL
    algo: str = "TD3-BC"
    project: str = "train-TD3-BC"
    env_name: str = "MiniGrid-Empty-8x8"
    seed: int = 42
    eval_episodes: int = 5
    log_interval: int = 100000
    eval_interval: int = 100000
    batch_size: int = 256
    max_steps: int = int(1e6)
    n_jitted_updates: int = 8
    # DATASET
    data_size: int = int(1e6)
    normalize_state: bool = True
    # NETWORK
    hidden_dims: Sequence[int] = (256, 256)
    critic_lr: float = 1e-3
    actor_lr: float = 1e-3
    # TD3-BC SPECIFIC
    policy_freq: int = 2  # update actor every policy_freq updates
    alpha: float = 2.5  # BC loss weight
    policy_noise_std: float = 0.2  # std of policy noise
    policy_noise_clip: float = 0.5  # clip policy noise
    tau: float = 0.005  # target network update rate
    discount: float = 0.99  # discount factor

    def __hash__(
        self,
    ):  # make config hashable to be specified as static_argnums in jax.jit.
        return hash(self.__repr__())

conf_dict = OmegaConf.from_cli() # CLI Input
config = TD3BCConfig(**conf_dict)

def default_init(scale: Optional[float] = jnp.sqrt(2)):
    return nn.initializers.orthogonal(scale)


class MLP(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    activate_final: bool = False
    kernel_init: Callable[[Any, Sequence[int], Any], jnp.ndarray] = default_init()
    layer_norm: bool = False

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for i, hidden_dims in enumerate(self.hidden_dims):
            x = nn.Dense(hidden_dims, kernel_init=self.kernel_init)(x)
            if i + 1 < len(self.hidden_dims) or self.activate_final:
                if self.layer_norm:  # Add layer norm after activation
                    if i + 1 < len(self.hidden_dims):
                        x = nn.LayerNorm()(x)
                x = self.activations(x)
        return x

class DoubleCritic(nn.Module):
    hidden_dims: Sequence[int]

    @nn.compact
    def __call__(
        self, observation: jnp.ndarray, action: jnp.ndarray
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        x = jnp.concatenate([observation, action], axis=-1)
        q1 = MLP((*self.hidden_dims, 1), layer_norm=True)(x)
        q2 = MLP((*self.hidden_dims, 1), layer_norm=True)(x)
        return q1, q2


class TD3Actor(nn.Module):
    hidden_dims: Sequence[int]
    action_dim: int
    max_action: float = 1.0  # In D4RL, action is scaled to [-1, 1]

    @nn.compact
    def __call__(self, observation: jnp.ndarray) -> jnp.ndarray:
        action = MLP((*self.hidden_dims, self.action_dim))(observation)
        action = self.max_action * jnp.tanh(
            action
        )  # scale to [-max_action, max_action]
        return action

class Transition(NamedTuple):
    observations: jnp.ndarray
    actions: jnp.ndarray
    rewards: jnp.ndarray
    next_observations: jnp.ndarray
    dones: jnp.ndarray

class TD3BCTrainState(NamedTuple):
    actor: TrainState
    critic: TrainState
    target_actor: TrainState
    target_critic: TrainState
    max_action: float = 1.0



TD3BC Object

In [None]:
class TD3BC(object):
    @classmethod
    def update_actor(
        self,
        train_state: TD3BCTrainState,
        batch: Transition,
        rng: jax.random.PRNGKey,
        config: TD3BCConfig,
    ) -> Tuple["TD3BCTrainState", jnp.ndarray]:
        def actor_loss_fn(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray:
            predicted_action = train_state.actor.apply_fn(
                actor_params, batch.observations
            )
            critic_params = jax.lax.stop_gradient(train_state.critic.params)
            q_value, _ = train_state.critic.apply_fn(
                critic_params, batch.observations, predicted_action
            )

            mean_abs_q = jax.lax.stop_gradient(jnp.abs(q_value).mean())
            loss_lambda = config.alpha / mean_abs_q

            bc_loss = jnp.square(predicted_action - batch.actions).mean()
            loss_actor = -1.0 * q_value.mean() * loss_lambda + bc_loss
            return loss_actor

        new_actor, actor_loss = update_by_loss_grad(train_state.actor, actor_loss_fn)
        return train_state._replace(actor=new_actor), actor_loss

    @classmethod
    def update_critic(
        self,
        train_state: TD3BCTrainState,
        batch: Transition,
        rng: jax.random.PRNGKey,
        config: TD3BCConfig,
    ) -> Tuple["TD3BCTrainState", jnp.ndarray]:
        def critic_loss_fn(
            critic_params: flax.core.FrozenDict[str, Any]
        ) -> jnp.ndarray:
            q_pred_1, q_pred_2 = train_state.critic.apply_fn(
                critic_params, batch.observations, batch.actions
            )
            target_next_action = train_state.target_actor.apply_fn(
                train_state.target_actor.params, batch.next_observations
            )
            policy_noise = (
                config.policy_noise_std
                * train_state.max_action
                * jax.random.normal(rng, batch.actions.shape)
            )
            target_next_action = target_next_action + policy_noise.clip(
                -config.policy_noise_clip, config.policy_noise_clip
            )
            target_next_action = target_next_action.clip(
                -train_state.max_action, train_state.max_action
            )
            q_next_1, q_next_2 = train_state.target_critic.apply_fn(
                train_state.target_critic.params,
                batch.next_observations,
                target_next_action,
            )
            target = batch.rewards[..., None] + config.discount * jnp.minimum(
                q_next_1, q_next_2
            ) * (1 - batch.dones[..., None])
            target = jax.lax.stop_gradient(target)  # stop gradient for target
            value_loss_1 = jnp.square(q_pred_1 - target)
            value_loss_2 = jnp.square(q_pred_2 - target)
            value_loss = (value_loss_1 + value_loss_2).mean()
            return value_loss

        new_critic, critic_loss = update_by_loss_grad(
            train_state.critic, critic_loss_fn
        )
        return train_state._replace(critic=new_critic), critic_loss

    @classmethod
    def update_n_times(
        self,
        train_state: TD3BCTrainState,
        data: Transition,
        rng: jax.random.PRNGKey,
        config: TD3BCConfig,
    ) -> Tuple["TD3BCTrainState", Dict]:
        for _ in range(
            config.n_jitted_updates
        ):  # we can jit for roop for static unroll
            rng, batch_rng = jax.random.split(rng, 2)
            batch_idx = jax.random.randint(
                batch_rng, (config.batch_size,), 0, len(data.observations)
            )
            batch: Transition = jax.tree_util.tree_map(lambda x: x[batch_idx], data)
            rng, critic_rng, actor_rng = jax.random.split(rng, 3)
            train_state, critic_loss = self.update_critic(
                train_state, batch, critic_rng, config
            )
            if _ % config.policy_freq == 0:
                train_state, actor_loss = self.update_actor(
                    train_state, batch, actor_rng, config
                )
                new_target_critic = target_update(
                    train_state.critic, train_state.target_critic, config.tau
                )
                new_target_actor = target_update(
                    train_state.actor, train_state.target_actor, config.tau
                )
                train_state = train_state._replace(
                    target_critic=new_target_critic,
                    target_actor=new_target_actor,
                )
        return train_state, {
            "critic_loss": critic_loss,
            "actor_loss": actor_loss,
        }

    @classmethod
    def get_action(
        self,
        train_state: TD3BCTrainState,
        obs: jnp.ndarray,
        max_action: float = 1.0,  # In D4RL, action is scaled to [-1, 1]
    ) -> jnp.ndarray:
        action = train_state.actor.apply_fn(train_state.actor.params, obs)
        action = action.clip(-max_action, max_action)
        return action


Create TrainState

In [None]:
def create_td3bc_train_state(
    rng: jax.random.PRNGKey,
    observations: jnp.ndarray,
    actions: jnp.ndarray,
    config: TD3BCConfig,
) -> TD3BCTrainState:
    critic_model = DoubleCritic(
        hidden_dims=config.hidden_dims,
    )
    action_dim = actions.shape[-1]
    actor_model = TD3Actor(
        action_dim=action_dim,
        hidden_dims=config.hidden_dims,
    )
    rng, critic_rng, actor_rng = jax.random.split(rng, 3)
    # initialize critic
    critic_train_state: TrainState = TrainState.create(
        apply_fn=critic_model.apply,
        params=critic_model.init(critic_rng, observations, actions),
        tx=optax.adam(config.critic_lr),
    )
    target_critic_train_state: TrainState = TrainState.create(
        apply_fn=critic_model.apply,
        params=critic_model.init(critic_rng, observations, actions),
        tx=optax.adam(config.critic_lr),
    )
    # initialize actor
    actor_train_state: TrainState = TrainState.create(
        apply_fn=actor_model.apply,
        params=actor_model.init(actor_rng, observations),
        tx=optax.adam(config.actor_lr),
    )
    target_actor_train_state: TrainState = TrainState.create(
        apply_fn=actor_model.apply,
        params=actor_model.init(actor_rng, observations),
        tx=optax.adam(config.actor_lr),
    )
    return TD3BCTrainState(
        actor=actor_train_state,
        critic=critic_train_state,
        target_actor=target_actor_train_state,
        target_critic=target_critic_train_state,
    )

Evaluation

In [None]:
def evaluate(
    policy_fn: Callable[[jnp.ndarray], jnp.ndarray],
    env_name: str,
    num_episodes: int,
    obs_mean,
    obs_std,
    max_steps_per_episode: int = 100,
) -> float:
    """
    评估策略

    Args:
        policy_fn: 策略函数
        env_name: 环境名称
        num_episodes: episode数量
        obs_mean: observation均值
        obs_std: observation标准差
        max_steps_per_episode: 每个episode的最大步数

    Returns:
        平均episode回报
    """
    # 创建环境
    env, env_params = xminigrid.make(env_name)

    episode_returns = []

    for episode in range(num_episodes):
        episode_return = 0
        timestep = env.reset(env_params, jax.random.PRNGKey(episode))

        for step in range(max_steps_per_episode):
            # 处理observation - xminigrid的observation是直接的JAX数组
            obs_array = timestep.observation
            obs_numpy = np.array(obs_array)

            # xminigrid的observation形状是(7, 7, 2)
            if obs_numpy.shape == (7, 7, 2):
                # 将(7, 7, 2)转换为(7, 7, 3)的RGB图像
                object_types = obs_numpy[:, :, 0]
                colors = obs_numpy[:, :, 1]

                rgb_image = np.zeros((7, 7, 3), dtype=np.uint8)
                rgb_image[:, :, 0] = colors
                rgb_image[:, :, 1] = object_types
                rgb_image[:, :, 2] = 0

                # 上采样到22x22
                from scipy.ndimage import zoom
                try:
                    rgb_image = zoom(rgb_image, (22/7, 22/7, 1), order=0)
                except ImportError:
                    rgb_image = np.repeat(np.repeat(rgb_image, 3, axis=0), 3, axis=1)
                    rgb_image = rgb_image[:22, :22, :]

                direction = np.array([0.0])
            else:
                rgb_image = obs_numpy
                direction = np.array([0.0])

            obs_dict = {
                'image': rgb_image,
                'direction': direction
            }
            processed_obs = processor.process_observation(obs_dict)

            # 归一化observation
            if obs_mean is not None and obs_std is not None:
                processed_obs = (processed_obs - obs_mean) / obs_std

            # 获取动作
            action = policy_fn(obs=processed_obs)

            # 执行动作
            timestep = env.step(env_params, timestep, action)
            episode_return += timestep.reward

            if timestep.is_done():
                break

        episode_returns.append(episode_return)

    return np.mean(episode_returns)


In [None]:
if __name__ == "__main__":
    # wandb.init(project=config.project, config=config)

    rng = jax.random.PRNGKey(config.seed)
    # dataset, obs_mean, obs_std = get_dataset(config)

    # create train_state
    rng, subkey = jax.random.split(rng)
    # example_batch: Transition = jax.tree_util.tree_map(lambda x: x[0], dataset)
    train_state = create_td3bc_train_state(
        subkey, example_batch.observations, example_batch.actions, config
    )
    algo = TD3BC()
    update_fn = jax.jit(algo.update_n_times, static_argnums=(3,))
    act_fn = jax.jit(algo.get_action)

    num_steps = config.max_steps // config.n_jitted_updates
    eval_interval = config.eval_interval // config.n_jitted_updates
    for i in tqdm.tqdm(range(1, num_steps + 1), smoothing=0.1, dynamic_ncols=True):
        rng, update_rng = jax.random.split(rng)
        train_state, update_info = update_fn(
            train_state,
            dataset,
            update_rng,
            config,
        )  # update parameters
        if i % config.log_interval == 0:
            train_metrics = {f"training/{k}": v for k, v in update_info.items()}
            # wandb.log(train_metrics, step=i)

        if i % eval_interval == 0:
            policy_fn = partial(act_fn, train_state=train_state)
            normalized_score = evaluate(
                policy_fn,
                config.env_name,
                num_episodes=config.eval_episodes,
                obs_mean=obs_mean,
                obs_std=obs_std,
            )
            print(i, normalized_score)
            eval_metrics = {f"{config.env_name}/episode_return": normalized_score}
            # wandb.log(eval_metrics, step=i)

    # # final evaluation
    # policy_fn = partial(act_fn, train_state=train_state)
    # normalized_score = evaluate(
    #     policy_fn,
    #     config.env_name,
    #     num_episodes=config.eval_episodes,
    #     obs_mean=obs_mean,
    #     obs_std=obs_std,
    # )
    # print("Final Evaluation Score:", normalized_score)
    # wandb.log({f"{config.env_name}/final_episode_return": normalized_score})
    # wandb.finish()

NameError: name 'create_td3bc_train_state' is not defined