In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Check dataset action dim

In [13]:
import numpy as np
from lerobot.common.datasets import LeRobotDataset
ds = LeRobotDataset("lerobot/xarm_lift_medium")

ds.download_episodes()

# Print dataset info
print("Dataset info:")
print(f"Number of episodes: {ds.num_episodes}")
print(f"Features: {ds.features}")

Dataset info:
Number of episodes: 800
Features: {'observation.image': {'dtype': 'video', 'shape': (84, 84, 3), 'names': ['height', 'width', 'channel'], 'video_info': {'video.fps': 15.0, 'video.codec': 'av1', 'video.pix_fmt': 'yuv420p', 'video.is_depth_map': False, 'has_audio': False}}, 'observation.state': {'dtype': 'float32', 'shape': (4,), 'names': {'motors': ['motor_0', 'motor_1', 'motor_2', 'motor_3']}}, 'action': {'dtype': 'float32', 'shape': (4,), 'names': {'motors': ['motor_0', 'motor_1', 'motor_2', 'motor_3']}}, 'episode_index': {'dtype': 'int64', 'shape': (1,), 'names': None}, 'frame_index': {'dtype': 'int64', 'shape': (1,), 'names': None}, 'timestamp': {'dtype': 'float32', 'shape': (1,), 'names': None}, 'next.reward': {'dtype': 'float32', 'shape': (1,), 'names': None}, 'next.done': {'dtype': 'bool', 'shape': (1,), 'names': None}, 'index': {'dtype': 'int64', 'shape': (1,), 'names': None}, 'task_index': {'dtype': 'int64', 'shape': (1,), 'names': None}}


In [14]:
# Try to access a single episode first
try:
    # Get the first item from the dataset
    first_item = ds[0]
    print("\nFirst item keys:", first_item.keys())
    
    # Now try to access episodes
    print("\nTrying to access episodes...")
    all_actions = []
    for i in range(min(5, ds.num_episodes)):  # Try first 5 episodes
        item = ds[i]
        if "action" in item:
            all_actions.append(item["action"])
    
    print(f"Successfully collected {len(all_actions)} actions")
except Exception as e:
    print("Error:", str(e))


First item keys: dict_keys(['observation.image', 'observation.state', 'action', 'episode_index', 'frame_index', 'timestamp', 'next.reward', 'next.done', 'index', 'task_index', 'task'])

Trying to access episodes...
Successfully collected 5 actions


objc[57261]: Class AVFFrameReceiver is implemented in both /Users/OAA/miniforge3/envs/robotics/lib/python3.10/site-packages/av/.dylibs/libavdevice.61.3.100.dylib (0x11bd083a8) and /Users/OAA/miniforge3/envs/robotics/lib/libavdevice.61.3.100.dylib (0x16bb4c848). One of the two will be used. Which one is undefined.
objc[57261]: Class AVFAudioReceiver is implemented in both /Users/OAA/miniforge3/envs/robotics/lib/python3.10/site-packages/av/.dylibs/libavdevice.61.3.100.dylib (0x11bd083f8) and /Users/OAA/miniforge3/envs/robotics/lib/libavdevice.61.3.100.dylib (0x16bb4c898). One of the two will be used. Which one is undefined.


In [18]:
all_actions = []
for episode in ds:
    # print(episode)
    all_actions.append(episode["action"])
    # for step in episode:
    #     print(step)
    #     all_actions.append(step["action"])

actions = np.array(all_actions)

print("Action Mean:", actions.mean(axis=0))
print("Action Std:", actions.std(axis=0))
print("Action Min:", actions.min(axis=0))
print("Action Max:", actions.max(axis=0))

Action Mean: [ 0.27324945 -0.14783773 -0.15354335 -0.23991735]
Action Std: [0.631202   0.6673078  0.6527433  0.65370834]
Action Min: [-1. -1. -1. -1.]
Action Max: [1. 1. 1. 1.]


# Testing VAE

In [44]:
import torch
import numpy as np
from pathlib import Path
from lerobot.common.policies.action_vae import ActionVAE
from operator import itemgetter

In [27]:
device = torch.device("mps")

vae = ActionVAE(
    input_dim=2, 
    latent_dim=2,
    hidden_dims=[256, 128, 64],
    ).to(device)

path = Path("checkpoints/trained_beta_vae.pth")
state = torch.load(path, map_location=device)

vae.load_state_dict(
        state['model_state_dict'] if 'model_state_dict' in state  # -- saved with model_state_dict key
        else state['model'] if 'model' in state         # -- diffusion-policy style
        else state['state_dict'] if 'state_dict' in state
        else state                                 # plain torch.save(vae.state_dict())
)
vae.eval()
print(vae)

ActionVAE(
  (encoder_layers): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=2, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Dropout(p=0.1, inplace=False)
    )
    (1): Sequential(
      (0): Linear(in_features=256, out_features=128, bias=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Dropout(p=0.1, inplace=False)
    )
    (2): Sequential(
      (0): Linear(in_features=128, out_features=64, bias=True)
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
      (3): Dropout(p=0.1, inplace=False)
    )
  )
  (fc_mu): Linear(in_features=64, out_features=2, bias=True)
  (fc_var): Linear(in_features=64, out_features=2, bias=True)
  (decoder_layers): ModuleList(
    (0)

In [28]:
# Load lerobot dataset
from datasets import load_dataset
ds = load_dataset("lerobot/pusht", split="train")

In [29]:
def fetch_episode(ds, ep_idx: int, max_len: int = None):
    """Return a list of actions for the requested episode_index."""
    # ➊ keep only rows that match the episode
    ep_rows = ds.filter(lambda r: r["episode_index"] == ep_idx)
    # ➋ make sure they’re sorted by frame_index
    ep_rows = sorted(ep_rows, key=itemgetter("frame_index"))
    # ➌ pull the 'action' field and (optionally) truncate to max_len
    actions = [row["action"] for row in ep_rows]
    if max_len is not None:
        actions = actions[:max_len]
    return torch.tensor(actions, dtype=torch.float32, device=device)  # (T,2)｜

In [39]:
ep_idx = 9
actions = fetch_episode(ds, ep_idx=ep_idx, max_len=32)
print(actions.shape)

Filter: 100%|██████████| 25650/25650 [00:00<00:00, 127010.99 examples/s]

torch.Size([32, 2])





In [40]:
with torch.no_grad():
    recon, _, mu, log_var = vae(actions)

mse = torch.mean((actions - recon) ** 2).mean()
print(f"MSE per element: {mse:.6e}")

MSE per element: 6.913333e+04


In [None]:
import gym, imageio
env = gym.make('gym_pusht/PushT-v0')
frames = []
obs, _ = env.reset(seed=0)

for a in actions.cpu().numpy():
    obs, _, done, _ = env.step(a)
    frames.append(env.render(mode='rgb_array'))
    if done:
        break

imageio.mimsave('orig.gif', frames, fps=20)

# Test Diffusion config with VAE

In [1]:
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.configs.types import FeatureType

import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
# creating the policy:
#   - input/output shapes: to properly size the policy
#   - dataset stats: for normalization and denormalization of input/outputs
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}


In [3]:
cfg = DiffusionConfig(input_features=input_features,output_features=output_features)

# Check action dimension
print("Action feature shape:", cfg.action_feature.shape)

policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)



Action feature shape: (2,)
Saved VAE input_dim: 2
Current action_dim: 2
Loaded VAE checkpoint from checkpoints/trained_beta_vae.pth
VAE device: mps:0
VAE input_dim: 2


In [4]:
device = torch.device("mps")
B = 4
Tobs = cfg.n_obs_steps
Th = cfg.horizon
A = 2

batch = {
    # proprio state: (B, Tobs, state_dim)
    "observation.state": torch.randn(B, Tobs, cfg.robot_state_feature.shape[0], device=device),

    # a dummy RGB stream so image encoder path is exercised
    "observation.images": torch.rand(
        B, Tobs, 1,                     # one “camera”
        3, *cfg.crop_shape, device=device
    ),
    # action trajectory + padding mask
    "action":         torch.randn(B, Th, A, device=device),
    "action_is_pad":  torch.zeros (B, Th,     dtype=torch.bool, device=device)
}

In [7]:
actions_flat = batch["action"].reshape(-1, A)

mu, log_var = policy.vae.encode(actions_flat)          # (B, Th, latent_dim)
z           = policy.vae.reparameterize(mu, log_var)
recon       = policy.vae.decode(z)

print("VAE | action -> latent -> action:", batch["action"].shape, "→", z.shape, "→", recon.shape)

mse = torch.mean((actions_flat - recon) ** 2).item()
print(f"VAE reconstruction MSE  : {mse:.3e}")
print(f"KL per-dim (mean)       : {(-0.5*(1+log_var-mu**2-log_var.exp())).mean():.3f}")


VAE | action -> latent -> action: torch.Size([4, 16, 2]) → torch.Size([64, 2]) → torch.Size([64, 2])
VAE reconstruction MSE  : 1.373e-01
KL per-dim (mean)       : 12.016


In [None]:
def fetch_episode(ds, ep_idx: int, max_len: int = None):
    """Return a list of actions for the requested episode_index."""
    # ➊ keep only rows that match the episode
    ep_rows = ds.filter(lambda r: r["episode_index"] == ep_idx)
    # ➋ make sure they’re sorted by frame_index
    ep_rows = sorted(ep_rows, key=itemgetter("frame_index"))
    # ➌ pull the 'action' field and (optionally) truncate to max_len
    actions = [row["action"] for row in ep_rows]
    if max_len is not None:
        actions = actions[:max_len]
    return torch.tensor(actions, dtype=torch.float32, device=device)  # (T,2)｜

In [None]:
def fetch_episode(ds, ep_idx: int, max_len: int = None):
    """Return a list of actions for the requested episode_index."""
    # ➊ keep only rows that match the episode
    ep_rows = ds.filter(lambda r: r["episode_index"] == ep_idx)
    # ➋ make sure they’re sorted by frame_index
    ep_rows = sorted(ep_rows, key=itemgetter("frame_index"))
    # ➌ pull the 'action' field and (optionally) truncate to max_len
    actions = [row["action"] for row in ep_rows]
    if max_len is not None:
        actions = actions[:max_len]
    return torch.tensor(actions, dtype=torch.float32, device=device)  # (T,2)｜

In [None]:
def fetch_episode(ds, ep_idx: int, max_len: int = None):
    """Return a list of actions for the requested episode_index."""
    # ➊ keep only rows that match the episode
    ep_rows = ds.filter(lambda r: r["episode_index"] == ep_idx)
    # ➋ make sure they’re sorted by frame_index
    ep_rows = sorted(ep_rows, key=itemgetter("frame_index"))
    # ➌ pull the 'action' field and (optionally) truncate to max_len
    actions = [row["action"] for row in ep_rows]
    if max_len is not None:
        actions = actions[:max_len]
    return torch.tensor(actions, dtype=torch.float32, device=device)  # (T,2)｜

In [None]:
def fetch_episode(ds, ep_idx: int, max_len: int = None):
    """Return a list of actions for the requested episode_index."""
    # ➊ keep only rows that match the episode
    ep_rows = ds.filter(lambda r: r["episode_index"] == ep_idx)
    # ➋ make sure they’re sorted by frame_index
    ep_rows = sorted(ep_rows, key=itemgetter("frame_index"))
    # ➌ pull the 'action' field and (optionally) truncate to max_len
    actions = [row["action"] for row in ep_rows]
    if max_len is not None:
        actions = actions[:max_len]
    return torch.tensor(actions, dtype=torch.float32, device=device)  # (T,2)｜

In [62]:
# ---------- 2. VAE round-trip alone ----------
mu, log_var = policy.vae.encode(batch["action"])          # (B, Th, latent_dim)
z           = policy.vae.reparameterize(mu, log_var)
recon       = policy.vae.decode(z)

print("VAE | action -> latent -> action:", batch["action"].shape, "→", z.shape, "→", recon.shape)

RuntimeError: Tensor for argument weight is on cpu but expected on mps

In [None]:
def fetch_episode(ds, ep_idx: int, max_len: int = None):
    """Return a list of actions for the requested episode_index."""
    # ➊ keep only rows that match the episode
    ep_rows = ds.filter(lambda r: r["episode_index"] == ep_idx)
    # ➋ make sure they’re sorted by frame_index
    ep_rows = sorted(ep_rows, key=itemgetter("frame_index"))
    # ➌ pull the 'action' field and (optionally) truncate to max_len
    actions = [row["action"] for row in ep_rows]
    if max_len is not None:
        actions = actions[:max_len]
    return torch.tensor(actions, dtype=torch.float32, device=device)  # (T,2)｜