In [24]:
import h5py
# 加载轨迹数据
data_path = "/home/glt/Projects/HIRL/data/pusht_human_mouse_trajectories/trajectories_1episodes.h5"

with h5py.File(data_path, 'r') as f:
    print(f['episode_0']['steps'].keys())
    print(f['episode_0']['steps']['observations'])

<KeysViewHDF5 ['actions', 'is_human_action', 'observations', 'rewards', 'terminated', 'truncated']>
<HDF5 dataset "observations": shape (215, 2), type "<f8">


In [None]:
# 可视化轨迹中的关键帧
n_steps = len(episode.steps)
key_frames = [0, n_steps//2, -1] if n_steps > 2 else list(range(n_steps))
frame_names = ['开始帧', '中间帧', '结束帧'] if n_steps > 2 else [f'第{i}帧' for i in range(n_steps)]

fig, axes = plt.subplots(2, len(key_frames), figsize=(5*len(key_frames), 10))
if len(key_frames) == 1:
    axes = axes.reshape(2, 1)

for i, (frame_idx, frame_name) in enumerate(zip(key_frames, frame_names)):
    if frame_idx < 0:
        frame_idx = n_steps + frame_idx
    
    step = episode.steps[frame_idx]
    obs = step.observation
    
    if isinstance(obs, dict) and 'pixels' in obs and 'agent_pos' in obs:
        pixels = obs['pixels']
        agent_pos = obs['agent_pos']
        
        # 显示pixels图像
        axes[0, i].imshow(pixels)
        axes[0, i].set_title(f'{frame_name}\nPixels观测 {pixels.shape}')
        axes[0, i].set_xlabel('Width (像素)')
        axes[0, i].set_ylabel('Height (像素)')
        
        # 在图像上标记agent位置
        # agent_pos是在[0, 512]范围内，需要缩放到[0, 96]图像坐标
        agent_x_img = agent_pos[0] * pixels.shape[1] / 512
        agent_y_img = agent_pos[1] * pixels.shape[0] / 512
        
        axes[0, i].plot(agent_x_img, agent_y_img, 'ro', markersize=8, 
                       label=f'Agent位置')
        axes[0, i].legend()
        
        # 显示agent位置数据
        axes[1, i].bar(['X坐标', 'Y坐标'], agent_pos, color=['blue', 'red'], alpha=0.7)
        axes[1, i].set_title(f'{frame_name}\nAgent位置坐标')
        axes[1, i].set_ylabel('坐标值')
        axes[1, i].set_ylim(0, 512)
        
        # 添加数值标签
        for j, val in enumerate(agent_pos):
            axes[1, i].text(j, val + 10, f'{val:.1f}', ha='center', va='bottom')
    else:
        axes[0, i].text(0.5, 0.5, 'No pixels data', ha='center', va='center')
        axes[1, i].text(0.5, 0.5, 'No agent_pos data', ha='center', va='center')

plt.tight_layout()
plt.show()


In [None]:
# 提取所有步骤的数据进行统计分析
agent_positions = []
pixels_shapes = []
rewards = []
actions = []

for i, step in enumerate(episode.steps):
    obs = step.observation
    
    if isinstance(obs, dict):
        if 'agent_pos' in obs:
            agent_positions.append(obs['agent_pos'].copy())
        if 'pixels' in obs:
            pixels_shapes.append(obs['pixels'].shape)
    
    rewards.append(step.reward)
    actions.append(step.action.copy())

agent_positions = np.array(agent_positions) if agent_positions else np.array([])
rewards = np.array(rewards)
actions = np.array(actions) if actions else np.array([])

print("=== 轨迹数据统计 ===")
print(f"总步数: {len(episode.steps)}")

if len(agent_positions) > 0:
    print(f"\nAgent位置统计:")
    print(f"  - Agent位置数组shape: {agent_positions.shape}")
    print(f"  - X坐标范围: [{agent_positions[:, 0].min():.2f}, {agent_positions[:, 0].max():.2f}]")
    print(f"  - Y坐标范围: [{agent_positions[:, 1].min():.2f}, {agent_positions[:, 1].max():.2f}]")
    print(f"  - X坐标标准差: {agent_positions[:, 0].std():.2f}")
    print(f"  - Y坐标标准差: {agent_positions[:, 1].std():.2f}")

if len(pixels_shapes) > 0:
    print(f"\nPixels数据统计:")
    print(f"  - 所有pixels观测的shape: {set(pixels_shapes)}")
    print(f"  - Pixels观测一致性: {'一致' if len(set(pixels_shapes)) == 1 else '不一致'}")

print(f"\n奖励统计:")
print(f"  - 奖励范围: [{rewards.min():.4f}, {rewards.max():.4f}]")
print(f"  - 平均奖励: {rewards.mean():.4f}")
print(f"  - 奖励标准差: {rewards.std():.4f}")

if len(actions) > 0:
    print(f"\n动作统计:")
    print(f"  - 动作数组shape: {actions.shape}")
    print(f"  - X动作范围: [{actions[:, 0].min():.2f}, {actions[:, 0].max():.2f}]")
    print(f"  - Y动作范围: [{actions[:, 1].min():.2f}, {actions[:, 1].max():.2f}]")


In [25]:
# read pickle file
import pickle
data = pickle.load(open('data/pusht_human_mouse_trajectories/2_trajectories.pickle', 'rb'))

In [None]:
data[0].steps[0].observation