In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import zarr
from lcdp.data.dataset import RobotDataset

%matplotlib inline

## 1. 加载数据集

In [None]:
# 加载数据
data_path = '../data/demonstrations.zarr'

dataset = RobotDataset(
    data_path=data_path,
    horizon=16,
    action_horizon=16,
    image_size=(224, 224),
    normalize_actions=True,
    file_format='zarr'
)

print(f"数据集大小: {len(dataset)} 个样本")

## 2. 可视化观测图像

In [None]:
# 可视化几个样本
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i, ax in enumerate(axes):
    sample = dataset[i * 10]
    image = sample['image'].numpy().transpose(1, 2, 0)  # CHW -> HWC
    image = (image + 1) / 2  # 反归一化到 [0, 1]
    
    ax.imshow(image)
    ax.set_title(f"Sample {i}: {sample['instruction']}", fontsize=10)
    ax.axis('off')

plt.tight_layout()
plt.show()

## 3. 分析动作分布

In [None]:
# 收集所有动作
all_actions = []
for i in range(min(1000, len(dataset))):
    sample = dataset[i]
    actions = sample['actions'].numpy()  # [action_dim, horizon]
    all_actions.append(actions)

all_actions = np.array(all_actions)  # [N, action_dim, horizon]
print(f"收集了 {len(all_actions)} 个动作序列")

In [None]:
# 绘制每个动作维度的分布
action_labels = ['x', 'y', 'z', 'roll', 'pitch', 'yaw', 'gripper']

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i, (ax, label) in enumerate(zip(axes[:7], action_labels)):
    # 取第一个时间步的动作
    actions = all_actions[:, i, 0]
    
    ax.hist(actions, bins=50, alpha=0.7, edgecolor='black')
    ax.set_xlabel(label, fontsize=12, fontweight='bold')
    ax.set_ylabel('Frequency')
    ax.grid(True, alpha=0.3)
    ax.set_title(f'{label} Distribution')

axes[7].axis('off')
plt.tight_layout()
plt.show()

## 4. 可视化动作轨迹

In [None]:
# 选择一个样本并可视化其动作序列
sample_idx = 0
sample = dataset[sample_idx]
actions = sample['actions'].numpy()  # [7, 16]

fig, axes = plt.subplots(7, 1, figsize=(12, 10))

for i, (ax, label) in enumerate(zip(axes, action_labels)):
    ax.plot(actions[i], marker='o', linewidth=2, markersize=4)
    ax.set_ylabel(label, fontsize=12, fontweight='bold')
    ax.grid(True, alpha=0.3)
    if i == 0:
        ax.set_title(f'Action Sequence\nInstruction: "{sample["instruction"]}"', 
                    fontsize=14, fontweight='bold')
    if i == 6:
        ax.set_xlabel('Time Step', fontsize=12)

plt.tight_layout()
plt.show()

## 5. 语言指令分析

In [None]:
# 收集所有指令
instructions = set()
for i in range(min(1000, len(dataset))):
    sample = dataset[i]
    instructions.add(sample['instruction'])

print(f"唯一指令数量: {len(instructions)}\n")
print("指令样例:")
for inst in list(instructions)[:10]:
    print(f"  - {inst}")

## 6. 数据统计摘要

In [None]:
print("=" * 50)
print("数据集统计摘要")
print("=" * 50)
print(f"总样本数: {len(dataset)}")
print(f"图像尺寸: {dataset.image_size}")
print(f"动作维度: {all_actions.shape[1]}")
print(f"动作时域: {all_actions.shape[2]}")
print(f"\n动作统计 (归一化后):")
for i, label in enumerate(action_labels):
    mean = all_actions[:, i, 0].mean()
    std = all_actions[:, i, 0].std()
    print(f"  {label:8s}: mean={mean:7.3f}, std={std:7.3f}")
print("=" * 50)