In [5]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os, json
import numpy as np
import pyarrow.parquet as pq
from tqdm import tqdm

# -------- 6D(前两“行”) <-> R（行正交） --------
def r6d_rows_to_mat(r6: np.ndarray) -> np.ndarray:
    r1 = r6[..., 0:3]
    r2 = r6[..., 3:6]
    e1 = r1 / np.clip(np.linalg.norm(r1, axis=-1, keepdims=True), 1e-9, None)
    proj = np.sum(r2 * e1, axis=-1, keepdims=True) * e1
    u2 = r2 - proj
    e2 = u2 / np.clip(np.linalg.norm(u2, axis=-1, keepdims=True), 1e-9, None)
    e3 = np.cross(e1, e2, axis=-1)
    return np.stack([e1, e2, e3], axis=-2)

def mat_to_r6d_rows(R: np.ndarray) -> np.ndarray:
    return np.concatenate([R[..., 0, :], R[..., 1, :]], axis=-1)

# -------- 读取列并裁前 29 维 --------
def _stack_column_to_ndarray(col):
    arr = col.to_numpy(zero_copy_only=False)
    arr = [x for x in arr if x is not None]
    if not arr:
        return None
    first = arr[0]
    if isinstance(first, (list, np.ndarray)):
        arr = np.stack([np.asarray(x, dtype=np.float64) for x in arr], axis=0)
    else:
        arr = np.asarray(arr, dtype=np.float64)[:, None]
    if arr.ndim != 2:
        raise ValueError(f"Expect 2D array, got shape {arr.shape}")
    return arr

def read_state_29(path):
    table = pq.read_table(path, columns=[COLUMN])
    arr = _stack_column_to_ndarray(table[COLUMN])
    if arr is None or arr.size == 0:
        raise ValueError("empty column")
    if arr.shape[1] < 29:
        raise ValueError(f"dim={arr.shape[1]} < 29")
    arr = arr[:, :29]  # 固定取前29维
    mask = np.isfinite(arr).all(axis=1)
    return arr[mask]

# -------- action 构造（相对 xyz 与相对旋转 + 夹爪绝对）--------
def actions_from_sequence(seq29: np.ndarray, horizon: int = 40) -> np.ndarray:
    T, D = seq29.shape
    assert D == 29
    W = max(T - horizon, 0)
    if W == 0:
        return np.empty((0, D), dtype=np.float64)

    def split_blocks(x):
        torso_xyz = x[..., 0:3];  torso_r6 = x[..., 3:9]
        l_xyz = x[..., 9:12];     l_r6 = x[..., 12:18]; l_g = x[..., 18:19]
        r_xyz = x[..., 19:22];    r_r6 = x[..., 22:28]; r_g = x[..., 28:29]
        return torso_xyz, torso_r6, l_xyz, l_r6, l_g, r_xyz, r_r6, r_g

    base_idx = np.arange(W, dtype=np.int64)
    k_idx = np.arange(1, horizon + 1, dtype=np.int64)
    future_idx = base_idx[:, None] + k_idx[None, :]
    now_idx = base_idx[:, None]

    now = seq29[now_idx]      # [W,1,29]
    fut = seq29[future_idx]   # [W,H,29]

    (now_tx, now_tr6, now_lx, now_lr6, now_lg, now_rx, now_rr6, now_rg) = split_blocks(now)
    (fut_tx, fut_tr6, fut_lx, fut_lr6, fut_lg, fut_rx, fut_rr6, fut_rg) = split_blocks(fut)

    # xyz 差分
    d_tx = fut_tx - now_tx
    d_lx = fut_lx - now_lx
    d_rx = fut_rx - now_rx

    # 旋转差
    R_now_t = r6d_rows_to_mat(np.squeeze(now_tr6, 1))
    R_now_l = r6d_rows_to_mat(np.squeeze(now_lr6, 1))
    R_now_r = r6d_rows_to_mat(np.squeeze(now_rr6, 1))

    W, H = fut_tr6.shape[0], fut_tr6.shape[1]
    R_fut_t = r6d_rows_to_mat(fut_tr6.reshape(-1, 6)).reshape(W, H, 3, 3)
    R_fut_l = r6d_rows_to_mat(fut_lr6.reshape(-1, 6)).reshape(W, H, 3, 3)
    R_fut_r = r6d_rows_to_mat(fut_rr6.reshape(-1, 6)).reshape(W, H, 3, 3)

    Rt_now_t = np.transpose(R_now_t, (0, 2, 1))[:, None, :, :]
    Rt_now_l = np.transpose(R_now_l, (0, 2, 1))[:, None, :, :]
    Rt_now_r = np.transpose(R_now_r, (0, 2, 1))[:, None, :, :]

    Rrel_t = R_fut_t @ Rt_now_t
    Rrel_l = R_fut_l @ Rt_now_l
    Rrel_r = R_fut_r @ Rt_now_r

    d_tr6 = mat_to_r6d_rows(Rrel_t)
    d_lr6 = mat_to_r6d_rows(Rrel_l)
    d_rr6 = mat_to_r6d_rows(Rrel_r)

    # 夹爪：未来绝对
    a_lg = fut_lg
    a_rg = fut_rg

    A = np.concatenate([d_tx, d_tr6, d_lx, d_lr6, a_lg, d_rx, d_rr6, a_rg], axis=-1)  # [W,H,29]
    return A.reshape(-1, 29)

# -------- 合并 mean/std/min/max --------
def merge_basic(stats_list):
    Ns = np.array([s["N"] for s in stats_list], dtype=np.int64)
    means = np.stack([s["mean"] for s in stats_list], 0)
    stds  = np.stack([s["std"]  for s in stats_list], 0)
    mins  = np.stack([s["min"]  for s in stats_list], 0)
    maxs  = np.stack([s["max"]  for s in stats_list], 0)

    N_total = int(Ns.sum())
    w = (Ns / max(N_total, 1))[:, None]
    mean_g = (w * means).sum(0)

    var_within  = ((Ns - 1)[:, None] * (stds ** 2)).sum(0)
    var_between = (Ns[:, None] * (means - mean_g[None, :])**2).sum(0)
    denom = max(N_total - 1, 1)
    var_g = (var_within + var_between) / denom
    std_g = np.sqrt(np.maximum(var_g, 0.0))

    return {
        "N_total": N_total,
        "mean": mean_g,
        "std":  std_g,
        "min":  mins.min(0),
        "max":  maxs.max(0),
    }

def quantiles_from_hist_1d(counts, edges, q):
    cdf = np.cumsum(counts)
    total = max(int(cdf[-1]), 1)
    target = q * total
    idx = np.argmax(cdf >= target)
    left_cum = cdf[idx-1] if idx > 0 else 0
    right_cum = cdf[idx]
    left_edge, right_edge = edges[idx], edges[idx+1]
    bin_count = max(right_cum - left_cum, 1)
    frac = (target - left_cum) / bin_count
    return left_edge + (right_edge - left_edge) * frac

In [6]:
# ---------------- 主流程：单文件输出（state + action） ----------------
def main():
    files = []
    for root, _, fs in os.walk(DATA_DIR):
        for fname in fs:
            if fname.endswith(".parquet"):
                files.append(os.path.join(root, fname))
    print(f"Found {len(files)} parquet files")
    if not files:
        return

    # Pass1：逐文件统计 state/action 的 mean/std/min/max，并汇总全局 min/max
    state_stats, action_stats = [], []
    state_min_g, state_max_g = None, None
    action_min_g, action_max_g = None, None

    for f in tqdm(files, desc="Pass1: basic stats (state & action)"):
        try:
            S = read_state_29(f)  # [T,29]
            if S.size == 0: 
                continue

            # state per-file
            Ns = S.shape[0]
            s_mean = S.mean(0)
            s_std  = S.std(0, ddof=1) if Ns > 1 else np.zeros(29)
            s_min  = S.min(0); s_max = S.max(0)
            state_stats.append({"N": Ns, "mean": s_mean, "std": s_std, "min": s_min, "max": s_max})
            state_min_g = s_min if state_min_g is None else np.minimum(state_min_g, s_min)
            state_max_g = s_max if state_max_g is None else np.maximum(state_max_g, s_max)

            # action per-file
            if len(S) > FUTURE_HORIZON:
                A = actions_from_sequence(S, FUTURE_HORIZON)  # [N,29]
                if A.size > 0:
                    Na = A.shape[0]
                    a_mean = A.mean(0)
                    a_std  = A.std(0, ddof=1) if Na > 1 else np.zeros(29)
                    a_min  = A.min(0); a_max = A.max(0)
                    action_stats.append({"N": Na, "mean": a_mean, "std": a_std, "min": a_min, "max": a_max})
                    action_min_g = a_min if action_min_g is None else np.minimum(action_min_g, a_min)
                    action_max_g = a_max if action_max_g is None else np.maximum(action_max_g, a_max)
        except Exception as e:
            print(f"[WARN] {f}: {e}")

    if not state_stats:
        print("No state stats computed."); return
    if not action_stats:
        print("No action stats computed."); return

    merged_state  = merge_basic(state_stats)
    merged_action = merge_basic(action_stats)

    # Pass2：直方图法 q01/q99（state & action）
    def run_hist_quantiles(files, is_action: bool, lo, hi):
        D = 29
        all_counts = np.zeros((D, HIST_BINS), dtype=np.int64)
        edges = [np.linspace(lo[d]-1e-9, hi[d]+1e-9, HIST_BINS+1) for d in range(D)]
        desc = "Pass2: hist (action)" if is_action else "Pass2: hist (state)"
        for f in tqdm(files, desc=desc):
            try:
                S = read_state_29(f)
                if S.size == 0: continue
                X = actions_from_sequence(S, FUTURE_HORIZON) if is_action else S
                if X.size == 0: continue
                for d in range(D):
                    c, _ = np.histogram(X[:, d], bins=edges[d])
                    all_counts[d] += c
            except Exception as e:
                print(f"[WARN] {f}: {e}")
        q01 = np.empty(D); q99 = np.empty(D)
        for d in range(D):
            q01[d] = quantiles_from_hist_1d(all_counts[d], edges[d], 0.01)
            q99[d] = quantiles_from_hist_1d(all_counts[d], edges[d], 0.99)
        return q01, q99

    if USE_HIST_FOR_QUANTILES:
        q01_s, q99_s = run_hist_quantiles(files, False, merged_state["min"], merged_state["max"])
        q01_a, q99_a = run_hist_quantiles(files, True,  merged_action["min"], merged_action["max"])
    else:
        raise NotImplementedError("Enable USE_HIST_FOR_QUANTILES for accurate q01/q99.")

    # 单文件输出
    os.makedirs(os.path.dirname(OUT_JSON), exist_ok=True)
    payload = {
        "norm_stats": {
            "state": {
                "mean": merged_state["mean"].tolist(),
                "std":  merged_state["std"].tolist(),
                "min":  merged_state["min"].tolist(),
                "max":  merged_state["max"].tolist(),
                "q01":  q01_s.tolist(),
                "q99":  q99_s.tolist(),
            },
            "action": {
                "mean": merged_action["mean"].tolist(),
                "std":  merged_action["std"].tolist(),
                "min":  merged_action["min"].tolist(),
                "max":  merged_action["max"].tolist(),
                "q01":  q01_a.tolist(),
                "q99":  q99_a.tolist(),
            },
        }
    }
    with open(OUT_JSON, "w") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)

    print(f"[DONE] Both stats -> {OUT_JSON}")

In [7]:

#================= 配置 =================
DATA_DIR = "/kpfs-regular/share_space/data/lerobot_data_aliyun/s1_data/diversity_partial"
COLUMN   = "cartesian_so3_dict.cartesian_pose_state"
# COLUMN = ['cartesian_so3_dict.cartesian_pose_state', 'cartesian_so3_dict.cartesian_pose_command']

OUT_JSON = "/kpfs-cognition/waikei/codes/openpi-uncle/assets/diversity_partial/norm_stats_both.json"

FUTURE_HORIZON = 40
USE_HIST_FOR_QUANTILES = True
HIST_BINS = 2048
#=======================================

if __name__ == "__main__":
    main()


Found 3151 parquet files


Pass1: basic stats (state & action): 100%|██████████| 3151/3151 [02:11<00:00, 24.03it/s] 
Pass2: hist (state): 100%|██████████| 3151/3151 [00:09<00:00, 316.84it/s]
Pass2: hist (action): 100%|██████████| 3151/3151 [02:21<00:00, 22.35it/s] 


[DONE] Both stats -> /kpfs-cognition/waikei/codes/openpi-uncle/assets/diversity_partial/norm_stats_both.json
