In [None]:
"""
单臂机械臂 3D 重建脚本
基于 AirExo-2 的 URDF 正向运动学流程
"""

import os
import h5py
import json
import numpy as np
import open3d as o3d
from scipy.spatial.transform import Rotation

# 假设你已经将 AirExo-2 代码库添加到 Python 路径
import sys
# sys.path.append('/path/to/AirExo-2')  # 修改为实际路径

from airexo.helpers.urdf_robot import forward_kinematic_single
from airexo.helpers.constants import ROBOT_PREDEFINED_TRANSFORMATION, O3D_RENDER_TRANSFORMATION


class SingleRobotReconstructor:
    """单臂机械臂 3D 重建器"""
    
    def __init__(
        self,
        lowdim_h5_path,
        calib_json_path,
        urdf_file_path="airexo/urdf_models/robot/left_robot_inhand.urdf",
        gripper_joint_idx=7  # 夹爪关节索引（如果有）
    ):
        """
        初始化
        
        参数:
        - lowdim_h5_path: lowdim.h5 文件路径
        - calib_json_path: 标定 JSON 文件路径
        - urdf_file_path: URDF 模型文件路径
        - gripper_joint_idx: 夹爪关节在数据中的索引
        """
        self.lowdim_h5_path = lowdim_h5_path
        self.calib_json_path = calib_json_path
        self.urdf_file_path = urdf_file_path
        self.gripper_joint_idx = gripper_joint_idx
        
        # 加载数据
        self._load_lowdim_data()
        self._load_calibration_data()
        
    def _load_lowdim_data(self):
        """加载 lowdim 数据"""
        print(f"Loading lowdim data from {self.lowdim_h5_path}")
        
        with h5py.File(self.lowdim_h5_path, 'r') as f:
            # 读取关节位置（弧度）
            self.joint_positions = f['joint_position_rad_062046'][:]  # shape: (13036, 7)
            
            # 读取末端执行器状态（用于夹爪）
            self.ee_states = f['ee_state_062046'][:]  # shape: (13036, 1)
            
            # 读取时间戳
            self.timestamps = f['timestamp'][:]  # shape: (13036,)
            
            # 读取 TCP 位姿（可选，用于验证）
            self.tcp_poses = f['tcp_pose_062046'][:]  # shape: (13036, 7)
            
        print(f"Loaded {len(self.timestamps)} frames")
        print(f"Joint positions shape: {self.joint_positions.shape}")
        
    def _load_calibration_data(self):
        """加载标定数据"""
        print(f"Loading calibration data from {self.calib_json_path}")
        
        with open(self.calib_json_path, 'r') as f:
            calib_data = json.load(f)
        
        # 提取相机内参
        self.intrinsic = np.array(calib_data['intrinsics'], dtype=np.float32)
        print(f"Camera intrinsics:\n{self.intrinsic}")
        
        # 提取相机位姿（world 坐标系到相机坐标系）
        pose_in_link = calib_data['pose_in_link']
        
        # pose_in_link 格式: [x, y, z, qx, qy, qz, qw]
        position = np.array(pose_in_link[:3])
        quaternion = np.array(pose_in_link[3:])  # [qx, qy, qz, qw]
        
        # 构建 4x4 变换矩阵（world 到 camera）
        rotation = Rotation.from_quat(quaternion).as_matrix()
        
        world_to_cam = np.eye(4, dtype=np.float32)
        world_to_cam[:3, :3] = rotation
        world_to_cam[:3, 3] = position
        
        # 我们需要 camera 到 world（robot base）的变换
        self.cam_to_base = np.linalg.inv(world_to_cam).astype(np.float32)
        
        print(f"Camera position in world: {position}")
        print(f"cam_to_base transformation:\n{self.cam_to_base}")
        
    def get_joint_config(self, num_joints=7):
        """
        创建关节配置对象
        
        这是一个简化版本，你可能需要根据实际配置调整
        """
        class JointConfig:
            def __init__(self, num_joints):
                self.num_robot_joints = num_joints
                self.num_joints = num_joints + 1  # 包含夹爪
                
        return JointConfig(num_joints)
    
    def reconstruct_at_timestamp(self, timestamp_idx):
        """
        在指定时间戳重建机械臂 3D 模型
        
        参数:
        - timestamp_idx: 时间戳索引 (0 到 13035)
        
        返回:
        - transforms: 每个 link 的变换矩阵字典
        - visuals_map: 视觉模型映射
        - meshes: 变换后的 3D 网格列表
        """
        print(f"\n=== Reconstructing at timestamp index {timestamp_idx} ===")
        print(f"Timestamp: {self.timestamps[timestamp_idx]}")
        
        # 获取该时刻的关节角度（弧度）
        joint_angles = self.joint_positions[timestamp_idx]  # shape: (7,)
        ee_state = self.ee_states[timestamp_idx, 0]  # 夹爪状态
        
        print(f"Joint angles (rad): {joint_angles}")
        print(f"End-effector state: {ee_state}")
        
        # 构建完整的关节状态（7个关节 + 1个夹爪）
        full_joint_state = np.concatenate([joint_angles, [ee_state]])
        
        # 创建关节配置
        joint_cfg = self.get_joint_config(num_joints=7)
        
        # 执行正向运动学
        print("Performing forward kinematics...")
        transforms, visuals_map = forward_kinematic_single(
            joint=full_joint_state,
            joint_cfgs=joint_cfg,
            is_rad=True,  # 数据已经是弧度
            urdf_file=self.urdf_file_path,
            with_visuals_map=True
        )
        
        print(f"Found {len(transforms)} links")
        
        # 加载并变换 3D 网格
        meshes = []
        urdf_dir = os.path.dirname(self.urdf_file_path)
        
        for link_name, transform in transforms.items():
            visuals = visuals_map.get(link_name, [])
            
            for visual in visuals:
                if visual.geom_param is None:
                    continue
                
                # 获取网格文件路径
                mesh_file = visual.geom_param
                if isinstance(mesh_file, (list, tuple)):
                    mesh_file = mesh_file[0]
                
                mesh_path = os.path.join(urdf_dir, mesh_file)
                
                if not os.path.exists(mesh_path):
                    print(f"Warning: Mesh file not found: {mesh_path}")
                    continue
                
                # 计算完整的变换矩阵
                # transform_chain: visual_offset → link_transform → URDF_correction
                tf = ROBOT_PREDEFINED_TRANSFORMATION @ \
                     transform.matrix() @ \
                     visual.offset.matrix()
                
                # 加载网格
                mesh = o3d.io.read_triangle_mesh(mesh_path)
                
                # 应用变换
                mesh.transform(tf)
                mesh.compute_vertex_normals()
                
                # 设置颜色（可选）
                mesh.paint_uniform_color([0.7, 0.7, 0.7])
                
                meshes.append({
                    'link_name': link_name,
                    'mesh_file': mesh_file,
                    'mesh': mesh,
                    'transform': tf
                })
                
                print(f"  Loaded mesh for {link_name}: {mesh_file}")
        
        return transforms, visuals_map, meshes
    
    def visualize_reconstruction(self, timestamp_idx):
        """
        可视化重建结果
        
        参数:
        - timestamp_idx: 时间戳索引
        """
        transforms, visuals_map, meshes = self.reconstruct_at_timestamp(timestamp_idx)
        
        print(f"\n=== Visualizing {len(meshes)} meshes ===")
        
        # 创建可视化窗口
        vis = o3d.visualization.Visualizer()
        vis.create_window(window_name=f"Robot 3D Reconstruction - Frame {timestamp_idx}", 
                         width=1280, height=720)
        
        # 添加坐标系
        coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
            size=0.2, origin=[0, 0, 0]
        )
        coord_frame.transform(ROBOT_PREDEFINED_TRANSFORMATION)
        vis.add_geometry(coord_frame)
        
        # 添加所有网格
        for mesh_data in meshes:
            vis.add_geometry(mesh_data['mesh'])
        
        # 运行可视化
        vis.run()
        vis.destroy_window()
    
    def render_to_image(self, timestamp_idx, width=1280, height=720):
        """
        渲染指定时刻的 RGB 图像
        
        参数:
        - timestamp_idx: 时间戳索引
        - width, height: 渲染分辨率
        
        返回:
        - rgb_image: 渲染的 RGB 图像
        - depth_image: 深度图像
        """
        transforms, visuals_map, meshes = self.reconstruct_at_timestamp(timestamp_idx)
        
        print(f"\n=== Rendering image at {width}x{height} ===")
        
        # 创建离屏渲染器
        renderer = o3d.visualization.rendering.OffscreenRenderer(width, height)
        
        # 设置材质
        material = o3d.visualization.rendering.MaterialRecord()
        material.shader = "defaultLit"
        
        # 添加所有网格到场景
        for i, mesh_data in enumerate(meshes):
            mesh = mesh_data['mesh']
            
            # 应用相机变换
            mesh_in_cam = mesh.transform(
                O3D_RENDER_TRANSFORMATION @ self.cam_to_base
            )
            
            renderer.scene.add_geometry(f"mesh_{i}", mesh_in_cam, material)
        
        # 设置相机投影
        renderer.scene.camera.set_projection(
            self.intrinsic, 
            0.01,      # near_plane
            100.0,     # far_plane
            float(width), 
            float(height)
        )
        
        # 渲染
        rgb_image = np.asarray(renderer.render_to_image(), dtype=np.uint8)
        depth_image = np.asarray(
            renderer.render_to_depth_image(z_in_view_space=True), 
            dtype=np.float32
        )
        
        print(f"Rendered image shape: {rgb_image.shape}")
        print(f"Depth image shape: {depth_image.shape}")
        
        return rgb_image, depth_image
    
    def export_meshes(self, timestamp_idx, output_dir):
        """
        导出指定时刻的所有网格
        
        参数:
        - timestamp_idx: 时间戳索引
        - output_dir: 输出目录
        """
        transforms, visuals_map, meshes = self.reconstruct_at_timestamp(timestamp_idx)
        
        os.makedirs(output_dir, exist_ok=True)
        
        print(f"\n=== Exporting meshes to {output_dir} ===")
        
        # 合并所有网格
        combined_mesh = o3d.geometry.TriangleMesh()
        
        for mesh_data in meshes:
            combined_mesh += mesh_data['mesh']
        
        # 导出合并后的网格
        output_path = os.path.join(output_dir, f"robot_frame_{timestamp_idx:06d}.ply")
        o3d.io.write_triangle_mesh(output_path, combined_mesh)
        print(f"Saved combined mesh to {output_path}")
        
        # 也可以导出单独的网格
        for i, mesh_data in enumerate(meshes):
            link_name = mesh_data['link_name'].replace('/', '_')
            output_path = os.path.join(
                output_dir, 
                f"robot_frame_{timestamp_idx:06d}_{link_name}.ply"
            )
            o3d.io.write_triangle_mesh(output_path, mesh_data['mesh'])
            print(f"  Saved {link_name} to {output_path}")


def main():
    """主函数示例"""
    
    # ===== 配置路径 =====
    LOWDIM_H5_PATH = "/data/haoxiang/data/FLIPPING_v3/train/scene_0001/lowdim/lowdim.h5"
    CALIB_JSON_PATH = "/data/haoxiang/data/FLIPPING_v3/train/scene_0001/calib.json"
    URDF_FILE_PATH = "airexo/urdf_models/robot/left_robot_inhand.urdf"
    
    # ===== 创建重建器 =====
    reconstructor = SingleRobotReconstructor(
        lowdim_h5_path=LOWDIM_H5_PATH,
        calib_json_path=CALIB_JSON_PATH,
        urdf_file_path=URDF_FILE_PATH
    )
    
    # ===== 示例1: 可视化某一帧 =====
    timestamp_idx = 100  # 选择第 100 帧
    reconstructor.visualize_reconstruction(timestamp_idx)
    
    # ===== 示例2: 渲染图像 =====
    rgb_image, depth_image = reconstructor.render_to_image(timestamp_idx)
    
    # 保存渲染结果
    import cv2
    cv2.imwrite(f"rendered_rgb_frame_{timestamp_idx}.png", rgb_image[:, :, ::-1])
    
    # 保存深度图（归一化到 0-255）
    depth_vis = (depth_image / depth_image.max() * 255).astype(np.uint8)
    cv2.imwrite(f"rendered_depth_frame_{timestamp_idx}.png", depth_vis)
    
    # ===== 示例3: 导出网格 =====
    output_dir = f"./reconstruction_output/frame_{timestamp_idx}"
    reconstructor.export_meshes(timestamp_idx, output_dir)
    
    # ===== 示例4: 批量处理多帧 =====
    print("\n=== Batch processing ===")
    for idx in range(0, len(reconstructor.timestamps), 100):  # 每隔100帧处理一次
        print(f"\nProcessing frame {idx}/{len(reconstructor.timestamps)}")
        rgb, depth = reconstructor.render_to_image(idx)
        cv2.imwrite(f"./batch_output/frame_{idx:06d}.png", rgb[:, :, ::-1])


if __name__ == "__main__":
    main()