In [None]:
import yaml
import torch
import argparse
import numpy as np
import open3d as o3d
import torchvision.transforms as T
from copy import deepcopy
import cv2
from easydict import EasyDict as edict
import glob
from IPython.display import clear_output

import plotly.graph_objects as go

import os
import pathlib

# o3d.visualization.webrtc_server.enable_webrtc()
# 在notebook中使用open3d可视化需要启用webrtc_server
# 貌似这个会自动设置，不能手动二次指定

# test_color = "/home/haoxiang/RISE-2/saved_test_data/color_0000.png"
# test_depth = "/home/haoxiang/RISE-2/saved_test_data/depth_0000.png"
test_color = "/data/haoxiang/data/FLIPPING_v3/train/scene_0001/cam_104122060902/color/1767593840262.png"
test_depth = "/data/haoxiang/data/FLIPPING_v3/train/scene_0001/cam_104122060902/depth/1767593840262.png"

config_path = "test.yaml"
# mask_dir = "/data/haoxiang/propainter/masks_FLIPPING_v3_scene0001"

mask_dir = "/data/haoxiang/propainter/masks_FLIPPING_v3_scene0001_arm_only"
color_dir = "/data/haoxiang/data/FLIPPING_v3/train/scene_0001/cam_104122060902/color"
depth_dir = "/data/haoxiang/data/FLIPPING_v3/train/scene_0001/cam_104122060902/depth"

fake_intrinsics = np.array([
    [915.384521484375, 0.0, 633.3715209960938],
    [0.0, 914.9421997070312, 354.1505432128906],
    [0.0, 0.0, 1.0]
])


fake_depth_scale = 1000.0

def get_file_sorted_position(file_path):
    """
    获取指定文件在其所在目录下按文件名排序后的位置（从0开始计数）
    
    Args:
        file_path (str): 文件的路径（绝对路径/相对路径均可）
    
    Returns:
        int: 文件在排序后的列表中的位置（从0开始）；若文件不存在返回-1
    
    Raises:
        ValueError: 传入的路径是目录而非文件时抛出
    """
    # 转换为绝对路径并解析
    abs_path = pathlib.Path(file_path).resolve()
    
    # 检查文件是否存在
    if not abs_path.exists():
        print(f"错误：文件 {abs_path} 不存在")
        return -1
    
    # 检查是否是文件（而非目录）
    if not abs_path.is_file():
        raise ValueError(f"传入的路径 {abs_path} 是目录，不是文件！")
    
    # 获取文件所在目录和文件名
    dir_path = abs_path.parent
    file_name = abs_path.name
    
    # 获取目录下所有文件（排除子目录），并按文件名排序
    # 按字符串自然排序（区分大小写，若要忽略大小写可改为 key=lambda x: x.lower()）
    all_files = [f.name for f in dir_path.iterdir() if f.is_file()]
    sorted_files = sorted(all_files)
    
    # 查找文件位置（从0开始计数）
    try:
        position = sorted_files.index(file_name)  # 列表索引从0开始，位置从0开始
        return position
    except ValueError:
        # 理论上不会走到这里，因为已经验证文件存在
        print(f"异常：文件 {file_name} 不在目录 {dir_path} 的文件列表中")
        return -1

def load_test_obs(color_path, depth_path):
    # 1. 加载彩色图并转为 RGB (OpenCV 默认读入是 BGR)
    color_image = cv2.imread(color_path)
    if color_image is None:
        raise ValueError(f"无法加载图片: {color_path}")
    color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB).astype(np.uint8)

    # 2. 加载深度图
    # 注意：必须使用 cv2.IMREAD_UNCHANGED 才能保留 16bit 深度信息
    depth_image = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
    if depth_image is None:
        raise ValueError(f"无法加载深度图: {depth_path}")
    
    # 确保是 uint16。如果你的 PNG 是 8bit 的，需要根据量化比例转回 uint16 (通常单位是毫米)
    depth_image = depth_image.astype(np.uint16)

    return color_image, depth_image

def resize_image(image_list, image_size, interpolation = T.InterpolationMode.BILINEAR):
    resize = T.Resize(image_size, interpolation)
    image_list_resized = resize(image_list)
    return image_list_resized

def create_point_cloud(colors, depths, intrinsics, config, mask=None, depth_scale = 1000.0, rescale_factor = 1):
    """
    color, depth, mask => point cloud
    """
    if rescale_factor != 1:
        H, W = depths.shape
        h, w = int(H * rescale_factor), int(W * rescale_factor)

        # 颜色图缩放 (双线性)
        colors = colors.transpose([2, 0, 1]).astype(np.float32)
        colors = torch.from_numpy(colors)
        colors = np.ascontiguousarray(resize_image(colors, [h, w]).numpy().transpose([1, 2, 0]))

        # 深度图缩放 (最近邻)
        depths = depths.astype(np.float32)
        depths = torch.from_numpy(depths[np.newaxis])
        depths = resize_image(depths, [h,w], interpolation = T.InterpolationMode.NEAREST)[0]
        depths = depths.numpy()

        # Mask 同步缩放 (最近邻)
        if mask is not None:
            mask_torch = torch.from_numpy(mask.astype(np.float32))[np.newaxis]
            mask = resize_image(mask_torch, [h, w], interpolation=T.InterpolationMode.NEAREST)[0].numpy()

    # 应用 Mask
    if mask is not None:
        depths = depths.copy() # 防止修改原始深度图
        depths[mask > 0] = 0  # 255白色部分为mask，深度置零

    # generate point cloud
    h, w = depths.shape
    fx, fy = intrinsics[0, 0] * rescale_factor, intrinsics[1, 1] * rescale_factor
    cx, cy = intrinsics[0, 2] * rescale_factor, intrinsics[1, 2] * rescale_factor
    colors = o3d.geometry.Image(colors.astype(np.uint8))
    depths = o3d.geometry.Image(depths.astype(np.float32))
    camera_intrinsics = o3d.camera.PinholeCameraIntrinsic(
        width = w, height = h, fx = fx, fy = fy, cx = cx, cy = cy
    )

    # Open3D 内部逻辑：depth == 0 的像素不会产生 3D 点
    rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
        colors, depths, depth_scale, convert_rgb_to_intensity = False
    )
    cloud = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, camera_intrinsics)
    # crop point cloud
    bbox3d = o3d.geometry.AxisAlignedBoundingBox(config.deploy.workspace.min, config.deploy.workspace.max)
    cloud = cloud.crop(bbox3d)
    # downsample
    cloud = cloud.voxel_down_sample(config.data.voxel_size)
    return cloud

def create_input(colors, depths, cam_intrinsics, config, mask=None, depth_scale = 1000.0, rescale_factor = 1):
    """
    colors, depths => coords, points
    """
    # create point cloud
    cloud = create_point_cloud(
        colors, 
        depths, 
        cam_intrinsics, 
        config,
        mask = mask,
        depth_scale = depth_scale,
        rescale_factor = rescale_factor,
    )

    # convert to sparse tensor
    points = np.asarray(cloud.points)
    # 按照 voxel_size 进行量化，得到稀疏卷积需要的坐标
    coords = np.ascontiguousarray(points / config.data.voxel_size, dtype = np.int32)

    return coords, points, cloud

def visualize_in_jupyter(cloud):
    points = np.asarray(cloud.points)
    colors = np.asarray(cloud.colors)
    
    if len(points) > 100000: # 采样点数可以稍微设高一点，Plotly 在 10w 点左右还算流畅
        idx = np.random.choice(len(points), 100000, replace=False)
        points = points[idx]
        colors = colors[idx]

    fig = go.Figure(data=[go.Scatter3d(
        x=points[:, 0], y=points[:, 1], z=points[:, 2],
        mode='markers',
        marker=dict(
            size=1.5,      # 增大点的大小，看起来更实一点
            color=colors, 
            opacity=1.0
        )
    )])
    
    # --- 修改这里来增大窗口 ---
    fig.update_layout(
        # width=1000,        # 设置宽度（像素）
        height=800,        # 设置高度（像素）
        margin=dict(l=0, r=0, b=0, t=0), # 去掉四周留白，让点云撑满窗口
        scene=dict(
            aspectmode='data', # 保持物理坐标比例
            xaxis=dict(visible=True), # 如果不想看轴，可以设为 False
            yaxis=dict(visible=True),
            zaxis=dict(visible=True),
        )
    )
    
    clear_output(wait=True) 
    fig.show()

# load config
with open(config_path, "r") as f:
    config = edict(yaml.load(f, Loader = yaml.FullLoader))
config.data.normalization.trans_min = np.asarray(config.data.normalization.trans_min)
config.data.normalization.trans_max = np.asarray(config.data.normalization.trans_max)

# colors, depths = load_test_obs(test_color, test_depth)

# coords, points, cloud = create_input(
#     colors,
#     depths,
#     # cam_intrinsics = agent.intrinsics,
#     cam_intrinsics = fake_intrinsics,
#     config = config,
#     # depth_scale = agent.camera.depth_scale,
#     depth_scale = fake_depth_scale,
#     rescale_factor = 1.0
# )

# # o3d.visualization.draw([cloud])
# visualize_in_jupyter(cloud)

# 获取排序后的原始文件列表
color_paths = sorted(glob.glob(os.path.join(color_dir, "*.png")))
depth_paths = sorted(glob.glob(os.path.join(depth_dir, "*.png")))

# --- 2. 确定遍历范围 ---
num_frames = len(color_paths)
print(f"找到 {num_frames} 帧数据")


for i in range(num_frames):
    
    if i % 60 != 0:
        continue  # 每 60 帧处理一次，方便测试

    # 获取当前帧的文件路径
    c_path = color_paths[i]
    d_path = depth_paths[i]
    print(f"color path: {c_path}")
    print(f"depth path: {d_path}")
    # input()
    
    # 获取对应的 Mask 路径 (按 00000.png 格式)
    m_path = os.path.join(mask_dir, f"{i:05d}.png")
    print(f"mask path: {m_path}")
    # input()
    
    if not os.path.exists(m_path):
        print(f"跳过第 {i} 帧：未找到 Mask {m_path}")
        continue

    # --- 3. 加载数据 ---
    # 加载彩色图和深度图
    colors, depths = load_test_obs(c_path, d_path)
    
    # 加载 Mask 
    mask_img = cv2.imread(m_path, cv2.IMREAD_GRAYSCALE)

    # --- 4. 生成点云 ---
    coords, points, cloud = create_input(
        colors,
        depths,
        mask=mask_img ,           # <--- 传入 Mask
        cam_intrinsics=fake_intrinsics,
        config=config,
        depth_scale=fake_depth_scale,
        rescale_factor=1.0
    )

    print(f"处理完成第 {i} 帧: {os.path.basename(c_path)} | 点数: {len(points)}")

    # --- 5. 可视化或保存 ---
    # 在循环中可视化建议加上提示，否则一次性弹出太多窗口会卡死浏览器
    visualize_in_jupyter(cloud)
    input()
    
    # 或者保存点云到文件，方便以后查看
    # o3d.io.write_point_cloud(f"output_pc_{i:05d}.ply", cloud)