In [7]:
# dataset_v2_sharded.py
from __future__ import annotations
import os, json, glob
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
from torch.utils.data import Dataset

# -------------------------
# Helpers
# -------------------------

def _load_json(path: str) -> Dict:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def _first_existing(paths: List[str]) -> Optional[str]:
    for p in paths:
        if os.path.exists(p):
            return p
    return None

def _np_dtype(dtype_str: str):
    # metadataに "float32" / "int32" / "uint16" などが来る想定
    return getattr(np, dtype_str) if hasattr(np, dtype_str) else np.dtype(dtype_str)

# -------------------------
# Dataset
# -------------------------

@dataclass
class V2Spec:
    token_grid: Tuple[int, int, int] = (6, 32, 32)  # (T_total, H, W) for one sample
    state_grid: Tuple[int, int] = (64, 25)          # (T_state, state_dim)
    token_dtype: np.dtype = np.int32
    state_dtype: np.dtype = np.float32
    segment_dtype: np.dtype = np.int64              # segment_idx dtype (often int64)
    # If segment_idx stores [start,end) offsets into flattened token stream
    segment_is_pairs: bool = True

class ShardedCompressionV2Dataset(Dataset):
    """
    Train/Val v2.0 Sharded format.

    Expected layout (root/split):
      video_{shard}.bin
      segment_idx_{shard}.bin
      states_{shard}.bin
      metadata.json or metadata_{shard}.json (optional but recommended)

    Returns per sample:
      z: LongTensor (6, 32, 32)  tokens
      s: FloatTensor (64, 25)    states
    """

    def __init__(
        self,
        root_dir: str,
        split: str = "train",  # "train" / "val" / "test_v2.0" etc.
        spec: Optional[V2Spec] = None,
        use_memmap: bool = True,
    ):
        super().__init__()
        self.root = os.path.join(root_dir, split)
        self.spec = spec or V2Spec()
        self.use_memmap = use_memmap

        # ---- metadata (overall + per-shard) ----
        meta_overall_path = _first_existing([
            os.path.join(self.root, "metadata.json"),
            os.path.join(self.root, "metadata_overall.json"),
        ])
        self.meta_overall = _load_json(meta_overall_path) if meta_overall_path else None

        # shard files
        self.video_files = sorted(glob.glob(os.path.join(self.root, "segment_indices/videos/video_*.bin")))
        self.seg_files   = sorted(glob.glob(os.path.join(self.root, "segment_indices/segment_idx_*.bin")))
        self.state_files = sorted(glob.glob(os.path.join(self.root, "robot_states/states_*.bin")))

        if len(self.video_files) == 0:
            raise FileNotFoundError(f"No video_*.bin under {self.root}")

        # test_v2.0 may omit segment_idx_*.bin (as your description says "similar structure")
        self.has_segments = len(self.seg_files) == len(self.video_files) and len(self.seg_files) > 0
        if self.has_segments:
            assert len(self.state_files) == len(self.video_files), "states shard count mismatch"
        else:
            # If no segment_idx, we assume each shard already corresponds to one sample or fixed-length packing.
            # We'll support a fixed-length packing fallback.
            assert len(self.state_files) == len(self.video_files), "states shard count mismatch"

        # ---- Build global index: list of (shard_id, local_sample_id) ----
        self.shard_sample_counts: List[int] = []
        self.segment_arrays: List[np.ndarray] = []
        self.video_mmaps: List[np.memmap | np.ndarray] = []
        self.state_mmaps: List[np.memmap | np.ndarray] = []

        T, H, W = self.spec.token_grid
        tok_per_sample = T * H * W
        St, Sd = self.spec.state_grid
        state_per_sample = St * Sd

        for sid in range(len(self.video_files)):
            # per-shard metadata (optional)
            meta_shard_path = _first_existing([
                os.path.join(self.root, f"metadata_{sid}.json"),
                os.path.join(self.root, f"metadata_{os.path.basename(self.video_files[sid]).split('_')[1].split('.')[0]}.json"),
            ])
            meta_shard = _load_json(meta_shard_path) if meta_shard_path else None

            # dtype overrides from metadata if present
            token_dtype = self.spec.token_dtype
            state_dtype = self.spec.state_dtype
            seg_dtype   = self.spec.segment_dtype

            if meta_shard:
                # Best-effort keys (dataset提供側でキー名が違う可能性があるため)
                if "video_dtype" in meta_shard:
                    token_dtype = _np_dtype(meta_shard["video_dtype"])
                if "states_dtype" in meta_shard:
                    state_dtype = _np_dtype(meta_shard["states_dtype"])
                if "segment_idx_dtype" in meta_shard:
                    seg_dtype = _np_dtype(meta_shard["segment_idx_dtype"])

            # memmap open
            vpath = self.video_files[sid]
            spath = self.state_files[sid]
            video = np.memmap(vpath, mode="r", dtype=token_dtype) if use_memmap else np.fromfile(vpath, dtype=token_dtype)
            states = np.memmap(spath, mode="r", dtype=state_dtype) if use_memmap else np.fromfile(spath, dtype=state_dtype)

            self.video_mmaps.append(video)
            self.state_mmaps.append(states)

            if self.has_segments:
                segpath = self.seg_files[sid]
                seg = np.memmap(segpath, mode="r", dtype=seg_dtype) if use_memmap else np.fromfile(segpath, dtype=seg_dtype)
                seg = np.asarray(seg)  # safe view
                # We expect pairs (start,end) -> shape (N,2)
                if seg.ndim == 1:
                    if seg.size % 2 == 0:
                        seg = seg.reshape(-1, 2)
                    else:
                        # could be boundary list [b0,b1,...,bN], interpret as start boundaries
                        # We'll convert to pairs
                        boundaries = seg
                        starts = boundaries[:-1]
                        ends = boundaries[1:]
                        seg = np.stack([starts, ends], axis=1)
                self.segment_arrays.append(seg)
                n = seg.shape[0]
            else:
                # fallback: infer sample count by fixed packing
                n = int(video.size // tok_per_sample)
                # optional sanity check on states
                n_states = int(states.size // state_per_sample)
                n = min(n, n_states)

            self.shard_sample_counts.append(n)

        # global prefix sums for O(logN) indexing
        self._prefix = np.cumsum([0] + self.shard_sample_counts).tolist()

        # cache for shapes
        self._tok_per_sample = tok_per_sample
        self._state_per_sample = state_per_sample
        self._T, self._H, self._W = self.spec.token_grid
        self._St, self._Sd = self.spec.state_grid

    def __len__(self):
        return self._prefix[-1]

    def _locate(self, idx: int) -> Tuple[int, int]:
        # binary search over prefix sums
        # find sid such that prefix[sid] <= idx < prefix[sid+1]
        lo, hi = 0, len(self.shard_sample_counts)
        while lo + 1 < hi:
            mid = (lo + hi) // 2
            if self._prefix[mid] <= idx:
                lo = mid
            else:
                hi = mid
        sid = lo
        local = idx - self._prefix[sid]
        return sid, local

    def __getitem__(self, idx: int):
        sid, local = self._locate(idx)

        video = self.video_mmaps[sid]
        states = self.state_mmaps[sid]

        # ---- tokens ----
        if self.has_segments:
            seg = self.segment_arrays[sid]
            start, end = int(seg[local, 0]), int(seg[local, 1])
            flat = video[start:end]
            # Many datasets store exactly one sample worth of tokens per segment
            # Ensure length matches expected (tok_per_sample)
            if flat.size != self._tok_per_sample:
                # If segment spans more, you may need to crop/choose window.
                # For Revontuli compression: expect exactly 6*32*32 tokens.
                flat = flat[: self._tok_per_sample]
            z = flat.reshape(self._T, self._H, self._W)
        else:
            # fixed packing fallback
            off = local * self._tok_per_sample
            z = np.asarray(video[off: off + self._tok_per_sample]).reshape(self._T, self._H, self._W)

        z = torch.from_numpy(np.asarray(z)).long()

        # ---- states ----
        off_s = local * self._state_per_sample
        s = np.asarray(states[off_s: off_s + self._state_per_sample]).reshape(self._St, self._Sd)
        s = torch.from_numpy(s).float()

        return z, s


# --------------- Optional: State index definition for clarity ---------------
STATE_INDEX = {
    0: "HIP_YAW",
    1: "HIP_ROLL",
    2: "HIP_PITCH",
    3: "KNEE_PITCH",
    4: "ANKLE_ROLL",
    5: "ANKLE_PITCH",
    6: "LEFT_SHOULDER_PITCH",
    7: "LEFT_SHOULDER_ROLL",
    8: "LEFT_SHOULDER_YAW",
    9: "LEFT_ELBOW_PITCH",
    10: "LEFT_ELBOW_YAW",
    11: "LEFT_WRIST_PITCH",
    12: "LEFT_WRIST_ROLL",
    13: "RIGHT_SHOULDER_PITCH",
    14: "RIGHT_SHOULDER_ROLL",
    15: "RIGHT_SHOULDER_YAW",
    16: "RIGHT_ELBOW_PITCH",
    17: "RIGHT_ELBOW_YAW",
    18: "RIGHT_WRIST_PITCH",
    19: "RIGHT_WRIST_ROLL",
    20: "NECK_PITCH",
    21: "LEFT_HAND_CLOSURE",
    22: "RIGHT_HAND_CLOSURE",
    23: "LINEAR_VELOCITY",
    24: "ANGULAR_VELOCITY",
}


In [33]:
spec = V2Spec(
    token_grid=(8, 8, 8),
    state_grid=(64, 25),
    token_dtype=np.uint32,
    state_dtype=np.float32,
    segment_dtype=np.int16,
)


In [34]:
ds = ShardedCompressionV2Dataset(
    root_dir="/root/work/data/raw/",
    split="train_v2.0",          # train / val / test_v2.0
    spec=spec,
    use_memmap=True,
)

In [35]:
import os
import numpy as np

split_dir = "/root/work/data/raw/train_v2.0"  # あなたの split フォルダに合わせて
seg_path  = os.path.join(split_dir, "segment_indices/segment_idx_0.bin")
vid_path  = os.path.join(split_dir, "segment_indices/videos/video_0.bin")

print("seg bytes:", os.path.getsize(seg_path))
print("vid bytes:", os.path.getsize(vid_path))

# 先頭 64 byte を生で見る（形式推定に最強）
head = np.fromfile(seg_path, dtype=np.uint8, count=64)
print("seg head bytes:", head.tolist())


seg bytes: 450168
vid bytes: 81358848
seg head bytes: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [22]:
import numpy as np, os

def infer_memmap_dtype_and_offset(path: str):
    size = os.path.getsize(path)
    dtype_candidates = [np.int32, np.uint32, np.int64, np.uint64, np.int16, np.uint16]
    offset_candidates = [0, 4, 8, 16, 32]
    feasible = []
    for off in offset_candidates:
        if size <= off:
            continue
        rem = size - off
        for dt in dtype_candidates:
            if rem % np.dtype(dt).itemsize == 0:
                feasible.append((dt, off, rem // np.dtype(dt).itemsize))
    if not feasible:
        raise ValueError(f"Cannot infer dtype/offset: {path}, size={size}")
    feasible_sorted = sorted(feasible, key=lambda x: (x[2] % 2 != 0, np.dtype(x[0]).itemsize))
    return feasible_sorted[0]

dt, off, n = infer_memmap_dtype_and_offset(seg_path)
print("inferred seg dtype:", dt, "offset:", off, "n_elems:", n)

seg = np.memmap(seg_path, mode="r", dtype=dt, offset=off)
seg = np.asarray(seg)

print("seg first 20 elems:", seg[:20].tolist())

# pairs / boundary の両方を試して見る
if seg.size % 2 == 0:
    pairs = seg.reshape(-1, 2)
    print("pairs[0:5]:\n", pairs[:5])
    lens = pairs[:5,1] - pairs[:5,0]
    print("lens[0:5]:", lens.tolist())
else:
    b = seg
    pairs = np.stack([b[:-1], b[1:]], axis=1)
    print("boundary->pairs[0:5]:\n", pairs[:5])
    lens = pairs[:5,1] - pairs[:5,0]
    print("lens[0:5]:", lens.tolist())

# video 側の要素数（dtypeは今のspecに合わせて）
video = np.memmap(vid_path, mode="r", dtype=np.int32)  # token_dtype が int32 の想定
print("video n_elems:", video.size)
print("max end (first 100):", int(pairs[:100,1].max()))


inferred seg dtype: <class 'numpy.int16'> offset: 0 n_elems: 225084
seg first 20 elems: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
pairs[0:5]:
 [[0 0]
 [0 0]
 [0 0]
 [0 0]
 [0 0]]
lens[0:5]: [0, 0, 0, 0, 0]
video n_elems: 20339712
max end (first 100): 0


In [28]:
import os, numpy as np

seg_path = "/root/work/data/raw/train_v2.0/segment_indices/segment_idx_0.bin"

print("seg bytes:", os.path.getsize(seg_path))

for dt in [np.int32, np.uint32, np.int16, np.uint16]:
    arr = np.memmap(seg_path, mode="r", dtype=dt)
    arr = np.asarray(arr)
    nz = np.flatnonzero(arr)
    print(f"\n--- dtype={dt} ---")
    print("n_elems:", arr.size)
    print("nonzero count:", nz.size)
    if nz.size > 0:
        i = int(nz[0])
        print("first nonzero index:", i, "value:", int(arr[i]))
        print("head 20:", arr[:20].tolist())
        # pairs解釈（start,end）として最初の5つ
        if arr.size % 2 == 0:
            pairs = arr.reshape(-1, 2)
            print("pairs[0:5]:\n", pairs[:5])
            lens = pairs[:5,1] - pairs[:5,0]
            print("lens[0:5]:", lens.tolist())


seg bytes: 450168

--- dtype=<class 'numpy.int32'> ---
n_elems: 112542
nonzero count: 111736
first nonzero index: 806 value: 1
head 20: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
pairs[0:5]:
 [[0 0]
 [0 0]
 [0 0]
 [0 0]
 [0 0]]
lens[0:5]: [0, 0, 0, 0, 0]

--- dtype=<class 'numpy.uint32'> ---
n_elems: 112542
nonzero count: 111736
first nonzero index: 806 value: 1
head 20: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
pairs[0:5]:
 [[0 0]
 [0 0]
 [0 0]
 [0 0]
 [0 0]]
lens[0:5]: [0, 0, 0, 0, 0]

--- dtype=<class 'numpy.int16'> ---
n_elems: 225084
nonzero count: 111736
first nonzero index: 1612 value: 1
head 20: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
pairs[0:5]:
 [[0 0]
 [0 0]
 [0 0]
 [0 0]
 [0 0]]
lens[0:5]: [0, 0, 0, 0, 0]

--- dtype=<class 'numpy.uint16'> ---
n_elems: 225084
nonzero count: 111736
first nonzero index: 1612 value: 1
head 20: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
pairs[0:5]:
 [[0 0]
 [0 0]
 [0 0]


In [32]:
import os
import numpy as np

# video の dtype は spec に合わせて（仮に int32 とします）
video = np.memmap(vid_path, mode="r", dtype=np.int32)
V = video.size

TOK_PER_SAMPLE = 8 * 8 * 8  # 6144（想定）
print("video elems:", V)

dtype_cands = [np.int64, np.uint64, np.int32, np.uint32, np.int16, np.uint16]
offset_cands = [0, 4, 8, 12, 16, 24, 32, 64, 128, 256, 512]

def score_as_pairs(arr):
    """arr を (N,2) の start/end として解釈した時の “それっぽさ” をスコア化"""
    if arr.size < 10 or arr.size % 2 != 0:
        return -1e9, None
    pairs = arr.reshape(-1, 2).astype(np.int64)
    s = pairs[:,0]
    e = pairs[:,1]
    # 基本条件
    if np.any(e < s):
        return -1e9, pairs
    # 0長は減点
    lens = e - s
    nz = np.mean(lens > 0)
    # 範囲内率
    inb = np.mean((s >= 0) & (e <= V))
    # 単調性（だいたい増えるはず）
    mono = np.mean(np.diff(s) >= 0) if len(s) > 1 else 0.0
    # 期待長 6144 に近い割合
    near = np.mean(np.abs(lens - TOK_PER_SAMPLE) <= 8)  # 多少のズレ許容
    # スコア（重みは経験則）
    score = 3*inb + 2*mono + 2*near + 1*nz
    return score, pairs

def score_as_boundaries(arr):
    """arr を境界点 [b0,b1,...] として解釈して pairs にした時のスコア"""
    if arr.size < 10:
        return -1e9, None
    b = arr.astype(np.int64)
    # 境界は単調増加が自然
    mono = np.mean(np.diff(b) >= 0)
    if mono < 0.9:
        return -1e9, None
    s = b[:-1]
    e = b[1:]
    lens = e - s
    nz = np.mean(lens > 0)
    inb = np.mean((s >= 0) & (e <= V))
    near = np.mean(np.abs(lens - TOK_PER_SAMPLE) <= 8)
    score = 3*inb + 2*mono + 2*near + 1*nz
    pairs = np.stack([s, e], axis=1)
    return score, pairs

best = None

for off in offset_cands:
    for dt in dtype_cands:
        size = os.path.getsize(seg_path)
        if size <= off:
            continue
        rem = size - off
        if rem % np.dtype(dt).itemsize != 0:
            continue

        arr = np.memmap(seg_path, mode="r", dtype=dt, offset=off)
        arr = np.asarray(arr)

        # 先頭が全部同じ、みたいなのは早期スキップ
        if arr.size > 100 and np.all(arr[:100] == arr[0]):
            continue

        sp, pairs_p = score_as_pairs(arr)
        sb, pairs_b = score_as_boundaries(arr)

        for mode, sc, pairs in [("pairs", sp, pairs_p), ("boundaries", sb, pairs_b)]:
            if pairs is None:
                continue
            # 先頭5のレンズを見て「0長ばっか」は落とす
            lens0 = (pairs[:5,1] - pairs[:5,0])
            if np.mean(lens0 == 0) > 0.6:
                continue

            cand = (sc, mode, dt, off, pairs)
            if best is None or cand[0] > best[0]:
                best = cand

print("\n=== BEST CANDIDATE ===")
if best is None:
    print("No plausible interpretation found. (segment file might be compressed/encoded differently)")
else:
    sc, mode, dt, off, pairs = best
    print("score:", sc, "mode:", mode, "dtype:", dt, "offset:", off)
    print("pairs[0:5]:\n", pairs[:5])
    print("lens[0:5]:", (pairs[:5,1] - pairs[:5,0]).tolist())
    print("in-bounds ratio:", np.mean((pairs[:,0] >= 0) & (pairs[:,1] <= V)))
    print("near 6144 ratio:", np.mean(np.abs((pairs[:,1]-pairs[:,0]) - TOK_PER_SAMPLE) <= 8))


video elems: 20339712

=== BEST CANDIDATE ===
No plausible interpretation found. (segment file might be compressed/encoded differently)


In [4]:
import json
import pathlib
import subprocess

import numpy as np

dir_path = pathlib.Path("/root/work/data/raw/train_v2.0")
rank = 0

# load metadata.json
metadata = json.load(open(dir_path / "metadata.json"))
metadata_shard = json.load(open(dir_path / f"metadata_{rank}.json"))

total_frames = metadata_shard["shard_num_frames"]


maps = [
    ("segment_idx", np.int32, []),
    ("states", np.float32, [25]),
]

for m, dtype, shape in maps:
    filename = dir_path / f"{m}_{rank}.bin"
    print("Reading", filename, [total_frames] + shape)
    m_out = np.memmap(filename, dtype=dtype, mode="r", shape=tuple([total_frames] + shape))
    assert m_out.shape[0] == total_frames
    print(m, m_out[:100])

FileNotFoundError: [Errno 2] No such file or directory: '/root/work/data/raw/train_v2.0/metadata_0.json'

In [1]:
# dataset_v2_sharded_fixed.py
from __future__ import annotations
import os, json, glob
from dataclasses import dataclass
from typing import List, Tuple, Optional

import numpy as np
import torch
from torch.utils.data import Dataset

@dataclass
class V2Spec:
    # sample definition
    token_len: int = 8          # 3 past + 3 future
    state_len: int = 64         # Revontuli-style (can set to 6 if you want)
    H: int = 8
    W: int = 8

    # dtypes
    video_dtype = np.int32
    seg_dtype   = np.int32
    state_dtype = np.float32

class ShardedCompressionV2Dataset(Dataset):
    """
    v2.0 sharded format (segment_idx is per-frame segment ID).

    Files per shard:
      video_{rank}.bin       : (total_frames, H, W) int32 tokens
      segment_idx_{rank}.bin : (total_frames,) int32 segment id per frame
      states_{rank}.bin      : (total_frames, 25) float32
      metadata_{rank}.json   : contains shard_num_frames
    """

    def __init__(self, root_dir: str, split: str, spec: Optional[V2Spec] = None, use_memmap: bool = True):
        super().__init__()
        self.root = os.path.join(root_dir, split)
        self.spec = spec or V2Spec()
        self.use_memmap = use_memmap

        # shard file lists
        self.video_files = sorted(glob.glob(os.path.join(self.root, "video_*.bin")))
        self.seg_files   = sorted(glob.glob(os.path.join(self.root, "segment_idx_*.bin")))
        self.state_files = sorted(glob.glob(os.path.join(self.root, "states_*.bin")))
        self.meta_files  = sorted(glob.glob(os.path.join(self.root, "metadata_*.json")))

        assert len(self.video_files) > 0, f"No video_*.bin under {self.root}"
        assert len(self.seg_files) == len(self.video_files), "segment_idx shard count mismatch"
        assert len(self.state_files) == len(self.video_files), "states shard count mismatch"

        # load shard_num_frames per shard
        self.shard_frames: List[int] = []
        for sid in range(len(self.video_files)):
            meta_path = os.path.join(self.root, f"metadata_{sid}.json")
            if not os.path.exists(meta_path):
                raise FileNotFoundError(f"Missing {meta_path}")
            meta = json.load(open(meta_path, "r"))
            self.shard_frames.append(int(meta["shard_num_frames"]))

        # open memmaps + build index of valid windows
        self.video_mmaps = []
        self.seg_mmaps = []
        self.state_mmaps = []

        self.index: List[Tuple[int, int]] = []  # list of (sid, t0)

        Ttok = self.spec.token_len
        Tst  = self.spec.state_len
        H, W = self.spec.H, self.spec.W

        for sid, total_frames in enumerate(self.shard_frames):
            vpath = self.video_files[sid]
            gpath = self.seg_files[sid]
            spath = self.state_files[sid]

            # memmap with known shapes (from metadata: total_frames)
            video = np.memmap(vpath, mode="r", dtype=self.spec.video_dtype,
                              shape=(total_frames, H, W))
            segid = np.memmap(gpath, mode="r", dtype=self.spec.seg_dtype,
                              shape=(total_frames,))
            states = np.memmap(spath, mode="r", dtype=self.spec.state_dtype,
                               shape=(total_frames, 25))

            self.video_mmaps.append(video)
            self.seg_mmaps.append(segid)
            self.state_mmaps.append(states)

            # ---- build segment boundaries from segid changes ----
            # segments are runs of equal segid
            segid_arr = np.asarray(segid)  # small enough to materialize for one shard
            # boundary indices: where segid changes
            change = np.nonzero(segid_arr[1:] != segid_arr[:-1])[0] + 1
            # segment start indices
            starts = np.concatenate([[0], change])
            # segment end indices (exclusive)
            ends = np.concatenate([change, [total_frames]])

            # For each segment, allow windows t0 such that:
            #   t0 + token_len <= seg_end
            #   t0 + state_len <= seg_end   (to avoid crossing segment boundary)
            # and also within total_frames.
            for a, b in zip(starts, ends):
                seg_len = b - a
                # need at least max(token_len, state_len) frames in segment
                need = max(Ttok, Tst)
                if seg_len < need:
                    continue
                # valid t0 range: [a, b-need]
                for t0 in range(a, b - need + 1):
                    self.index.append((sid, int(t0)))

        assert len(self.index) > 0, "No valid samples found. Maybe state_len too large or H/W wrong?"

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx: int):
        sid, t0 = self.index[idx]
        Ttok = self.spec.token_len
        Tst  = self.spec.state_len

        video = self.video_mmaps[sid]
        states = self.state_mmaps[sid]
        segid = self.seg_mmaps[sid]

        # Safety: ensure we are not crossing segment boundary (should already be guaranteed)
        seg0 = int(segid[t0])
        if int(segid[t0 + max(Ttok, Tst) - 1]) != seg0:
            raise ValueError(f"Window crosses segment boundary unexpectedly: sid={sid}, t0={t0}")

        z = np.asarray(video[t0:t0 + Ttok])         # (6,H,W)
        s = np.asarray(states[t0:t0 + Tst])         # (64,25)
        
        z = torch.from_numpy(np.array(z, copy=True)).long()
        s = torch.from_numpy(np.array(s, copy=True)).float()

        return z, s


In [3]:


spec = V2Spec(
    token_len=8,
    state_len=64,   # 6にしたければここを6へ
    H=8, W=8,
)

ds = ShardedCompressionV2Dataset(
    root_dir="/root/work/data/raw/",
    split="train_v2.0",
    spec=spec,
    use_memmap=True,
)

z, s = ds[0]
print("z:", z.shape, z.dtype, int(z.min()), int(z.max()))
print("s:", s.shape, s.dtype, float(s.mean()), float(s.std()))


AssertionError: No video_*.bin under /root/work/data/raw/train_v2.0

In [9]:
# dataset_v2_flexible.py
from __future__ import annotations
import os, json, glob, re
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
from torch.utils.data import Dataset


@dataclass
class V2SpecFlex:
    token_len: int = 8      # DV8×8×8
    state_len: int = 64
    H: int = 8
    W: int = 8
    state_dim: int = 25

    video_dtype = np.int32
    seg_dtype   = np.int32
    state_dtype = np.float32

    allow_no_segment_idx: bool = True


def _load_json(path: str) -> Dict:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def _find_overall_metadata_json(split_dir: str) -> Optional[str]:
    p = os.path.join(split_dir, "metadata.json")
    return p if os.path.exists(p) else None


def _find_per_shard_metadata(split_dir: str) -> Dict[int, str]:
    cand = []
    cand += glob.glob(os.path.join(split_dir, "metadata", "metadata_*.json"))
    cand += glob.glob(os.path.join(split_dir, "metadata_*.json"))
    out = {}
    for p in cand:
        m = re.search(r"metadata_(\d+)\.json$", p)
        if m:
            out[int(m.group(1))] = p
    return dict(sorted(out.items(), key=lambda x: x[0]))


def _find_videos(split_dir: str) -> Dict[int, str]:
    patterns = [
        os.path.join(split_dir, "segment_indices", "videos", "video_*.bin"),
        os.path.join(split_dir, "videos", "video_*.bin"),
        os.path.join(split_dir, "video_*.bin"),
    ]
    files = []
    for pat in patterns:
        files += glob.glob(pat)

    out = {}
    for p in files:
        m = re.search(r"video_(\d+)\.bin$", p)
        if m:
            out[int(m.group(1))] = p
    return dict(sorted(out.items(), key=lambda x: x[0]))


def _find_states(split_dir: str) -> Dict[int, str]:
    patterns = [
        os.path.join(split_dir, "robot_states", "states_*.bin"),
        os.path.join(split_dir, "states_*.bin"),
    ]
    files = []
    for pat in patterns:
        files += glob.glob(pat)

    out = {}
    for p in files:
        m = re.search(r"states_(\d+)\.bin$", p)
        if m:
            out[int(m.group(1))] = p
    return dict(sorted(out.items(), key=lambda x: x[0]))


def _find_segment_idx(split_dir: str) -> Dict[int, str]:
    patterns = [
        os.path.join(split_dir, "segment_indices", "segment_idx_*.bin"),
        os.path.join(split_dir, "segment_idx_*.bin"),
        os.path.join(split_dir, "segment_idx_*.json"),
    ]
    files = []
    for pat in patterns:
        files += glob.glob(pat)

    out = {}
    for p in files:
        m = re.search(r"segment_idx_(\d+)\.(?:bin|json)$", p)
        if m:
            out[int(m.group(1))] = p
    return dict(sorted(out.items(), key=lambda x: x[0]))


def _load_segment_idx(path: str, total_frames: int, dtype=np.int32) -> np.ndarray:
    if path.endswith(".bin"):
        seg = np.memmap(path, mode="r", dtype=dtype, shape=(total_frames,))
        return np.asarray(seg)
    elif path.endswith(".json"):
        obj = _load_json(path)
        if isinstance(obj, list):
            seg = np.array(obj, dtype=dtype)
        elif isinstance(obj, dict) and "segment_idx" in obj:
            seg = np.array(obj["segment_idx"], dtype=dtype)
        else:
            raise ValueError(f"Unknown segment_idx json format: {path}")
        if seg.shape[0] != total_frames:
            raise ValueError("segment_idx length mismatch")
        return seg
    else:
        raise ValueError(f"Unsupported segment_idx file: {path}")

def _infer_total_frames_from_file_sizes(
    video_path: str,
    states_path: str,
    H: int,
    W: int,
    state_dim: int,
    video_dtype=np.int32,
    state_dtype=np.float32,
) -> int:
    v_bytes = os.path.getsize(video_path)
    s_bytes = os.path.getsize(states_path)

    v_item = np.dtype(video_dtype).itemsize
    s_item = np.dtype(state_dtype).itemsize

    # video: (T, H, W)
    denom_v = H * W * v_item
    if v_bytes % denom_v != 0:
        raise ValueError(f"video size not divisible: {video_path} bytes={v_bytes}, denom={denom_v}")
    Tv = v_bytes // denom_v

    # states: (T, state_dim)
    denom_s = state_dim * s_item
    if s_bytes % denom_s != 0:
        raise ValueError(f"states size not divisible: {states_path} bytes={s_bytes}, denom={denom_s}")
    Ts = s_bytes // denom_s

    if Tv != Ts:
        raise ValueError(f"frame count mismatch: video={Tv}, states={Ts} for shard files:\n{video_path}\n{states_path}")
    return int(Tv)

class ShardedCompressionV2Dataset(Dataset):
    """
    Flexible loader for train_v2.0 / val_v2.0 / test_v2.0 with different folder layouts.
    """
    def __init__(self, root_dir: str, split: str, spec: Optional[V2SpecFlex] = None, use_memmap: bool = True):
        super().__init__()
        self.split_dir = os.path.join(root_dir, split)
        self.spec = spec or V2SpecFlex()
        self.use_memmap = use_memmap

        if not os.path.isdir(self.split_dir):
            raise FileNotFoundError(f"Split dir not found: {self.split_dir}")

        # discover files
        meta_overall = _find_overall_metadata_json(self.split_dir)
        self.meta_overall = _load_json(meta_overall) if meta_overall else None

        self.meta_shard = _find_per_shard_metadata(self.split_dir)
        self.videos = _find_videos(self.split_dir)
        self.states = _find_states(self.split_dir)
        self.segment_idx = _find_segment_idx(self.split_dir)

        # determine shard ids we can load
        # shard ids we can load:
        # - train/val: require metadata_shard
        # - test: allow no metadata_shard (infer from file sizes)
        common_vs = sorted(set(self.videos.keys()) & set(self.states.keys()))

        if len(common_vs) == 0:
            raise ValueError("No matching shard ids between video and states files.")

        # Prefer shards that have metadata, but allow shards without it (test)
        self.shard_ids = common_vs

        # store total_frames per shard
        self.total_frames_by_shard = {}

        for sid in self.shard_ids:
            meta_path = self.meta_shard.get(sid, None)
            if meta_path is not None:
                meta = _load_json(meta_path)
                total_frames = int(meta.get("shard_num_frames", meta.get("num_frames", -1)))
                if total_frames <= 0:
                    raise ValueError(f"metadata_{sid}.json missing shard_num_frames: {meta_path}")
            else:
                # ★ metadata 無し（test想定）：ファイルサイズから推定
                total_frames = _infer_total_frames_from_file_sizes(
                    video_path=self.videos[sid],
                    states_path=self.states[sid],
                    H=self.spec.H,
                    W=self.spec.W,
                    state_dim=self.spec.state_dim,
                    video_dtype=self.spec.video_dtype,
                    state_dtype=self.spec.state_dtype,
                )

            self.total_frames_by_shard[sid] = total_frames

        # open memmaps + build index
        self.video_mmaps = {}
        self.state_mmaps = {}
        self.seg_arrays = {}  # may be missing in test
        self.index: List[Tuple[int, int]] = []  # (shard_id, t0)

        Ttok = self.spec.token_len
        Tst  = self.spec.state_len
        H, W = self.spec.H, self.spec.W
        Sd   = self.spec.state_dim

        for sid in self.shard_ids:
            meta = _load_json(self.meta_shard[sid])
            total_frames = self.total_frames_by_shard[sid]
            if total_frames <= 0:
                raise ValueError(f"metadata_{sid}.json missing shard_num_frames: {self.meta_shard[sid]}")

            # video
            vpath = self.videos[sid]
            video = np.memmap(vpath, mode="r", dtype=self.spec.video_dtype, shape=(total_frames, H, W))
            self.video_mmaps[sid] = video

            # states
            spath = self.states[sid]
            states = np.memmap(spath, mode="r", dtype=self.spec.state_dtype, shape=(total_frames, Sd))
            self.state_mmaps[sid] = states

            # segment_idx (optional in test)
            if sid in self.segment_idx:
                seg = _load_segment_idx(self.segment_idx[sid], total_frames=total_frames, dtype=self.spec.seg_dtype)
                self.seg_arrays[sid] = seg
                # build segment boundaries from seg id changes
                change = np.nonzero(seg[1:] != seg[:-1])[0] + 1
                starts = np.concatenate([[0], change])
                ends   = np.concatenate([change, [total_frames]])
                need = max(Ttok, Tst)
                for a, b in zip(starts, ends):
                    if (b - a) < need:
                        continue
                    for t0 in range(int(a), int(b - need + 1)):
                        self.index.append((sid, t0))
            else:
                if not self.spec.allow_no_segment_idx:
                    continue
                # fallback: treat whole shard as one segment
                need = max(Ttok, Tst)
                if total_frames < need:
                    continue
                for t0 in range(0, total_frames - need + 1):
                    self.index.append((sid, int(t0)))

        if len(self.index) == 0:
            raise ValueError("No valid samples built. Maybe token_len/state_len too large or shapes wrong.")

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx: int):
        sid, t0 = self.index[idx]
        Ttok = self.spec.token_len
        Tst  = self.spec.state_len

        video = self.video_mmaps[sid]
        states = self.state_mmaps[sid]

        # Safety: if segment_idx exists, ensure window stays in same segment
        if sid in self.seg_arrays:
            seg = self.seg_arrays[sid]
            seg0 = int(seg[t0])
            if int(seg[t0 + max(Ttok, Tst) - 1]) != seg0:
                raise ValueError(f"Crossed segment boundary: shard={sid} t0={t0}")

        z = np.asarray(video[t0:t0 + Ttok])          # (Ttok,H,W)
        s = np.asarray(states[t0:t0 + Tst])          # (Tst,25)

        # avoid non-writable memmap warning
        z = torch.from_numpy(z.copy()).long()
        s = torch.from_numpy(s.copy()).float()
        return z, s


In [18]:
spec = V2SpecFlex(
    token_len=8,  # DV8×8×8
    state_len=64,
    H=8, W=8,
)

train_ds = ShardedCompressionV2Dataset("/root/work/data/raw", "train_v2.0", spec=spec)
val_ds   = ShardedCompressionV2Dataset("/root/work/data/raw", "val_v2.0",   spec=spec)
test_ds  = ShardedCompressionV2Dataset("/root/work/data/raw", "test_v2.0",  spec=spec)  # segment_idx無しでも動く

z, s = train_ds[0]
print("train sample:", z.shape, s.shape)

ValueError: frame count mismatch: video=48, states=64 for shard files:
/root/work/data/raw/test_v2.0/videos/video_0.bin
/root/work/data/raw/test_v2.0/robot_states/states_0.bin

In [14]:
import os, numpy as np

vpath = "/root/work/data/raw/test_v2.0/videos/video_0.bin"
spath = "/root/work/data/raw/test_v2.0/robot_states/states_0.bin"

v_bytes = os.path.getsize(vpath)
s_bytes = os.path.getsize(spath)

print("video bytes:", v_bytes, "itemsize(int32):", np.dtype(np.int32).itemsize)
print("states bytes:", s_bytes, "itemsize(float32):", np.dtype(np.float32).itemsize)

print("video n_int32:", v_bytes // 4, "remainder:", v_bytes % 4)
print("states n_f32 :", s_bytes // 4, "remainder:", s_bytes % 4)

# 期待候補
print("video == 8*8*8 ?", (v_bytes//4) == (8*8*8))
print("states == 64*25 ?", (s_bytes//4) == (64*25))


video bytes: 12288 itemsize(int32): 4
states bytes: 6400 itemsize(float32): 4
video n_int32: 3072 remainder: 0
states n_f32 : 1600 remainder: 0
video == 8*8*8 ? False
states == 64*25 ? True


In [19]:
import os, numpy as np

vpath = "/root/work/data/raw/train_v2.0/segment_indices/videos/video_0.bin"
bytes = os.path.getsize(vpath)
print("bytes:", bytes)
print("int32 elems:", bytes // 4)

# 32*32*3 = 3072
print("div by 3072 ?", (bytes // 4) % 3072 == 0)


bytes: 81358848
int32 elems: 20339712
div by 3072 ? True


In [20]:
import os, json
import numpy as np

split_dir = "/root/work/data/raw/train_v2.0"
rank = 0

# train の構造に合わせてパスを調整（あなたの構造: segment_indices/videos/ と robot_states/ と metadata/）
meta_path = os.path.join(split_dir, "metadata", f"metadata_{rank}.json")
seg_path  = os.path.join(split_dir, "segment_indices", f"segment_idx_{rank}.bin")
vid_path  = os.path.join(split_dir, "segment_indices", "videos", f"video_{rank}.bin")
st_path   = os.path.join(split_dir, "robot_states", f"states_{rank}.bin")

meta = json.load(open(meta_path))
T = int(meta["shard_num_frames"])
print("total_frames:", T)

video = np.memmap(vid_path, mode="r", dtype=np.int32,   shape=(T, 32, 32))
segid = np.memmap(seg_path, mode="r", dtype=np.int32,   shape=(T,))
state = np.memmap(st_path,  mode="r", dtype=np.float32, shape=(T, 25))

print("video shape:", video.shape, "dtype:", video.dtype)
print("segid shape:", segid.shape, "dtype:", segid.dtype)
print("state shape:", state.shape, "dtype:", state.dtype)

# 先頭10の segid（変化しているか）
print("segid head:", np.asarray(segid[:20]).tolist())

# 値域ざっくり
v0 = np.asarray(video[0])
print("video[0] min/max:", int(v0.min()), int(v0.max()))
print("state[0] mean/std:", float(state[0].mean()), float(state[0].std()))


total_frames: 112542


ValueError: mmap length is greater than file size

In [21]:
import os, json, numpy as np

split_dir = "/root/work/data/raw/train_v2.0"
rank = 0

meta_path = os.path.join(split_dir, "metadata", f"metadata_{rank}.json")
seg_path  = os.path.join(split_dir, "segment_indices", f"segment_idx_{rank}.bin")
vid_path  = os.path.join(split_dir, "segment_indices", "videos", f"video_{rank}.bin")
st_path   = os.path.join(split_dir, "robot_states", f"states_{rank}.bin")

meta = json.load(open(meta_path))
T = int(meta["shard_num_frames"])
print("T(meta) =", T)

vid_bytes = os.path.getsize(vid_path)
seg_bytes = os.path.getsize(seg_path)
st_bytes  = os.path.getsize(st_path)

print("video bytes:", vid_bytes)
print("segment_idx bytes:", seg_bytes)
print("states bytes:", st_bytes)

# sanity: segment_idx and states are per-frame arrays in v2.0
print("seg elems int32:", seg_bytes // 4, "remainder:", seg_bytes % 4)
print("state elems f32 :", st_bytes // 4,  "remainder:", st_bytes % 4)

need_vid_32 = T * 32 * 32 * np.dtype(np.int32).itemsize
need_vid_8  = T * 8  * 8  * np.dtype(np.int32).itemsize
print("need video bytes if (T,32,32):", need_vid_32)
print("need video bytes if (T, 8, 8):", need_vid_8)


T(meta) = 112542
video bytes: 81358848
segment_idx bytes: 450168
states bytes: 11254200
seg elems int32: 112542 remainder: 0
state elems f32 : 2813550 remainder: 0
need video bytes if (T,32,32): 460972032
need video bytes if (T, 8, 8): 28810752


In [22]:
import numpy as np, os

vid_elems = os.path.getsize(vid_path) // np.dtype(np.int32).itemsize
print("video int32 elems:", vid_elems)
print("video elems / T(meta):", vid_elems / T)

# よくある候補を当てる
cands = {
    "32x32": 32*32,
    "3x32x32": 3*32*32,
    "8x8": 8*8,
    "8x8x8": 8*8*8,  # 512
    "6x32x32": 6*32*32,
}
for name, k in cands.items():
    ok = (vid_elems % k == 0)
    if ok:
        t_infer = vid_elems // k
        print(f"divisible by {name} ({k}) -> inferred T={t_infer}")


video int32 elems: 20339712
video elems / T(meta): 180.7299674788079
divisible by 32x32 (1024) -> inferred T=19863
divisible by 3x32x32 (3072) -> inferred T=6621
divisible by 8x8 (64) -> inferred T=317808
divisible by 8x8x8 (512) -> inferred T=39726


In [23]:
import os, numpy as np, json

split_dir = "/root/work/data/raw/train_v2.0"
rank = 0
meta = json.load(open(f"{split_dir}/metadata/metadata_{rank}.json"))
T = int(meta["shard_num_frames"])

vid_path = f"{split_dir}/segment_indices/videos/video_{rank}.bin"
seg_path = f"{split_dir}/segment_indices/segment_idx_{rank}.bin"

vid_elems = os.path.getsize(vid_path)//4
N = vid_elems//512
print("T(frames):", T)
print("video elems:", vid_elems, "=> N_cubes(if 8x8x8):", N, "remainder:", vid_elems % 512)

seg = np.memmap(seg_path, mode="r", dtype=np.int32, shape=(T,))
seg = np.asarray(seg)
print("seg min/max:", int(seg.min()), int(seg.max()))
print("unique seg count:", int(np.unique(seg).size))
print("seg max < N_cubes ?", int(seg.max()) < N)


T(frames): 112542
video elems: 20339712 => N_cubes(if 8x8x8): 39726 remainder: 0
seg min/max: 0 520
unique seg count: 521
seg max < N_cubes ? True
