# Mini Dataset Episode Sampler

该脚本用于从 `demo_data_language` 数据集中按不同随机种子抽取小规模 episode 子集，便于快速实验 SmolVLA 在小批量数据下的训练表现。运行之前请确认原始数据集已完整存在。

In [1]:
"""采样配置"""
from pathlib import Path
import json
import random
import shutil
from copy import deepcopy
from datetime import datetime

# 每个子集包含的 episode 数量
EPISODES_PER_SUBSET = 15

# 使用的随机种子（每个种子采样一次）
RANDOM_SEEDS = [0, 1, 2]

# 数据路径配置
SOURCE_ROOT = Path("demo_data_language")
TARGET_ROOT = Path("demo_data_language_subsets_nbr15")

# 若目标子集目录已存在，是否允许覆盖
OVERWRITE_EXISTING = True

# 是否重新编号 episode（默认保留原始编号以便对照）
RENUMBER_EPISODES = True


In [2]:
"""生成小规模 episode 子集"""

import json
import random
import shutil
from copy import deepcopy
from datetime import datetime
from math import sqrt
from pathlib import Path

import pyarrow as pa
import pyarrow.parquet as pq


def load_jsonl(path: Path):
    entries = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            entries.append(json.loads(line))
    return entries


def write_jsonl(path: Path, entries):
    with path.open("w", encoding="utf-8") as f:
        for entry in entries:
            f.write(json.dumps(entry, ensure_ascii=False))
            f.write("\n")


def ensure_dir(path: Path):
    if not path.exists():
        path.mkdir(parents=True, exist_ok=True)


SOURCE_DATA_DIR = SOURCE_ROOT / "data" / "chunk-000"
SOURCE_IMAGE_DIRS = {
    "observation.image": SOURCE_ROOT / "images" / "observation.image",
    "observation.wrist_image": SOURCE_ROOT / "images" / "observation.wrist_image",
}
SOURCE_META_DIR = SOURCE_ROOT / "meta"
BLOCK_POSE_LOG_PATH = SOURCE_ROOT / "remove_red_block_from_plate_UR5_smolvla_mujoco" / "block_pose_log.json"

if not SOURCE_ROOT.exists():
    raise FileNotFoundError(f"未找到原始数据集目录: {SOURCE_ROOT}")

all_episode_indices = sorted(int(p.stem.split("_")[1]) for p in SOURCE_DATA_DIR.glob("episode_*.parquet"))
if len(all_episode_indices) < EPISODES_PER_SUBSET:
    raise ValueError("原始数据集中的 episode 数量不足以完成采样，请检查 EPISODES_PER_SUBSET 配置。")

episodes_meta = load_jsonl(SOURCE_META_DIR / "episodes.jsonl")
episodes_meta_map = {entry["episode_index"]: entry for entry in episodes_meta}

episodes_stats = load_jsonl(SOURCE_META_DIR / "episodes_stats.jsonl")
episodes_stats_map = {entry["episode_index"]: entry for entry in episodes_stats}

block_pose_lookup = {}
if BLOCK_POSE_LOG_PATH.exists():
    raw_block_pose = json.loads(BLOCK_POSE_LOG_PATH.read_text(encoding="utf-8"))
    for item in raw_block_pose:
        block_pose_lookup[int(item["episode"])] = item

info_template = json.loads((SOURCE_META_DIR / "info.json").read_text(encoding="utf-8"))
tasks_payload = (SOURCE_META_DIR / "tasks.jsonl").read_text(encoding="utf-8")

summary = []

for seed in RANDOM_SEEDS:
    random.seed(seed)
    chosen = sorted(random.sample(all_episode_indices, EPISODES_PER_SUBSET))

    subset_name = f"seed_{seed:03d}"
    dest_root = TARGET_ROOT / subset_name

    if dest_root.exists():
        if not OVERWRITE_EXISTING:
            raise FileExistsError(f"目标目录已存在: {dest_root}，若需覆盖请将 OVERWRITE_EXISTING 设为 True")
        shutil.rmtree(dest_root)

    dest_chunk_dir = dest_root / "data" / "chunk-000"
    dest_obs_dir = dest_root / "images" / "observation.image"
    dest_wrist_dir = dest_root / "images" / "observation.wrist_image"
    dest_meta_dir = dest_root / "meta"

    for path in [dest_chunk_dir, dest_obs_dir, dest_wrist_dir, dest_meta_dir]:
        ensure_dir(path)

    new_episodes_entries = []
    new_stats_entries = []
    new_block_pose_entries = []
    total_frames = 0

    global_frame_index = 0

    for new_idx, orig_idx in enumerate(chosen):
        target_idx = new_idx if RENUMBER_EPISODES else orig_idx

        src_data_file = SOURCE_DATA_DIR / f"episode_{orig_idx:06d}.parquet"
        dst_data_file = dest_chunk_dir / f"episode_{target_idx:06d}.parquet"
        shutil.copy2(src_data_file, dst_data_file)

        table = pq.read_table(dst_data_file)
        num_rows = table.num_rows
        if RENUMBER_EPISODES:
            if "episode_index" in table.schema.names:
                idx = table.schema.get_field_index("episode_index")
                table = table.set_column(idx, "episode_index", pa.array([target_idx] * num_rows, type=table.column(idx).type))
            if "index" in table.schema.names:
                idx = table.schema.get_field_index("index")
                start = global_frame_index
                new_index = pa.array(range(start, start + num_rows), type=table.column(idx).type)
                table = table.set_column(idx, "index", new_index)
            global_frame_index += num_rows
            pq.write_table(table, dst_data_file)
        else:
            global_frame_index += num_rows

        for key, src_root in SOURCE_IMAGE_DIRS.items():
            src_dir = src_root / f"episode_{orig_idx:06d}"
            dst_dir = (dest_root / "images" / key) / f"episode_{target_idx:06d}"
            shutil.copytree(src_dir, dst_dir)

        episode_entry = deepcopy(episodes_meta_map[orig_idx])
        total_frames += int(episode_entry.get("length", num_rows))
        episode_entry["episode_index"] = target_idx
        new_episodes_entries.append(episode_entry)

        stats_entry = deepcopy(episodes_stats_map[orig_idx])
        stats_entry["episode_index"] = target_idx
        stats_payload = stats_entry.get("stats", {})
        episode_stats = stats_payload.get("episode_index")
        if episode_stats:
            count = episode_stats.get("count", [num_rows])
            stats_payload["episode_index"] = {
                "min": [target_idx],
                "max": [target_idx],
                "mean": [float(target_idx)],
                "std": [0.0],
                "count": count,
            }
        index_stats = stats_payload.get("index")
        if RENUMBER_EPISODES and index_stats:
            start = global_frame_index - num_rows
            end = start + num_rows - 1
            mean = start + (num_rows - 1) / 2 if num_rows > 0 else start
            std = sqrt((num_rows ** 2 - 1) / 12) if num_rows > 1 else 0.0
            stats_payload["index"] = {
                "min": [start],
                "max": [end],
                "mean": [mean],
                "std": [std],
                "count": index_stats.get("count", [num_rows]),
            }
        new_stats_entries.append(stats_entry)

        pose_entry = block_pose_lookup.get(orig_idx)
        if pose_entry:
            pose_copy = deepcopy(pose_entry)
            pose_copy["episode"] = target_idx if RENUMBER_EPISODES else pose_entry["episode"]
            pose_copy["source_episode"] = pose_entry["episode"]
            new_block_pose_entries.append(pose_copy)

    write_jsonl(dest_meta_dir / "episodes.jsonl", new_episodes_entries)
    write_jsonl(dest_meta_dir / "episodes_stats.jsonl", new_stats_entries)
    (dest_meta_dir / "tasks.jsonl").write_text(tasks_payload, encoding="utf-8")

    info = deepcopy(info_template)
    info["total_episodes"] = EPISODES_PER_SUBSET
    info["total_frames"] = total_frames
    info["splits"] = {"train": f"0:{EPISODES_PER_SUBSET}"}
    info["subset_seed"] = seed
    info["subset_source_episodes"] = chosen
    info["generated_at"] = datetime.now().isoformat(timespec="seconds")
    info["renumbered"] = RENUMBER_EPISODES
    (dest_meta_dir / "info.json").write_text(json.dumps(info, ensure_ascii=False, indent=2), encoding="utf-8")

    if new_block_pose_entries:
        (dest_root / "block_pose_log.json").write_text(json.dumps(new_block_pose_entries, ensure_ascii=False, indent=2), encoding="utf-8")

    summary.append({"subset": subset_name, "episodes": chosen, "total_frames": total_frames})

summary



[{'subset': 'seed_000',
  'episodes': [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 15, 17, 19],
  'total_frames': 3193},
 {'subset': 'seed_001',
  'episodes': [0, 1, 2, 3, 4, 6, 7, 8, 10, 12, 13, 14, 15, 16, 18],
  'total_frames': 3136},
 {'subset': 'seed_002',
  'episodes': [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 18],
  'total_frames': 3203}]