In [None]:
%pip install jax
%pip install numpy
%pip install matplotlib
%pip install xminigrid

In [2]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np

import timeit
import imageio
import matplotlib.pyplot as plt
from tqdm.auto import trange, tqdm

import xminigrid

In [None]:
class TimeStep(struct.PyTreeNode):
    # hidden environment state, such as grid, agent, goal, etc
    state: State

    # similar to the dm_env enterface
    step_type: StepType
    reward: jax.Array
    discount: jax.Array
    observation: jax.Array

In [22]:
from xminigrid.wrappers import GymAutoResetWrapper

def build_rollout(env, env_params, num_steps):
  def rollout(rng):
    def _step_fn(carry, _):
      rng, timestep = carry
      rng, _rng = jax.random.split(rng)
      action = jax.random.randint(_rng, shape=(), minval=0, maxval=env.num_actions(env_params))

      timestep = env.step(env_params, timestep, action)

      return (rng, timestep), (timestep,action)

    rng, _rng = jax.random.split(rng)
    timestep = env.reset(env_params, _rng)
    rng, (transitions, actions) = jax.lax.scan(_step_fn, (rng, timestep), None, length=num_steps)

    return transitions, actions
  return rollout

In [24]:
env, env_params = xminigrid.make("MiniGrid-EmptyRandom-8x8")
env = GymAutoResetWrapper(env)

rollout_fn = jax.jit(build_rollout(env, env_params, num_steps=1000))

transitions, actions = rollout_fn(jax.random.key(0))

In [25]:
print("Transitions shapes: \n", jtu.tree_map(jnp.shape, transitions))
print("Actions shape:", actions.shape)
print(type(actions))

Transitions shapes: 
 TimeStep(state=State(key=(1000,), step_num=(1000,), grid=(1000, 8, 8, 2), agent=AgentState(position=(1000, 2), direction=(1000,), pocket=(1000, 2)), goal_encoding=(1000, 5), rule_encoding=(1000, 1, 7), carry=EnvCarry()), step_type=(1000,), reward=(1000,), discount=(1000,), observation=(1000, 7, 7, 2))
Actions shape: (1000,)
<class 'jaxlib.xla_extension.ArrayImpl'>


In [40]:
def create_replay_buffer(transitions, actions):

  observations = transitions.observation # (T, 7, 7, 2)
  rewards = transitions.reward # (T,)
  dones = transitions.step_type == 2 # (T,)
  next_observations = jnp.concatenate([observations[1:], observations[-1:]], axis=0) #(T, 7, 7, 2)
  actions = jnp.array(actions, dtype=jnp.int32) #(1000,)

  replay_buffer = {'observations': observations,
                   'actions': actions,
                   'rewards': rewards,
                   'next_observations': next_observations,
                   'dones': dones}

  # print("=== Replay Buffer 构建完成 ===")
  # print(f"数据点数量: {len(observations)}")
  # print(f"平均奖励: {jnp.mean(rewards):.4f}")
  # print(f"Episode结束次数: {jnp.sum(dones)}")
  # print(f"动作分布: {jnp.bincount(actions)}")
  return replay_buffer

Potential issue with sparse reward

In [41]:
replay_buffer = create_replay_buffer(transitions, actions)

In [42]:
def create_batches(replay_buffer, batch_size=32, num_batches=None):
  data_size = len(replay_buffer['observations'])

  if num_batches is None:
    num_batches = max(1, data_size // batch_size)
