Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions configs/examples/segments.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"0": [
[
5,
15
],
[
10,
20
]
],
"1": [
[
0,
30
]
]
}
55 changes: 38 additions & 17 deletions src/opentau/datasets/compute_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,40 +192,61 @@ def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[st
}


def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
def compute_episode_stats(
episode_data: dict[str, list[str] | np.ndarray],
features: dict,
skip_video_stats: bool = False,
) -> dict:
"""Compute statistics for a single episode.

For image/video features, samples and downsamples images before computing stats.
For image/video features, samples and downsamples images before computing stats
(unless skip_video_stats is True, in which case placeholder stats are used).
For other features, computes stats directly on the array data.

Args:
episode_data: Dictionary mapping feature names to their data (arrays or image paths).
features: Dictionary of feature specifications with 'dtype' keys.
skip_video_stats: If True, do not compute real stats for image/video features;
instead use placeholder stats (min=0, max=1, mean=0.5, std=0.5, count from data)
so the output format remains valid.

Returns:
Dictionary mapping feature names to their statistics (min, max, mean, std, count).
Image statistics are normalized to [0, 1] range.
Image statistics are normalized to [0, 1] range (or placeholders when skip_video_stats).
"""
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
continue # HACK: we should receive np.arrays of strings
elif features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data) # data is a list of image paths
axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True
if skip_video_stats:
# Placeholder stats: shape (3, 1, 1) for min/max/mean/std, count from length
n_frames = len(data) if isinstance(data, list) else data.shape[0]
shape = features[key]["shape"]
# Expected shape for video is (C, H, W) e.g. (3, H, W)
c = shape[0] if len(shape) >= 3 else 3
ep_stats[key] = {
"min": np.zeros((c, 1, 1), dtype=np.float64),
"max": np.ones((c, 1, 1), dtype=np.float64),
"mean": np.full((c, 1, 1), 0.5, dtype=np.float64),
"std": np.full((c, 1, 1), 0.5, dtype=np.float64),
"count": np.array([n_frames]),
}
else:
image_paths = data.tolist() if isinstance(data, np.ndarray) else data
ep_ft_array = sample_images(image_paths) # image_paths is list[str]
axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
# normalize and remove batch dim for images
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
else:
ep_ft_array = data # data is already a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array

ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)

# finally, we normalize and remove batch dim for images
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
ep_ft_array = data if isinstance(data, np.ndarray) else np.asarray(data)
axes_to_reduce = (0,) # compute stats over the first axis
keepdims = ep_ft_array.ndim == 1 # keep as np.array
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)

return ep_stats

Expand Down
12 changes: 9 additions & 3 deletions src/opentau/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,7 +1670,9 @@ def save_episode(self, episode_data: dict | None = None) -> None:

self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index)
ep_stats = compute_episode_stats(episode_buffer, self.features)
ep_stats = compute_episode_stats(
episode_buffer, self.features, skip_video_stats=getattr(self, "skip_video_stats", False)
)

if len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index)
Expand All @@ -1682,9 +1684,11 @@ def save_episode(self, episode_data: dict | None = None) -> None:

ep_data_index, _ = get_episode_data_index(self.meta.episodes, [episode_index])
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
timestamps = np.asarray(episode_buffer["timestamp"]).reshape(-1)
episode_indices = np.full(episode_length, episode_index)
check_timestamps_sync(
episode_buffer["timestamp"],
episode_buffer["episode_index"],
timestamps,
episode_indices,
ep_data_index_np,
self.fps,
self.tolerance_s,
Expand Down Expand Up @@ -1870,6 +1874,7 @@ def create(
image_resample_strategy: str = "nearest",
vector_resample_strategy: str = "nearest",
standardize: bool = True,
skip_video_stats: bool = False,
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data."""
obj = cls.__new__(cls)
Expand Down Expand Up @@ -1903,5 +1908,6 @@ def create(
obj.image_resample_strategy = image_resample_strategy
obj.vector_resample_strategy = vector_resample_strategy
obj.standardize = standardize
obj.skip_video_stats = skip_video_stats
obj.episode_data_index, obj.epi2idx = get_episode_data_index(obj.meta.episodes, obj.episodes)
return obj
Loading
Loading