# Export ReplayBuffer from LoCo-MuJoCo (UnitreeG1, default:walking)

- Loads Unitree G1 in LoCo-MuJoCo and selects the 'walking' motion from the default dataset (alias: 'walk').
- Uses `TrajectoryDatasetManager` to create Zarr if missing and to step reference data.
- Builds a TorchRL memmap-backed replay buffer from the saved Zarr.
- Demonstrates step-wise sequential sampling and random minibatch sampling.


In [None]:
import os, gc, json, sys
from pathlib import Path
import numpy as np
import torch
from omegaconf import DictConfig
from tensordict import TensorDict

# Repo import for local package
repo_root = Path.cwd().parent
if (repo_root / "src").exists():
    sys.path.insert(0, str(repo_root / "src"))

from iltools_datasets.loco_mujoco.loader import LocoMuJoCoLoader
from iltools_datasets.manager import TrajectoryDatasetManager
from iltools_datasets.replay_export import build_replay_from_zarr
from iltools_datasets.replay_manager import EnvAssignment

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

Using device: cuda:0


In [16]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Configure LoCo-MuJoCo (G1 walking)

In [None]:
# LoCo-MuJoCo base cfg (see tests for example usage)
basic_cfg = DictConfig(
    {
        "dataset": {"trajectories": {"default": ["walk"], "amass": [], "lafan1": []}},
        "control_freq": 50.0,
        "window_size": 4,
        "sim": {"dt": 0.001},
        "decimation": 20,
    }
)

# Resolve joint names via loader metadata to satisfy manager mapping
tmp_loader = LocoMuJoCoLoader(env_name="UnitreeG1", cfg=basic_cfg)
joint_names = list(tmp_loader.metadata.joint_names)
print("Found", len(joint_names), "joints")
print("joint_names", joint_names)
del tmp_loader
joint_names = joint_names[1:]  # no need for root joint

# Paths
DATA_DIR = Path(repo_root) / "examples" / "data" / "g1_default_walk"
DATA_DIR.mkdir(parents=True, exist_ok=True)
ZARR_PATH = DATA_DIR / "trajectories.zarr"

# Manager cfg (creates Zarr if missing using LocoMuJoCoLoader)
mgr_cfg = DictConfig(
    {
        "dataset_path": str(DATA_DIR),
        "dataset": {"trajectories": {"default": ["walk"], "amass": [], "lafan1": []}},
        "loader_type": "loco_mujoco",
        "loader_kwargs": {"env_name": "UnitreeG1", "cfg": basic_cfg},
        "assignment_strategy": "sequential",
        "window_size": 4,
        "target_joint_names": joint_names,
        "reference_joint_names": joint_names,
    }
)

[LocoMuJoCoLoader] Initializing LocoMuJoCoLoader
[LocoMuJoCoLoader] Dataset dictionary: {'default': ['walk'], 'amass': [], 'lafan1': []}


100%|██████████| 35198/35198 [00:05<00:00, 6815.72it/s]


Found 24 joints
joint_names ['root', 'left_hip_pitch_joint', 'left_hip_roll_joint', 'left_hip_yaw_joint', 'left_knee_joint', 'left_ankle_pitch_joint', 'left_ankle_roll_joint', 'right_hip_pitch_joint', 'right_hip_roll_joint', 'right_hip_yaw_joint', 'right_knee_joint', 'right_ankle_pitch_joint', 'right_ankle_roll_joint', 'waist_yaw_joint', 'left_shoulder_pitch_joint', 'left_shoulder_roll_joint', 'left_shoulder_yaw_joint', 'left_elbow_joint', 'left_wrist_roll_joint', 'right_shoulder_pitch_joint', 'right_shoulder_roll_joint', 'right_shoulder_yaw_joint', 'right_elbow_joint', 'right_wrist_roll_joint']


In [18]:
manager = TrajectoryDatasetManager(cfg=mgr_cfg, num_envs=8, device="cuda:0")
manager.reset_trajectories()
ref0 = manager.get_reference_data()
print("Reference data keys:", list(ref0.keys()))
print("Zarr ready at:", ZARR_PATH)

[TrajectoryDatasetManager] Using device: cuda:0
[TrajectoryDatasetManager] Initialized with 1 trajectories, 8 envs
Reference data keys: ['root_pos', 'root_quat', 'root_lin_vel', 'root_ang_vel', 'joint_pos', 'joint_vel', 'raw_qpos', 'raw_qvel']
Zarr ready at: /home/fwu/Documents/Research/SkillLearning/ImitationLearningTools/examples/data/g1_default_walk/trajectories.zarr


## Build memmap-backed replay buffer from Zarr

In [None]:
# Export replay using qpos as observation; action auto-detected if present
replay_mgr = build_replay_from_zarr(
    zarr_path=str(ZARR_PATH),
    scratch_dir=str(DATA_DIR / "memmap"),
    obs_keys=["qpos"],
    concat_obs_to_key="observation",
    include_terminated=True,
    include_truncated=True,
)
print("Replay transitions available:", len(replay_mgr.buffer))

[build_replay_from_zarr] td: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([87994, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        next: TensorDict(
            fields={
                observation: Tensor(shape=torch.Size([87994, 30]), device=cpu, dtype=torch.float32, is_shared=False),
                qpos: Tensor(shape=torch.Size([87994, 30]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([87994]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([87994, 30]), device=cpu, dtype=torch.float32, is_shared=False),
        qpos: Tensor(shape=torch.Size([87994, 30]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([87994]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([87994]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([87994]),
    device=cpu

In [20]:
buffer = replay_mgr.buffer
buffer

TensorDictReplayBuffer(
    storage=LazyMemmapStorage(
        data=TensorDict(
            fields={
                action: MemoryMappedTensor(shape=torch.Size([87994, 1]), device=cpu, dtype=torch.float32, is_shared=True),
                next: TensorDict(
                    fields={
                        observation: MemoryMappedTensor(shape=torch.Size([87994, 30]), device=cpu, dtype=torch.float32, is_shared=True),
                        qpos: MemoryMappedTensor(shape=torch.Size([87994, 30]), device=cpu, dtype=torch.float32, is_shared=True)},
                    batch_size=torch.Size([87994]),
                    device=cpu,
                    is_shared=False),
                observation: MemoryMappedTensor(shape=torch.Size([87994, 30]), device=cpu, dtype=torch.float32, is_shared=True),
                qpos: MemoryMappedTensor(shape=torch.Size([87994, 30]), device=cpu, dtype=torch.float32, is_shared=True),
                terminated: MemoryMappedTensor(shape=torch.Size([87994])

## Step-wise sequential sampler

In [None]:
# Build a simple assignment: all 8 envs read the first segment (task 0, traj 0)
asg = [EnvAssignment(task_id=0, traj_id=0, step=0) for i in range(8)]
replay_mgr.set_assignment(asg)
# Sample twice; each env advances one step and wraps per-trajectory
b1 = replay_mgr.buffer.sample()
b2 = replay_mgr.buffer.sample()
print(
    "Sequential batch1 shape:",
    {k: tuple(v.shape) for k, v in b1.items() if isinstance(v, torch.Tensor)},
)
print("Sequential batch2 observation head:", b2["observation"][:2])

Sequential batch1 shape: {'observation': (8, 30), 'action': (8, 1), 'qpos': (8, 30), 'terminated': (8,), 'truncated': (8,), 'index': (8,)}
Sequential batch2 observation head: tensor([[ 8.9815e-03, -6.0037e-04,  7.8958e-01,  9.9513e-01,  2.6733e-03,
          7.5421e-02,  6.3350e-02,  1.5073e-01, -2.4139e-02, -4.5167e-02,
          2.1360e-01, -3.2380e-01,  4.2471e-02, -6.7555e-01, -2.7696e-02,
         -1.5894e-01,  2.0293e-01,  9.8581e-03, -6.1347e-02, -1.7089e-01,
         -2.8424e-01,  1.7141e-01,  1.1185e-01,  5.9139e-01,  7.3095e-02,
          2.1847e-01, -1.1540e-02, -9.2835e-01,  1.3153e+00, -1.2692e+00],
        [ 8.9815e-03, -6.0037e-04,  7.8958e-01,  9.9513e-01,  2.6733e-03,
          7.5421e-02,  6.3350e-02,  1.5073e-01, -2.4139e-02, -4.5167e-02,
          2.1360e-01, -3.2380e-01,  4.2471e-02, -6.7555e-01, -2.7696e-02,
         -1.5894e-01,  2.0293e-01,  9.8581e-03, -6.1347e-02, -1.7089e-01,
         -2.8424e-01,  1.7141e-01,  1.1185e-01,  5.9139e-01,  7.3095e-02,
          

In [None]:
b1["observation"]

tensor([[ 0.0000,  0.0000,  0.7905,  0.9951,  0.0016,  0.0764,  0.0624,  0.1464,
         -0.0225, -0.0381,  0.2012, -0.3344,  0.0427, -0.6798, -0.0301, -0.1662,
          0.1969, -0.0031, -0.0641, -0.1703, -0.2879,  0.1746,  0.1106,  0.5922,
          0.0746,  0.2183, -0.0113, -0.9256,  1.3148, -1.2664],
        [ 0.0000,  0.0000,  0.7905,  0.9951,  0.0016,  0.0764,  0.0624,  0.1464,
         -0.0225, -0.0381,  0.2012, -0.3344,  0.0427, -0.6798, -0.0301, -0.1662,
          0.1969, -0.0031, -0.0641, -0.1703, -0.2879,  0.1746,  0.1106,  0.5922,
          0.0746,  0.2183, -0.0113, -0.9256,  1.3148, -1.2664],
        [ 0.0000,  0.0000,  0.7905,  0.9951,  0.0016,  0.0764,  0.0624,  0.1464,
         -0.0225, -0.0381,  0.2012, -0.3344,  0.0427, -0.6798, -0.0301, -0.1662,
          0.1969, -0.0031, -0.0641, -0.1703, -0.2879,  0.1746,  0.1106,  0.5922,
          0.0746,  0.2183, -0.0113, -0.9256,  1.3148, -1.2664],
        [ 0.0000,  0.0000,  0.7905,  0.9951,  0.0016,  0.0764,  0.0624,  0.1464

In [39]:
b2["observation"]

tensor([[ 8.9815e-03, -6.0037e-04,  7.8958e-01,  9.9513e-01,  2.6733e-03,
          7.5421e-02,  6.3350e-02,  1.5073e-01, -2.4139e-02, -4.5167e-02,
          2.1360e-01, -3.2380e-01,  4.2471e-02, -6.7555e-01, -2.7696e-02,
         -1.5894e-01,  2.0293e-01,  9.8581e-03, -6.1347e-02, -1.7089e-01,
         -2.8424e-01,  1.7141e-01,  1.1185e-01,  5.9139e-01,  7.3095e-02,
          2.1847e-01, -1.1540e-02, -9.2835e-01,  1.3153e+00, -1.2692e+00],
        [ 8.9815e-03, -6.0037e-04,  7.8958e-01,  9.9513e-01,  2.6733e-03,
          7.5421e-02,  6.3350e-02,  1.5073e-01, -2.4139e-02, -4.5167e-02,
          2.1360e-01, -3.2380e-01,  4.2471e-02, -6.7555e-01, -2.7696e-02,
         -1.5894e-01,  2.0293e-01,  9.8581e-03, -6.1347e-02, -1.7089e-01,
         -2.8424e-01,  1.7141e-01,  1.1185e-01,  5.9139e-01,  7.3095e-02,
          2.1847e-01, -1.1540e-02, -9.2835e-01,  1.3153e+00, -1.2692e+00],
        [ 8.9815e-03, -6.0037e-04,  7.8958e-01,  9.9513e-01,  2.6733e-03,
          7.5421e-02,  6.3350e-02,  

## Random minibatch sampler

In [None]:
# Switch to uniform random minibatch sampler (without replacement)
replay_mgr.set_uniform_sampler(batch_size=1024, without_replacement=True)
rb_batch = replay_mgr.buffer.sample()
print("Random minibatch size:", rb_batch.batch_size)
print("Keys:", list(rb_batch.keys(True)))

Random minibatch size: torch.Size([1024])
Keys: ['observation', 'action', ('next', 'observation'), ('next', 'qpos'), 'next', 'qpos', 'terminated', 'truncated', 'index']


## Replay Buffer Examples and Tests

In [None]:
# Inspect segment metadata and presence of auxiliary keys
print("Num segments:", len(replay_mgr.segments))
print(
    "First 3 segments:",
    [(s.task_id, s.traj_id, s.length) for s in replay_mgr.segments[:3]],
)
batch = replay_mgr.buffer.sample()
print("Sampled keys:", list(batch.keys(True)))
print("Has terminated?", "terminated" in batch, "Has truncated?", "truncated" in batch)
print(
    "Batch sizes:", {k: tuple(v.shape) for k, v in batch.items() if hasattr(v, "shape")}
)

Num segments: 1
First 3 segments: [(0, 0, 87994)]
Sampled keys: ['observation', 'action', ('next', 'observation'), ('next', 'qpos'), 'next', 'qpos', 'terminated', 'truncated', 'index']
Has terminated? True Has truncated? True
Batch sizes: {'observation': (1024, 30), 'action': (1024, 1), 'next': (1024,), 'qpos': (1024, 30), 'terminated': (1024,), 'truncated': (1024,), 'index': (1024,)}


### Step-wise Sequential Sampler (assignment and wraparound)

In [None]:
# Assign all envs to the first segment and stagger start steps to show progression
first_seg = replay_mgr.segments[0]
num_envs = 6
asg = [
    EnvAssignment(task_id=first_seg.task_id, traj_id=first_seg.traj_id, step=i)
    for i in range(num_envs)
]
replay_mgr.set_assignment(asg)
b1 = replay_mgr.buffer.sample()
b2 = replay_mgr.buffer.sample()
print("Sequential sampler shapes:", b1["observation"].shape, b2["observation"].shape)
print("First env obs head b1/b2:", b1["observation"][0, :4], b2["observation"][0, :4])
# Note: Actual values depend on dataset; we demonstrate API and per-call advancement.

Sequential sampler shapes: torch.Size([6, 30]) torch.Size([6, 30])
First env obs head b1/b2: tensor([0.0000, 0.0000, 0.7905, 0.9951]) tensor([ 8.9815e-03, -6.0037e-04,  7.8958e-01,  9.9513e-01])


### Random Minibatch Sampler

In [None]:
# Switch to uniform minibatching (without replacement)
replay_mgr.set_uniform_sampler(batch_size=512, without_replacement=True)
rb_batch = replay_mgr.buffer.sample()
print("Uniform minibatch size:", rb_batch.batch_size)
# With replacement (sampler=None)
replay_mgr.set_uniform_sampler(batch_size=128, without_replacement=False)
rb_batch_rep = replay_mgr.buffer.sample()
print("With-replacement minibatch size:", rb_batch_rep.batch_size)

Uniform minibatch size: torch.Size([512])
With-replacement minibatch size: torch.Size([128])


### Device Transform (prefetch to GPU if available)

In [None]:
# Force reload of the replay_manager module to pick up latest changes
import importlib
import iltools_datasets.replay_manager

importlib.reload(iltools_datasets.replay_manager)

target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Target device:", target_device)
replay_mgr.set_device_transform(target_device)
bs = replay_mgr.buffer.sample()
print("Device:", bs["observation"].device)
print("Device types match:", bs["observation"].device.type == target_device.type)
if bs["observation"].device.type == target_device.type:
    print("✅ Device transform ok")
else:
    print("❌ Device transform failed - data still on CPU")
    print(
        "This might be due to buffer recreation losing transforms. Check replay_manager.py"
    )

Target device: cuda
Device: cpu
Device types match: False
❌ Device transform failed - data still on CPU
This might be due to buffer recreation losing transforms. Check replay_manager.py


### Synthetic Tests (deterministic observations)

In [None]:
from iltools_datasets.replay_manager import ExpertReplayManager, ExpertReplaySpec
from iltools_datasets.replay_memmap import build_trajectory_td, Segment


def _mk_traj(task_id: int, traj_id: int, T: int, obs_dim: int = 3, act_dim: int = 1):
    t = torch.arange(T, dtype=torch.float32).unsqueeze(-1)
    obs = torch.cat(
        [torch.full_like(t, float(task_id)), torch.full_like(t, float(traj_id)), t],
        dim=1,
    )
    nxt = obs + 0.5
    act = torch.zeros(T, act_dim)
    return build_trajectory_td(observation=obs, action=act, next_observation=nxt)


# Build small tasks set
tasks = {0: [_mk_traj(0, 0, T=3)], 1: [_mk_traj(1, 0, T=2)]}
tmp_dir = str((Path.cwd() / "_tmp_memmap").absolute())
mgr2 = ExpertReplayManager(
    ExpertReplaySpec(
        tasks=tasks, scratch_dir=tmp_dir, device="cpu", sample_batch_size=4
    )
)

# Sequential assignment for 3 envs
asg = [EnvAssignment(0, 0, 0), EnvAssignment(1, 0, 0), EnvAssignment(0, 0, 2)]
mgr2.set_assignment(asg)
out = mgr2.buffer.sample()
obs = out["observation"]
assert torch.allclose(obs[0], torch.tensor([0.0, 0.0, 0.0]))
assert torch.allclose(obs[1], torch.tensor([1.0, 0.0, 0.0]))
assert torch.allclose(obs[2], torch.tensor([0.0, 0.0, 2.0]))
out2 = mgr2.buffer.sample()
obs2 = out2["observation"]
assert torch.allclose(obs2[0], torch.tensor([0.0, 0.0, 1.0]))
assert torch.allclose(obs2[1], torch.tensor([1.0, 0.0, 1.0]))
assert torch.allclose(obs2[2], torch.tensor([0.0, 0.0, 0.0]))
print("✅ Sequential sampler synthetic test passed")

# Uniform samplers
mgr2.set_uniform_sampler(batch_size=5, without_replacement=True)
u1 = mgr2.buffer.sample()
assert u1.batch_size[0] == 5
mgr2.set_uniform_sampler(batch_size=3, without_replacement=False)
u2 = mgr2.buffer.sample()
assert u2.batch_size[0] == 3
print("✅ Uniform sampler tests passed")

[ExpertMemmapBuilder] setting storage in dir /home/fwu/Documents/Research/SkillLearning/ImitationLearningTools/examples/_tmp_memmap with range 0 to 3
[ExpertMemmapBuilder] setting storage in dir /home/fwu/Documents/Research/SkillLearning/ImitationLearningTools/examples/_tmp_memmap with range 3 to 5
✅ Sequential sampler synthetic test passed
✅ Uniform sampler tests passed


### Zarr Dataset Summary

In [None]:
import zarr, os

if "ZARR_PATH" not in globals():
    print("ZARR_PATH is undefined. Run the configuration cells first.")
else:
    zp = str(ZARR_PATH)
    if not os.path.exists(zp):
        print("Zarr not found at", zp, "- run earlier cells to create it.")
    else:
        root = zarr.open_group(zp, mode="r")
        print("Zarr:", zp)
        for ds_name in root.keys():
            ds_group = root[ds_name]
            motions = [
                k for k in ds_group.keys() if isinstance(ds_group[k], zarr.Group)
            ]
            print(f"- Dataset source: {ds_name} (motions: {len(motions)})")
            for motion in motions:
                mg = ds_group[motion]
                trajs = [k for k in mg.keys() if isinstance(mg[k], zarr.Group)]
                lengths = []
                for traj in trajs:
                    tg = mg[traj]
                    # Prefer 'qpos' to determine T, else first array key
                    arr_key = (
                        "qpos"
                        if "qpos" in tg
                        else next(
                            (k for k in tg.keys() if isinstance(tg[k], zarr.Array)),
                            None,
                        )
                    )
                    T = int(tg[arr_key].shape[0]) if arr_key is not None else -1
                    lengths.append(T)
                total_T = sum(max(0, T) for T in lengths)
                print(
                    f"  • Motion: {motion:>20} | trajs: {len(trajs):3d} | mean T: { (sum(lengths)/len(lengths)) if lengths else 0:.1f} | total T: {total_T}"
                )

Zarr: /home/fwu/Documents/Research/SkillLearning/ImitationLearningTools/examples/data/g1_default_walk/trajectories.zarr
- Dataset source: loco_mujoco (motions: 1)
  • Motion:         default_walk | trajs:   1 | mean T: 87995.0 | total T: 87995
