In [None]:
%pip install jax
%pip install numpy
%pip install matplotlib
%pip install xminigrid
%pip install gymnax
%pip install distrax

In [2]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import distrax

import timeit
import imageio
import matplotlib.pyplot as plt
from tqdm.auto import trange, tqdm

from flax import nnx
import xminigrid

In [3]:
# class TimeStep(struct.PyTreeNode):
#     # hidden environment state, such as grid, agent, goal, etc
#     state: State

#     # similar to the dm_env enterface
#     step_type: StepType
#     reward: jax.Array
#     discount: jax.Array
#     observation: jax.Array

## Wrapper

## Utils

### Encoder

In [8]:
class Encoder(nnx.Module):
  def __init__(self, input_dim: int, hidden_dim: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(input_dim, hidden_dim, rngs=rngs)
    self.layer_norm0 = nnx.LayerNorm(hidden_dim, rngs=rngs)

  def __call__(self, x: jax.Array):
    h = self.linear(x)
    return self.layer_norm0(h)

class ActionEncoder(nnx.Module):
  def __init__(self, input_dim: int, hidden_dim: int, rngs: nnx.Rngs):
    self.embed = nnx.Embed(input_dim, hidden_dim, rngs=rngs)
    self.layer_norm0 = nnx.LayerNorm(hidden_dim, rngs=rngs)

  def __call__(self, x: jax.Array):
    h = self.embed(x)
    return self.layer_norm0(h)

class JointEncoder(nnx.Module):
  def __init__(self, hidden_dim: int, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
    self.linear2 = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
    self.layer_norm0 = nnx.LayerNorm(hidden_dim, rngs=rngs)
    self.layer_norm1 = nnx.LayerNorm(hidden_dim, rngs=rngs)
    self.layer_norm2 = nnx.LayerNorm(hidden_dim, rngs=rngs)
    self.layer_norm3 = nnx.LayerNorm(hidden_dim, rngs=rngs)

  def __call__(self, x: jax.Array, rng):
    dist_distrax = distrax.MultivariateNormalDiag(loc=x, scale_diag=1e-1*jnp.ones_like(x))
    # potential shape issue
    x = dist_distrax.sample(seed=rng, sample_shape=(1,))
    x = self.layer_norm0(x)
    h0 = self.linear1(x)
    h = nn.relu(h0)
    h = self.layer_norm1(h) + h0
    h0 = self.linear2(h)
    h = self.layer_norm2(h) + h0
    return self.layer_norm3(h)

### Actor

In [5]:
from jax import lax
import distrax

class Actor(nnx.Module):
  # environment related ???
  log_std_min: float = -4
  log_std_max: float = 2

  def __init__(self, obs_dim, action_dim, hidden_dim, rngs: nnx.Rngs):
    self.mean = nnx.Linear(hidden_dim, action_dim, rngs=rngs)
    self.log_std = nnx.Linear(hidden_dim, action_dim, rngs=rngs)

  def __call__(self, x: jnp.ndarray):
    mean = self.mean(x)
    log_std = jnp.clip(self.log_std(x), self.log_std_min, self.log_std_max)
    return mean, log_std

### Functions

In [13]:
def compute_info_gain_normal(mean, prec, l_prec, next_obs):
  prec = jnp.maximum(prec, 1e-6)
  posterior_prec = prec + l_prec
  prec_ratio = prec / posterior_prec

  posterior_mean = (prec * mean + l_prec * next_obs) / posterior_prec

  delta_mean = next_obs - posterior_mean
  kl = delta_mean * delta_mean * prec
  kl = kl + prec_ratio - jnp.log(prec_ratio) - 1
  kl = 0.5 * jnp.sum(kl, axis=-1)
  return kl, delta_mean

@jax.jit
def compute_expected_info_gain_normal(prec, l_prec):
  prec = jnp.maximum(prec, 1e-6)
  prec_ratio = l_prec / prec
  mi_matrix = 0.5 * jnp.sum(jnp.log(1+prec_ratio), axis=-1)
  return mi_matrix

### Others

In [6]:
class Likelihood_Prec(nnx.Module):
  log_std_min: float = -2
  log_std_max: float = 2

  def __init__(self, obs_dim: int, hidden_dim: int, rngs: nnx.Rngs):
    self.linear = nnx.Linear(hidden_dim, obs_dim, rngs)

  def __call__(self, x: jnp.ndarray):
    log_std = jnp.clip(self.linear(x), self.log_std_min, self.log_std_max)
    return jnp.exp(-log_std)

## Unsupervised Explorer

In [7]:
from gymnax.experimental import RolloutWrapper
# action = self.model_forward(policy_params, obs, rng_net)
import functools
import gymnax
from typing import Union,Optional,Any
import abc

In [17]:
class UnsupervisedExplorer(nnx.Module):

  @abc.abstractmethod
  def update(self, obs, actions, next_obs, dones, info):
    # update variable parameters
    return #{'kl':KL} MI= E[KL]

  @abc.abstractmethod
  def __call__(self, observations, rng):
    return #actions, {"mi":mi_matrix}

class RandomExplorer(UnsupervisedExplorer):

  def __init__(self, num_actions):
    self.num_actions = num_actions

  def update(self, rng, obs, actions, next_obs, dones, info):
    return {}

  def __call__(self, observations, rng):
    if observations.ndim == 1:
      # possible shape issue here
      actions = jax.random.randint(rng, shape=(1,), minval=0, maxval=self.num_actions)
      return actions, {}
    actions = jax.random.randint(rng, shape=(observations.shape[0],), minval=0, maxval=self.num_actions)
    return actions, {}

class DeepSACBayesianExplorer(UnsupervisedExplorer):
  # ent?
  def __init__(self, obs_dim, num_actions, hidden_dim, rngs: nnx.Rngs,
               l_prec=1.0, wd=1e-2, ent_lambda=1e-3, depth=2):
    self.obs_dim = obs_dim
    self.num_actions = num_actions
    self.hidden_dim = hidden_dim
    self.prec_w = nnx.Variable(jnp.zeros((hidden_dim, obs_dim)), name='prec_w')
    self.mean_w = nnx.Variable(jnp.zeros((hidden_dim, obs_dim)), name='mean_w')
    # what is trainable here
    self.trainable_likelihood_prec = Likelihood_Prec(obs_dim, hidden_dim, rngs)
    self.weight_decay = wd
    self.obs_embeds = Encoder(obs_dim, hidden_dim, rngs)
    self.action_embeds = ActionEncoder(num_actions, hidden_dim, rngs)
    self.joint_embeds = JointEncoder(hidden_dim, rngs)
    self.depth = depth
    self.ent_lambda = ent_lambda

  def update(self, rng, obs, action, next_obs, done, info):
    mean = info["mean"]
    prec = info["prec"]

    def _likelihood_loss(rng, T, mean, prec, next_obs):
      l_prec = self.trainable_likelihood_prec(T)
      mu = mean
      # model var + inherent var
      sigma = jnp.sqrt(1 / l_prec + 1 / prec)
      dist_distrax = distrax.MultivariateNormalDiag(loc=mu, scale_diag=sigma)
      log_prob = dist_distrax.log_prob(next_obs)
      return -log_prob, l_prec

    # jit here
    predictive_loss, l_prec = _likelihood_loss(rng, info["T"], mean, prec, next_obs)
    # originally jnp.sum
    mean_error = jnp.mean((mean - next_obs)**2)
    deepkl, delta_mean = compute_info_gain_normal(mean, prec, l_prec, next_obs)
    # batch x num_hidden
    T = info["T"].reshape(-1, self.hidden_dim)

    # batch x obs_dim
    l_prec = l_prec.reshape(-1, self.obs_dim)
    delta_mean = delta_mean.reshape(-1, self.obs_dim)

    T_T = jnp.transpose(T)
    covariance = T @ T_T
    inv_covariance = jnp.linalg.pinv(covariance)

    T_Map = T_T @ inv_covariance

    delta_precW = T_Map @ l_prec
    self.prec_w.value = (self.prec_w.value + delta_precW) * (1-self.weight_decay)
    delta_meanW = T_Map @ delta_mean
    self.mean_w.value = (self.mean_w.value + delta_meanW) * (1-self.weight_decay)

    return {"kl":deepkl,  "predictive_loss": predictive_loss, "mean_error":mean_error}

  # jitable
  def loss(self, rng, obs, action, next_obs, done, info):
    def _likelihood_loss(T, mean, prec, next_obs):
      l_prec = self.trainable_likelihood_prec(T)

      mu = mean
      sigma = jnp.sqrt(1 / l_prec + 1 / prec)
      dist_distrax = distrax.MultivariateNormalDiag(loc=mu, scale_diag=sigma)

      log_prob = dist_distrax.log_prob(next_obs)
      return -log_prob

    T, mean, prec = info["T"], info["mean"], info["prec"]
    likelihood_loss = _likelihood_loss(T, mean, prec, next_obs)
    return likelihood_loss

  def batch_loss(self, rng, obs, actions, next_obs, dones, info):
    vmapped = jax.vmap(self.loss)
    return vmapped(rng, obs, actions, next_obs, dones, info)

  def recursive_mi(self, observations, rng, depth):
    obs_embed = self.obs_embed(observations)
    action_embed = self.action_embed(jnp.arange(self.num_actions))
    # possible shape issue
    embed = action_embed + jnp.expand_dims(obs_embed, axis=0)

    # num_actions x embed_size
    T = self.joint_embeds(embed, rng)
    prec = jnp.maximum(T @ self.prec_w, 1e-3)
    # num_actions x obs_dim
    mean = T @ self.mean_w
    l_prec = self.trainable_likelihood_prec(T)

    MI = compute_expected_info_gain_normal(prec, l_prec)

    if depth > 0:
      vmapped = jax.vmap(self.recursive_mi, in_axes=(0,None,None))
      # num_actions x 1
      actions, info = vmapped(mean, rng, depth-1)
      MI = MI + info["mi"]

    actions = jnp.argmax(MI, axis=0)
    T = T[actions]
    MI = MI[actions]
    l_prec = l_prec[actions]
    prec = prec[actions]
    mean = mean[actions]
    return actions, {"mi":MI,"T":T,"obs_embed":obs_embed,"l_prec":l_prec,
                        "prec":prec,"mean":mean}

def show_variable(model, text):

    graphdef, params, vars,others = nnx.split(model, nnx.Param, nnx.Variable,...)

    print(text,vars)


## Collect Rollouts

In [None]:
from xminigrid.wrappers import GymAutoResetWrapper

def build_rollout(env, env_params, num_steps):
  def rollout(rng):
    def _step_fn(carry, _):
      rng, timestep = carry
      rng, _rng = jax.random.split(rng)
      action = jax.random.randint(_rng, shape=(), minval=0, maxval=env.num_actions(env_params))

      timestep = env.step(env_params, timestep, action)

      return (rng, timestep), (timestep,action)

    rng, _rng = jax.random.split(rng)
    timestep = env.reset(env_params, _rng)
    rng, (transitions, actions) = jax.lax.scan(_step_fn, (rng, timestep), None, length=num_steps)

    return transitions, actions
  return rollout

In [None]:
env, env_params = xminigrid.make("MiniGrid-EmptyRandom-6x6")
env = GymAutoResetWrapper(env)

rollout_fn = jax.jit(build_rollout(env, env_params, num_steps=1e6))

transitions, actions = rollout_fn(jax.random.key(0))

In [None]:
print("Transitions shapes: \n", jtu.tree_map(jnp.shape, transitions))
print("Actions shape:", actions.shape)
print(type(actions))

Transitions shapes: 
 TimeStep(state=State(key=(1000000,), step_num=(1000000,), grid=(1000000, 6, 6, 2), agent=AgentState(position=(1000000, 2), direction=(1000000,), pocket=(1000000, 2)), goal_encoding=(1000000, 5), rule_encoding=(1000000, 1, 7), carry=EnvCarry()), step_type=(1000000,), reward=(1000000,), discount=(1000000,), observation=(1000000, 7, 7, 2))
Actions shape: (1000000,)
<class 'jaxlib.xla_extension.ArrayImpl'>


In [None]:
def create_replay_buffer(transitions, actions):

  observations = transitions.observation # (T, 7, 7, 2)
  rewards = transitions.reward # (T,)
  dones = transitions.step_type == 2 # (T,)
  next_observations = jnp.concatenate([observations[1:], observations[-1:]], axis=0) #(T, 7, 7, 2)
  actions = jnp.array(actions, dtype=jnp.int32) #(T,)

  replay_buffer = {'observations': observations,
                   'actions': actions,
                   'rewards': rewards,
                   'next_observations': next_observations,
                   'dones': dones}

  print("=== Replay Buffer 构建完成 ===")
  print(f"数据点数量: {len(observations)}")
  print(f"平均奖励: {jnp.mean(rewards):.4f}")
  print(f"Episode结束次数: {jnp.sum(dones)}")
  print(f"动作分布: {jnp.bincount(actions)}")
  return replay_buffer

Potential issue with sparse reward

In [None]:
replay_buffer = create_replay_buffer(transitions, actions)

=== Replay Buffer 构建完成 ===
数据点数量: 1000000
平均奖励: 0.0028
Episode结束次数: 9572
动作分布: [167326 166610 166812 166592 166378 166282]


In [None]:
def create_batches(replay_buffer, batch_size=32, num_batches=None):
  data_size = len(replay_buffer['observations'])

  if num_batches is None:
    num_batches = max(1, data_size // batch_size)

  batches = []

  rng = jax.random.PRNGKey(0)



## IQL

In [None]:
import os
import time
from functools import partial
from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple

import distrax
import flax
import flax.linen as nn

import jax
import jax.numpy as jnp
import numpy as np
import optax
import tqdm
import wandb
from flax.training.train_state import TrainState
from omegaconf import OmegaConf
from pydantic import BaseModel

os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"


### Config

In [None]:
class IQLConfig(BaseModel):
    # GENERAL
    algo: str = "IQL"
    project: str = "train-IQL"
    env_name: str = "MiniGrid-EmptyRandom-6x6"
    seed: int = 42
    eval_episodes: int = 5
    log_interval: int = 100
    eval_interval: int = 100000
    batch_size: int = 256
    max_steps: int = int(1e6)
    n_jitted_updates: int = 8
    # DATASET
    data_size: int = int(1e6)
    normalize_state: bool = False
    normalize_reward: bool = True
    # NETWORK
    hidden_dims: Tuple[int, int] = (256, 256)
    actor_lr: float = 3e-4
    value_lr: float = 3e-4
    critic_lr: float = 3e-4
    layer_norm: bool = True
    opt_decay_schedule: bool = True
    # IQL SPECIFIC
    expectile: float = (
        0.7  # FYI: for Hopper-me, 0.5 produce better result. (antmaze: expectile=0.9)
    )
    beta: float = (
        3.0  # FYI: for Hopper-me, 6.0 produce better result. (antmaze: beta=10.0)
    )
    tau: float = 0.005
    discount: float = 0.99

    def __hash__(
        self,
    ):  # make config hashable to be specified as static_argnums in jax.jit.
        return hash(self.__repr__())


conf_dict = OmegaConf.from_cli()
config = IQLConfig(**conf_dict)

### Networks

In [None]:
def default_init(scale: Optional[float] = jnp.sqrt(2)):
    return nn.initializers.orthogonal(scale)


class MLP(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    activate_final: bool = False
    kernel_init: Callable[[Any, Sequence[int], Any], jnp.ndarray] = default_init()
    layer_norm: bool = False

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for i, hidden_dims in enumerate(self.hidden_dims):
            x = nn.Dense(hidden_dims, kernel_init=self.kernel_init)(x)
            if i + 1 < len(self.hidden_dims) or self.activate_final:
                if self.layer_norm:  # Add layer norm after activation
                    x = nn.LayerNorm()(x)
                x = self.activations(x)
        return x


class Critic(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    @nn.compact
    def __call__(self, observations: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray:
        batch_size = observations.shape[0]
        actions = jax.nn.one_hot(actions, num_classes=4) #one-hot encoding
        flat_observations = observations.reshape(batch_size, -1)
        inputs = jnp.concatenate([flat_observations, actions], axis=-1)
        critic = MLP((*self.hidden_dims, 1), activations=self.activations)(inputs)
        return jnp.squeeze(critic, -1)


def ensemblize(cls, num_qs, out_axes=0, **kwargs):
    split_rngs = kwargs.pop("split_rngs", {})
    return nn.vmap(
        cls,
        variable_axes={"params": 0},
        split_rngs={**split_rngs, "params": True},
        in_axes=None,
        out_axes=out_axes,
        axis_size=num_qs,
        **kwargs,
    )


class ValueCritic(nn.Module):
    hidden_dims: Sequence[int]
    layer_norm: bool = False

    @nn.compact
    def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
        batch_size = observations.shape[0]
        obs_flat = observations.reshape(batch_size, -1)
        critic = MLP((*self.hidden_dims, 1), layer_norm=self.layer_norm)(obs_flat)
        return jnp.squeeze(critic, -1)


class GaussianPolicy(nn.Module):
    hidden_dims: Sequence[int]
    action_dim: int
    log_std_min: Optional[float] = -5.0
    log_std_max: Optional[float] = 2

    @nn.compact
    def __call__(
        self, observations: jnp.ndarray, temperature: float = 1.0
    ) -> distrax.Distribution:
        outputs = MLP(
            self.hidden_dims,
            activate_final=True,
        )(observations)

        means = nn.Dense(
            self.action_dim, kernel_init=default_init()
        )(outputs)
        log_stds = self.param("log_stds", nn.initializers.zeros, (self.action_dim,))
        log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)

        distribution = distrax.MultivariateNormalDiag(
            loc=means, scale_diag=jnp.exp(log_stds) * temperature
        )
        return distribution

class CatPolicy(nn.Module):
  hidden_dims : Sequence[int]
  action_dim: int

  @nn.compact
  def __call__(self, observations: jnp.ndarray, temperature: float = 1.0) -> distrax.Distribution:
    x = observations.reshape(observations.shape[0], -1) # flatten
    outputs = MLP(self.hidden_dims, activate_final=True)(x)
    logits = nn.Dense(self.action_dim, kernel_init=default_init())(outputs)
    distribution = distrax.Categorical(logits=logits)
    return distribution


### Utils

In [None]:
print(jtu.tree_map(jnp.shape, replay_buffer))
print(type(replay_buffer))
print(replay_buffer["dones"])

In [None]:
class Transition(NamedTuple):
    observations: jnp.ndarray
    actions: jnp.ndarray
    rewards: jnp.ndarray
    next_observations: jnp.ndarray
    dones: jnp.ndarray
    dones_float: jnp.ndarray

In [None]:
def get_normalization(dataset: Transition) -> float:
    # into numpy.ndarray
    dataset = jax.tree_util.tree_map(lambda x: np.array(x), dataset)
    returns = []
    ret = 0
    for r, term in zip(dataset.rewards, dataset.dones_float):
        ret += r
        if term:
            returns.append(ret)
            ret = 0
    return (max(returns) - min(returns)) / 1000

In [None]:
def preprocess_dataset(
     dataset: dict, config: IQLConfig, clip_to_eps: bool = True, eps: float = 1e-5
) -> Transition:

    if clip_to_eps:
        lim = 1 - eps
        dataset["actions"] = jnp.clip(dataset["actions"], -lim, lim)

    # dones_float = np.zeros_like(dataset['dones'])

    # # for i in range(len(dones_float) - 1):
    # #     print(i)
    # #     if np.linalg.norm(dataset['observations'][i + 1] -
    # #                         dataset['next_observations'][i]
    # #                         ) > 1e-6 or dataset['dones'][i] == True:
    # #         dones_float[i] = 1
    # #     else:
    # #         dones_float[i] = 0
    # dones_float[-1] = 1

    obs = dataset['observations']         # shape: (N, 7, 7, 2)
    obs = dataset['observations']         # shape: (N, 7, 7, 2)
    next_obs = dataset['next_observations']  # shape: (N, 7, 7, 2)
    dones = dataset['dones']              # shape: (N,)

    # 展平每个 observation
    obs_flat = obs[1:].reshape((obs.shape[0] - 1, -1))           # shape: (N-1, 98)
    next_obs_flat = next_obs[:-1].reshape((next_obs.shape[0] - 1, -1))  # shape: (N-1, 98)

    # 对每个样本求 L2 范数
    obs_diff = jnp.linalg.norm(obs_flat - next_obs_flat, axis=1)   # shape: (N-1,)
    obs_flag = obs_diff > 1e-6
    done_flag = dones[:-1] == True

    dones_float = jnp.zeros_like(dones, dtype=jnp.float32)
    dones_float = dones_float.at[:-1].set(jnp.logical_or(obs_flag, done_flag).astype(jnp.float32))
    dones_float = dones_float.at[-1].set(1.0)

    dataset = Transition(
        observations=jnp.array(dataset["observations"], dtype=jnp.float32),
        actions=jnp.array(dataset["actions"], dtype=jnp.float32),
        rewards=jnp.array(dataset["rewards"], dtype=jnp.float32),
        next_observations=jnp.array(dataset["next_observations"], dtype=jnp.float32),
        dones=jnp.array(dataset["dones"], dtype=jnp.float32),
        dones_float=jnp.array(dones_float, dtype=jnp.float32),
    )

    # normalize states
    # obs_mean, obs_std = 0, 1
    # if config.normalize_state:
    #     obs_mean = dataset.observations.mean(0)
    #     obs_std = dataset.observations.std(0)
    #     dataset = dataset._replace(
    #         observations=(dataset.observations - obs_mean) / (obs_std + 1e-5),
    #         next_observations=(dataset.next_observations - obs_mean) / (obs_std + 1e-5),
    #     )
    # # normalize rewards
    # if config.normalize_reward:
    #     normalizing_factor = get_normalization(dataset)
    #     dataset = dataset._replace(rewards=dataset.rewards / normalizing_factor)

    # shuffle data and select the first data_size samples
    # data_size = min(config.data_size, len(dataset.observations))
    # rng = jax.random.PRNGKey(config.seed)
    # rng, rng_permute, rng_select = jax.random.split(rng, 3)
    # perm = jax.random.permutation(rng_permute, len(dataset.observations))
    # dataset = jax.tree_util.tree_map(lambda x: x[perm], dataset)
    # assert len(dataset.observations) >= data_size
    # dataset = jax.tree_util.tree_map(lambda x: x[:data_size], dataset)
    return dataset

In [None]:
def expectile_loss(diff, expectile=0.8) -> jnp.ndarray:
    weight = jnp.where(diff > 0, expectile, (1 - expectile))
    return weight * (diff**2)

def target_update(
    model: TrainState, target_model: TrainState, tau: float
) -> TrainState:
    new_target_params = jax.tree_util.tree_map(
        lambda p, tp: p * tau + tp * (1 - tau), model.params, target_model.params
    )
    return target_model.replace(params=new_target_params)


def update_by_loss_grad(
    train_state: TrainState, loss_fn: Callable
) -> Tuple[TrainState, jnp.ndarray]:
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(train_state.params)
    new_train_state = train_state.apply_gradients(grads=grad)
    return new_train_state, loss

### Model

In [None]:
class IQLTrainState(NamedTuple):
    rng: jax.random.PRNGKey
    critic: TrainState
    target_critic: TrainState
    value: TrainState
    actor: TrainState

class IQL(object):

    @classmethod
    def update_critic(
        self, train_state: IQLTrainState, batch: Transition, config: IQLConfig
    ) -> Tuple["IQLTrainState", Dict]:
        next_v = train_state.value.apply_fn(
            train_state.value.params, batch.next_observations
        )
        target_q = batch.rewards + config.discount * (1 - batch.dones) * next_v

        def critic_loss_fn(
            critic_params: flax.core.FrozenDict[str, Any]
        ) -> jnp.ndarray:
            q1, q2 = train_state.critic.apply_fn(
                critic_params, batch.observations, batch.actions
            )
            critic_loss = ((q1 - target_q) ** 2 + (q2 - target_q) ** 2).mean()
            return critic_loss

        new_critic, critic_loss = update_by_loss_grad(
            train_state.critic, critic_loss_fn
        )
        return train_state._replace(critic=new_critic), critic_loss

    @classmethod
    def update_value(
        self, train_state: IQLTrainState, batch: Transition, config: IQLConfig
    ) -> Tuple["IQLTrainState", Dict]:
        q1, q2 = train_state.target_critic.apply_fn(
            train_state.target_critic.params, batch.observations, batch.actions
        )
        q = jax.lax.stop_gradient(jnp.minimum(q1, q2))
        def value_loss_fn(value_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray:
            v = train_state.value.apply_fn(value_params, batch.observations)
            value_loss = expectile_loss(q - v, config.expectile).mean()
            return value_loss

        new_value, value_loss = update_by_loss_grad(train_state.value, value_loss_fn)
        return train_state._replace(value=new_value), value_loss

    @classmethod
    def update_actor(
        self, train_state: IQLTrainState, batch: Transition, config: IQLConfig
    ) -> Tuple["IQLTrainState", Dict]:
        v = train_state.value.apply_fn(train_state.value.params, batch.observations)
        q1, q2 = train_state.critic.apply_fn(
            train_state.target_critic.params, batch.observations, batch.actions
        )
        q = jnp.minimum(q1, q2)
        exp_a = jnp.exp((q - v) * config.beta)
        exp_a = jnp.minimum(exp_a, 100.0)
        def actor_loss_fn(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray:
            dist = train_state.actor.apply_fn(actor_params, batch.observations)
            log_probs = dist.log_prob(batch.actions.astype(jnp.int32))
            actor_loss = -(exp_a * log_probs).mean()
            return actor_loss

        new_actor, actor_loss = update_by_loss_grad(train_state.actor, actor_loss_fn)
        return train_state._replace(actor=new_actor), actor_loss

    @classmethod
    def update_n_times(
        self,
        train_state: IQLTrainState,
        dataset: Transition,
        rng: jax.random.PRNGKey,
        config: IQLConfig,
    ) -> Tuple["IQLTrainState", Dict]:
        for _ in range(config.n_jitted_updates):
            rng, subkey = jax.random.split(rng)
            batch_indices = jax.random.randint(
                subkey, (config.batch_size,), 0, len(dataset.observations)
            )
            batch = jax.tree_util.tree_map(lambda x: x[batch_indices], dataset)

            train_state, value_loss = self.update_value(train_state, batch, config)
            train_state, actor_loss = self.update_actor(train_state, batch, config)
            train_state, critic_loss = self.update_critic(train_state, batch, config)
            new_target_critic = target_update(
                train_state.critic, train_state.target_critic, config.tau
            )
            train_state = train_state._replace(target_critic=new_target_critic)
        return train_state, {
            "value_loss": value_loss,
            "actor_loss": actor_loss,
            "critic_loss": critic_loss,
        }

    @classmethod
    def get_action(
        self,
        train_state: IQLTrainState,
        observations: np.ndarray,
        seed: jax.random.PRNGKey,
        temperature: float = 1.0,
        max_action: float = 1.0,
    ) -> jnp.ndarray:

        # modified for discrete actions
        dist = train_state.actor.apply_fn(
            train_state.actor.params, observations, temperature=temperature
        )
        actions = jnp.argmax(dist.logits, axis=-1)
        return actions

### Train & Evaluate

In [None]:
def create_iql_train_state(
    rng: jax.random.PRNGKey,
    observations: jnp.ndarray,
    actions: jnp.ndarray,
    config: IQLConfig,
) -> IQLTrainState:
    rng, actor_rng, critic_rng, value_rng = jax.random.split(rng, 4)
    # initialize actor
    action_dim = 4

    # Gaussian Model
    # actor_model = GaussianPolicy(
    #     config.hidden_dims,
    #     action_dim=action_dim,
    #     log_std_min=-5.0,
    # )

    # Cat Model
    actor_model = CatPolicy(
        config.hidden_dims,
        action_dim = action_dim
    )

    if config.opt_decay_schedule:
        schedule_fn = optax.cosine_decay_schedule(-config.actor_lr, config.max_steps)
        actor_tx = optax.chain(optax.scale_by_adam(), optax.scale_by_schedule(schedule_fn))
    else:
        actor_tx = optax.adam(learning_rate=config.actor_lr)
    actor = TrainState.create(
        apply_fn=actor_model.apply,
        params=actor_model.init(actor_rng, observations),
        tx=actor_tx,
    )
    # initialize critic
    critic_model = ensemblize(Critic, num_qs=2)(config.hidden_dims)
    critic = TrainState.create(
        apply_fn=critic_model.apply,
        params=critic_model.init(critic_rng, observations, actions),
        tx=optax.adam(learning_rate=config.critic_lr),
    )
    target_critic = TrainState.create(
        apply_fn=critic_model.apply,
        params=critic_model.init(critic_rng, observations, actions),
        tx=optax.adam(learning_rate=config.critic_lr),
    )
    # initialize value
    value_model = ValueCritic(config.hidden_dims, layer_norm=config.layer_norm)
    value = TrainState.create(
        apply_fn=value_model.apply,
        params=value_model.init(value_rng, observations),
        tx=optax.adam(learning_rate=config.value_lr),
    )
    return IQLTrainState(
        rng,
        critic=critic,
        target_critic=target_critic,
        value=value,
        actor=actor,
    )

In [None]:
def evaluate(
    policy_fn, env, env_params, num_episodes: int, rng
) -> float:
    print("evaluation started")
    episode_returns = []

    for i in range(num_episodes):
      rng, _rng = jax.random.split(rng)
      episode_return = 0

      timestep = env.reset(env_params, _rng)
      done = timestep.step_type == 2
      observation = timestep.observation

      while not done:
          # potential case issue
          obs = observation[None, ...]
          action = policy_fn(observations=obs)

          if isinstance(action, (jnp.ndarray, np.ndarray)) and action.shape == (1,):
            action = int(action[0])

          timestep = env.step(env_params, timestep, action)
          reward = timestep.reward
          done = timestep.step_type == 2
          observation = timestep.observation

          episode_return += reward
      episode_returns.append(episode_return)
    return float(jnp.mean(jnp.array(episode_returns)))

In [None]:
if __name__ == "__main__":
    wandb.init(config=config, project=config.project)

    rng = jax.random.PRNGKey(config.seed)
    rng, _rng = jax.random.split(rng)

    env, env_params = xminigrid.make("MiniGrid-EmptyRandom-6x6")
    env = GymAutoResetWrapper(env)

    dataset= preprocess_dataset(replay_buffer, config)

    # create train_state
    example_batch: Transition = jax.tree_util.tree_map(lambda x: x[0], dataset)
    train_state: IQLTrainState = create_iql_train_state(
        _rng,
        example_batch.observations[None, ...],
        example_batch.actions[None, ...],
        config,
    )

    algo = IQL()
    update_fn = jax.jit(algo.update_n_times, static_argnums=(3,))
    act_fn = jax.jit(algo.get_action)
    num_steps = config.max_steps // config.n_jitted_updates
    eval_interval = config.eval_interval // config.n_jitted_updates
    for i in tqdm.tqdm(range(1, num_steps + 1), smoothing=0.1, dynamic_ncols=True):
        rng, subkey = jax.random.split(rng)
        train_state, update_info = update_fn(train_state, dataset, subkey, config)

        if i % config.log_interval == 0:
            train_metrics = {f"training/{k}": v for k, v in update_info.items()}
            wandb.log(train_metrics, step=i)

        # if i % eval_interval == 0:
        #     policy_fn = partial(
        #         act_fn,
        #         temperature=0.0,
        #         seed=jax.random.PRNGKey(0),
        #         train_state=train_state,
        #     )
        #     normalized_score = evaluate(
        #         policy_fn,
        #         env,
        #         env_params,
        #         rng = _rng,
        #         num_episodes=config.eval_episodes,
        #     )
        #     print(i, normalized_score)
        #     eval_metrics = {f"{config.env_name}/normalized_score": normalized_score}
        #     wandb.log(eval_metrics, step=i)
    # final evaluation
    policy_fn = partial(
        act_fn,
        temperature=0.0,
        seed=jax.random.PRNGKey(0),
        train_state=train_state,
    )
    normalized_score = evaluate(
        policy_fn,
        env,
        env_params,
        rng = _rng,
        num_episodes=config.eval_episodes,
    )
    print("Final Evaluation", normalized_score)
    wandb.log({f"{config.env_name}/final_normalized_score": normalized_score})
    wandb.finish()

## TD3BC

In [None]:
import os
import time
from functools import partial
from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tqdm
import wandb
from flax.training.train_state import TrainState
from omegaconf import OmegaConf
from pydantic import BaseModel

Functions

In [None]:
def target_update(
    model: TrainState, target_model: TrainState, tau: float
) -> TrainState:
    new_target_params = jax.tree_util.tree_map(
        lambda p, tp: p * tau + tp * (1 - tau), model.params, target_model.params
    )
    return target_model.replace(params=new_target_params)


def update_by_loss_grad(
    train_state: TrainState, loss_fn: Callable
) -> Tuple[TrainState, jnp.ndarray]:
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(train_state.params)
    new_train_state = train_state.apply_gradients(grads=grad)
    return new_train_state, loss

In [None]:
class TD3BCConfig(BaseModel):
    # GENERAL
    algo: str = "TD3-BC"
    project: str = "train-TD3-BC"
    env_name: str = "MiniGrid-Empty-8x8"
    seed: int = 42
    eval_episodes: int = 5
    log_interval: int = 100000
    eval_interval: int = 100000
    batch_size: int = 256
    max_steps: int = int(1e6)
    n_jitted_updates: int = 8
    # DATASET
    data_size: int = int(1e6)
    normalize_state: bool = True
    # NETWORK
    hidden_dims: Sequence[int] = (256, 256)
    critic_lr: float = 1e-3
    actor_lr: float = 1e-3
    # TD3-BC SPECIFIC
    policy_freq: int = 2  # update actor every policy_freq updates
    alpha: float = 2.5  # BC loss weight
    policy_noise_std: float = 0.2  # std of policy noise
    policy_noise_clip: float = 0.5  # clip policy noise
    tau: float = 0.005  # target network update rate
    discount: float = 0.99  # discount factor

    def __hash__(
        self,
    ):  # make config hashable to be specified as static_argnums in jax.jit.
        return hash(self.__repr__())

conf_dict = OmegaConf.from_cli() # CLI Input
config = TD3BCConfig(**conf_dict)

def default_init(scale: Optional[float] = jnp.sqrt(2)):
    return nn.initializers.orthogonal(scale)


class MLP(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    activate_final: bool = False
    kernel_init: Callable[[Any, Sequence[int], Any], jnp.ndarray] = default_init()
    layer_norm: bool = False

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for i, hidden_dims in enumerate(self.hidden_dims):
            x = nn.Dense(hidden_dims, kernel_init=self.kernel_init)(x)
            if i + 1 < len(self.hidden_dims) or self.activate_final:
                if self.layer_norm:  # Add layer norm after activation
                    if i + 1 < len(self.hidden_dims):
                        x = nn.LayerNorm()(x)
                x = self.activations(x)
        return x

class DoubleCritic(nn.Module):
    hidden_dims: Sequence[int]

    @nn.compact
    def __call__(
        self, observation: jnp.ndarray, action: jnp.ndarray
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        x = jnp.concatenate([observation, action], axis=-1)
        q1 = MLP((*self.hidden_dims, 1), layer_norm=True)(x)
        q2 = MLP((*self.hidden_dims, 1), layer_norm=True)(x)
        return q1, q2


class TD3Actor(nn.Module):
    hidden_dims: Sequence[int]
    action_dim: int
    max_action: float = 1.0  # In D4RL, action is scaled to [-1, 1]

    @nn.compact
    def __call__(self, observation: jnp.ndarray) -> jnp.ndarray:
        action = MLP((*self.hidden_dims, self.action_dim))(observation)
        action = self.max_action * jnp.tanh(
            action
        )  # scale to [-max_action, max_action]
        return action

class Transition(NamedTuple):
    observations: jnp.ndarray
    actions: jnp.ndarray
    rewards: jnp.ndarray
    next_observations: jnp.ndarray
    dones: jnp.ndarray

class TD3BCTrainState(NamedTuple):
    actor: TrainState
    critic: TrainState
    target_actor: TrainState
    target_critic: TrainState
    max_action: float = 1.0



TD3BC Object

In [None]:
class TD3BC(object):
    @classmethod
    def update_actor(
        self,
        train_state: TD3BCTrainState,
        batch: Transition,
        rng: jax.random.PRNGKey,
        config: TD3BCConfig,
    ) -> Tuple["TD3BCTrainState", jnp.ndarray]:
        def actor_loss_fn(actor_params: flax.core.FrozenDict[str, Any]) -> jnp.ndarray:
            predicted_action = train_state.actor.apply_fn(
                actor_params, batch.observations
            )
            critic_params = jax.lax.stop_gradient(train_state.critic.params)
            q_value, _ = train_state.critic.apply_fn(
                critic_params, batch.observations, predicted_action
            )

            mean_abs_q = jax.lax.stop_gradient(jnp.abs(q_value).mean())
            loss_lambda = config.alpha / mean_abs_q

            bc_loss = jnp.square(predicted_action - batch.actions).mean()
            loss_actor = -1.0 * q_value.mean() * loss_lambda + bc_loss
            return loss_actor

        new_actor, actor_loss = update_by_loss_grad(train_state.actor, actor_loss_fn)
        return train_state._replace(actor=new_actor), actor_loss

    @classmethod
    def update_critic(
        self,
        train_state: TD3BCTrainState,
        batch: Transition,
        rng: jax.random.PRNGKey,
        config: TD3BCConfig,
    ) -> Tuple["TD3BCTrainState", jnp.ndarray]:
        def critic_loss_fn(
            critic_params: flax.core.FrozenDict[str, Any]
        ) -> jnp.ndarray:
            q_pred_1, q_pred_2 = train_state.critic.apply_fn(
                critic_params, batch.observations, batch.actions
            )
            target_next_action = train_state.target_actor.apply_fn(
                train_state.target_actor.params, batch.next_observations
            )
            policy_noise = (
                config.policy_noise_std
                * train_state.max_action
                * jax.random.normal(rng, batch.actions.shape)
            )
            target_next_action = target_next_action + policy_noise.clip(
                -config.policy_noise_clip, config.policy_noise_clip
            )
            target_next_action = target_next_action.clip(
                -train_state.max_action, train_state.max_action
            )
            q_next_1, q_next_2 = train_state.target_critic.apply_fn(
                train_state.target_critic.params,
                batch.next_observations,
                target_next_action,
            )
            target = batch.rewards[..., None] + config.discount * jnp.minimum(
                q_next_1, q_next_2
            ) * (1 - batch.dones[..., None])
            target = jax.lax.stop_gradient(target)  # stop gradient for target
            value_loss_1 = jnp.square(q_pred_1 - target)
            value_loss_2 = jnp.square(q_pred_2 - target)
            value_loss = (value_loss_1 + value_loss_2).mean()
            return value_loss

        new_critic, critic_loss = update_by_loss_grad(
            train_state.critic, critic_loss_fn
        )
        return train_state._replace(critic=new_critic), critic_loss

    @classmethod
    def update_n_times(
        self,
        train_state: TD3BCTrainState,
        data: Transition,
        rng: jax.random.PRNGKey,
        config: TD3BCConfig,
    ) -> Tuple["TD3BCTrainState", Dict]:
        for _ in range(
            config.n_jitted_updates
        ):  # we can jit for roop for static unroll
            rng, batch_rng = jax.random.split(rng, 2)
            batch_idx = jax.random.randint(
                batch_rng, (config.batch_size,), 0, len(data.observations)
            )
            batch: Transition = jax.tree_util.tree_map(lambda x: x[batch_idx], data)
            rng, critic_rng, actor_rng = jax.random.split(rng, 3)
            train_state, critic_loss = self.update_critic(
                train_state, batch, critic_rng, config
            )
            if _ % config.policy_freq == 0:
                train_state, actor_loss = self.update_actor(
                    train_state, batch, actor_rng, config
                )
                new_target_critic = target_update(
                    train_state.critic, train_state.target_critic, config.tau
                )
                new_target_actor = target_update(
                    train_state.actor, train_state.target_actor, config.tau
                )
                train_state = train_state._replace(
                    target_critic=new_target_critic,
                    target_actor=new_target_actor,
                )
        return train_state, {
            "critic_loss": critic_loss,
            "actor_loss": actor_loss,
        }

    @classmethod
    def get_action(
        self,
        train_state: TD3BCTrainState,
        obs: jnp.ndarray,
        max_action: float = 1.0,  # In D4RL, action is scaled to [-1, 1]
    ) -> jnp.ndarray:
        action = train_state.actor.apply_fn(train_state.actor.params, obs)
        action = action.clip(-max_action, max_action)
        return action


Create TrainState

In [None]:
def create_td3bc_train_state(
    rng: jax.random.PRNGKey,
    observations: jnp.ndarray,
    actions: jnp.ndarray,
    config: TD3BCConfig,
) -> TD3BCTrainState:
    critic_model = DoubleCritic(
        hidden_dims=config.hidden_dims,
    )
    action_dim = actions.shape[-1]
    actor_model = TD3Actor(
        action_dim=action_dim,
        hidden_dims=config.hidden_dims,
    )
    rng, critic_rng, actor_rng = jax.random.split(rng, 3)
    # initialize critic
    critic_train_state: TrainState = TrainState.create(
        apply_fn=critic_model.apply,
        params=critic_model.init(critic_rng, observations, actions),
        tx=optax.adam(config.critic_lr),
    )
    target_critic_train_state: TrainState = TrainState.create(
        apply_fn=critic_model.apply,
        params=critic_model.init(critic_rng, observations, actions),
        tx=optax.adam(config.critic_lr),
    )
    # initialize actor
    actor_train_state: TrainState = TrainState.create(
        apply_fn=actor_model.apply,
        params=actor_model.init(actor_rng, observations),
        tx=optax.adam(config.actor_lr),
    )
    target_actor_train_state: TrainState = TrainState.create(
        apply_fn=actor_model.apply,
        params=actor_model.init(actor_rng, observations),
        tx=optax.adam(config.actor_lr),
    )
    return TD3BCTrainState(
        actor=actor_train_state,
        critic=critic_train_state,
        target_actor=target_actor_train_state,
        target_critic=target_critic_train_state,
    )

Evaluation

In [None]:
def evaluate(
    policy_fn: Callable[[jnp.ndarray], jnp.ndarray],
    env_name: str,
    num_episodes: int,
    obs_mean,
    obs_std,
    max_steps_per_episode: int = 100,
) -> float:
    """
    评估策略

    Args:
        policy_fn: 策略函数
        env_name: 环境名称
        num_episodes: episode数量
        obs_mean: observation均值
        obs_std: observation标准差
        max_steps_per_episode: 每个episode的最大步数

    Returns:
        平均episode回报
    """
    # 创建环境
    env, env_params = xminigrid.make(env_name)

    episode_returns = []

    for episode in range(num_episodes):
        episode_return = 0
        timestep = env.reset(env_params, jax.random.PRNGKey(episode))

        for step in range(max_steps_per_episode):
            # 处理observation - xminigrid的observation是直接的JAX数组
            obs_array = timestep.observation
            obs_numpy = np.array(obs_array)

            # xminigrid的observation形状是(7, 7, 2)
            if obs_numpy.shape == (7, 7, 2):
                # 将(7, 7, 2)转换为(7, 7, 3)的RGB图像
                object_types = obs_numpy[:, :, 0]
                colors = obs_numpy[:, :, 1]

                rgb_image = np.zeros((7, 7, 3), dtype=np.uint8)
                rgb_image[:, :, 0] = colors
                rgb_image[:, :, 1] = object_types
                rgb_image[:, :, 2] = 0

                # 上采样到22x22
                from scipy.ndimage import zoom
                try:
                    rgb_image = zoom(rgb_image, (22/7, 22/7, 1), order=0)
                except ImportError:
                    rgb_image = np.repeat(np.repeat(rgb_image, 3, axis=0), 3, axis=1)
                    rgb_image = rgb_image[:22, :22, :]

                direction = np.array([0.0])
            else:
                rgb_image = obs_numpy
                direction = np.array([0.0])

            obs_dict = {
                'image': rgb_image,
                'direction': direction
            }
            processed_obs = processor.process_observation(obs_dict)

            # 归一化observation
            if obs_mean is not None and obs_std is not None:
                processed_obs = (processed_obs - obs_mean) / obs_std

            # 获取动作
            action = policy_fn(obs=processed_obs)

            # 执行动作
            timestep = env.step(env_params, timestep, action)
            episode_return += timestep.reward

            if timestep.is_done():
                break

        episode_returns.append(episode_return)

    return np.mean(episode_returns)


In [None]:
if __name__ == "__main__":
    # wandb.init(project=config.project, config=config)

    rng = jax.random.PRNGKey(config.seed)
    # dataset, obs_mean, obs_std = get_dataset(config)

    # create train_state
    rng, subkey = jax.random.split(rng)
    # example_batch: Transition = jax.tree_util.tree_map(lambda x: x[0], dataset)
    train_state = create_td3bc_train_state(
        subkey, example_batch.observations, example_batch.actions, config
    )
    algo = TD3BC()
    update_fn = jax.jit(algo.update_n_times, static_argnums=(3,))
    act_fn = jax.jit(algo.get_action)

    num_steps = config.max_steps // config.n_jitted_updates
    eval_interval = config.eval_interval // config.n_jitted_updates
    for i in tqdm.tqdm(range(1, num_steps + 1), smoothing=0.1, dynamic_ncols=True):
        rng, update_rng = jax.random.split(rng)
        train_state, update_info = update_fn(
            train_state,
            dataset,
            update_rng,
            config,
        )  # update parameters
        if i % config.log_interval == 0:
            train_metrics = {f"training/{k}": v for k, v in update_info.items()}
            # wandb.log(train_metrics, step=i)

        if i % eval_interval == 0:
            policy_fn = partial(act_fn, train_state=train_state)
            normalized_score = evaluate(
                policy_fn,
                config.env_name,
                num_episodes=config.eval_episodes,
                obs_mean=obs_mean,
                obs_std=obs_std,
            )
            print(i, normalized_score)
            eval_metrics = {f"{config.env_name}/episode_return": normalized_score}
            # wandb.log(eval_metrics, step=i)

    # # final evaluation
    # policy_fn = partial(act_fn, train_state=train_state)
    # normalized_score = evaluate(
    #     policy_fn,
    #     config.env_name,
    #     num_episodes=config.eval_episodes,
    #     obs_mean=obs_mean,
    #     obs_std=obs_std,
    # )
    # print("Final Evaluation Score:", normalized_score)
    # wandb.log({f"{config.env_name}/final_episode_return": normalized_score})
    # wandb.finish()

NameError: name 'create_td3bc_train_state' is not defined