In [None]:
!pip3 install waymo-open-dataset-tf-2-12-0==1.6.7

# 下载deeplab2安装脚本并添加执行权限
!wget https://raw.githubusercontent.com/waymo-research/waymo-open-dataset/master/src/waymo_open_dataset/pip_pkg_scripts/install_deeplab2.sh
!chmod +x install_deeplab2.sh
!./install_deeplab2.sh

In [13]:
import os
from typing import Any, List, Sequence, Tuple, Optional, Iterator, Dict
import immutabledict
import matplotlib.pyplot as plt
import tensorflow as tf
import multiprocessing as mp
import numpy as np
import dask.dataframe as dd
from PIL import Image
from tqdm import tqdm
from rich.console import Console
from rich.rule import Rule

if not tf.executing_eagerly():
  tf.compat.v1.enable_eager_execution()

from waymo_open_dataset import v2
from waymo_open_dataset import dataset_pb2 as open_dataset
from waymo_open_dataset.protos import camera_segmentation_metrics_pb2 as metrics_pb2
from waymo_open_dataset.protos import camera_segmentation_submission_pb2 as submission_pb2
from waymo_open_dataset.wdl_limited.camera_segmentation import camera_segmentation_metrics
from waymo_open_dataset.utils import camera_segmentation_utils

In [18]:
def process_waymo_tfrecord(tfrecord_path: str, output_dir: str, save_visualization: bool = True):
    """
    处理单个Waymo PVPS tfrecord文件，提取原始图像和语义标签。

    Args:
        tfrecord_path (str): 输入的 .tfrecord 文件的完整路径。
        output_dir (str): 保存输出数据的根目录。
        save_visualization (bool, optional): 是否同时保存彩色的可视化语义图。默认为 True。
    """
    # 定义相机名称到数字后缀的映射
    camera_name_map = {
        'FRONT': '0',
        'FRONT_LEFT': '1',
        'FRONT_RIGHT': '2',
        'SIDE_LEFT': '3',  # 已根据Waymo标准修正了您提到的'SIDE_LERF'
        'SIDE_RIGHT': '4'
    }

    # 打印处理开始信息
    filename = os.path.basename(tfrecord_path)
    console.print(Rule(f"[bold blue]正在处理: {filename}", style="blue"))
    
    # 创建输出目录
    #raw_images_dir = os.path.join(output_dir, 'images')
    semantic_labels_dir = os.path.join(output_dir, 'labels')
    #os.makedirs(raw_images_dir, exist_ok=True)
    os.makedirs(semantic_labels_dir, exist_ok=True)
    
    if save_visualization:
        visualization_dir = os.path.join(output_dir, 'visualizations')
        os.makedirs(visualization_dir, exist_ok=True)

    # 加载数据集并使用tqdm创建进度条
    dataset = tf.data.TFRecordDataset(tfrecord_path, compression_type='')
    
    # --- 这里是新增的修改 ---
    # 1. 为了tqdm能显示总数，我们先快速遍历一次数据集来计算总长度
    #    这会消耗一点初始时间，但能提供更好的用户体验。
    try:
        #console.print("[grey50]正在计算总帧数...[/grey50]")
        total_frames = sum(1 for _ in dataset)
        # 2. 重新创建数据集迭代器，因为上一步已经将其耗尽
        dataset = tf.data.TFRecordDataset(tfrecord_path, compression_type='')
    except Exception as e:
        console.print(f"[yellow]警告：无法计算总帧数 ({e})，进度条将不显示总进度。[/yellow]")
        total_frames = None # 如果计算失败，则回退到无总数模式
    # --- 修改结束 ---

    # 3. 在tqdm中传入计算出的 total_frames
    for data in tqdm(dataset, total=total_frames, desc=f"处理帧", unit="frame"):
        frame = open_dataset.Frame()
        frame.ParseFromString(data.numpy())

        # 检查该帧是否包含分割标签
        if not frame.images[0].camera_segmentation_label.panoptic_label:
            continue

        # 遍历该帧中的所有摄像头图像
        for image_proto in frame.images:
            camera_name_str = open_dataset.CameraName.Name.Name(image_proto.name)
            
            # 如果相机名称不在我们的映射中，则跳过
            if camera_name_str not in camera_name_map:
                continue

            # 1. 构建新的文件名
            timestamp = frame.timestamp_micros
            camera_code = camera_name_map[camera_name_str]
            base_filename = f"{timestamp}_{camera_code}"
            
            # 2. 提取并保存原始图像
            #image_bytes = image_proto.image
            #image_path = os.path.join(raw_images_dir, f"{base_filename}.png")
            #with open(image_path, 'wb') as f:
            #    f.write(image_bytes)

            # 3. 解码并分离语义标签
            panoptic_label_proto = image_proto.camera_segmentation_label
            panoptic_label = camera_segmentation_utils.decode_single_panoptic_label_from_proto(
                panoptic_label_proto
            )
            semantic_label, instance_label = camera_segmentation_utils.decode_semantic_and_instance_labels_from_panoptic_label(
                panoptic_label,
                panoptic_label_proto.panoptic_label_divisor
            )
            semantic_label_2d = np.squeeze(semantic_label, axis=-1)

            # 4. 保存单通道语义标签图
            label_path = os.path.join(semantic_labels_dir, f"{base_filename}.png")
            pil_image = Image.fromarray(semantic_label_2d.astype(np.uint8))
            pil_image.save(label_path)

            # 5. (可选) 保存彩色可视化图
            if save_visualization:
                panoptic_label_rgb = camera_segmentation_utils.panoptic_label_to_rgb(
                    semantic_label, instance_label
                )
                vis_path = os.path.join(visualization_dir, f"{base_filename}.png")
                pil_vis_image = Image.fromarray(panoptic_label_rgb)
                pil_vis_image.save(vis_path)


In [28]:
console = Console()

# --- 1. 设置您的输入参数 ---
# 原始 .tfrecord 文件所在的根目录
root_dir = "/home/datuwsl/Research/SYSU/data/Waymo_NOTR/data/waymo/raw/validation"

# 希望保存处理后数据的根目录
save_dir = "/home/datuwsl/Research/SYSU/data/Waymo_NOTR/data/waymo/processed/validation"

# 包含segment名称列表的txt文件
segment_file = "/home/datuwsl/Research/SYSU/data/Waymo_NOTR/data/waymo_valid_list.txt"

# 您希望处理的场景索引 (0-based)
# 这是您提到的验证集有数据的20个索引
scene_ids = [1, 12, 30, 44, 46, 92, 98, 108, 140, 143, 150, 155, 165, 167, 174, 178, 180, 188, 195, 197]

# --- 2. 开始执行批量处理 ---
console.print(Rule("[bold yellow]开始批量处理任务", style="yellow"))

try:
    # 从txt文件中读取所有的segment名称
    with open(segment_file, 'r') as f:
        # 使用 .strip() 清除每行末尾的换行符，并过滤掉空行
        all_segments = [line.strip() for line in f if line.strip()]

    console.print(f"成功从 '{os.path.basename(segment_file)}' 中读取 {len(all_segments)} 个segment名称。")
    console.print(f"计划处理 {len(scene_ids)} 个指定场景。")

    # 遍历您指定的场景索引
    for scene_id in scene_ids:
        # 检查索引是否有效
        if scene_id >= len(all_segments):
            console.print(f"[bold red]警告: 索引 {scene_id} 超出文件列表范围 (共 {len(all_segments)} 行)，已跳过。[/bold red]")
            continue
        
        # 获取对应的segment文件名
        segment_name = all_segments[scene_id]
        
        # 构建完整的输入tfrecord文件路径
        input_tfrecord_path = os.path.join(root_dir, f"{segment_name}.tfrecord")
        
        # 构建输出目录路径，使用三位补全的scene_id
        # 例如: scene_id=6 -> "006", scene_id=26 -> "026"
        output_segment_dir = os.path.join(save_dir, f"{scene_id:03d}")
        
        # 检查输入文件是否存在，防止因文件缺失而中断
        if not os.path.exists(input_tfrecord_path):
            console.print(f"[bold red]警告: 输入文件不存在: {input_tfrecord_path}，已跳过。[/bold red]")
            console.print()
            continue
            
        # 调用核心函数进行处理
        process_waymo_tfrecord(
            tfrecord_path=input_tfrecord_path, 
            output_dir=output_segment_dir,
            save_visualization=True # 如果不需要可视化图，可设为False
        )

except FileNotFoundError:
    console.print(f"[bold red]错误: Segment列表文件未找到: {segment_file}[/bold red]")
except Exception as e:
    console.print(f"[bold red]处理过程中发生未知错误: {e}[/bold red]")

console.print(Rule("[bold yellow]所有指定任务处理完毕", style="yellow"))

处理帧: 100%|██████████| 199/199 [00:36<00:00,  5.38frame/s]
处理帧: 100%|██████████| 199/199 [00:36<00:00,  5.38frame/s]


处理帧: 100%|██████████| 199/199 [00:36<00:00,  5.52frame/s]
处理帧: 100%|██████████| 199/199 [00:36<00:00,  5.52frame/s]


处理帧: 100%|██████████| 198/198 [00:36<00:00,  5.49frame/s]
处理帧: 100%|██████████| 198/198 [00:36<00:00,  5.49frame/s]


处理帧: 100%|██████████| 199/199 [00:21<00:00,  9.24frame/s]
处理帧: 100%|██████████| 199/199 [00:21<00:00,  9.24frame/s]


处理帧: 100%|██████████| 198/198 [00:35<00:00,  5.58frame/s]
处理帧: 100%|██████████| 198/198 [00:35<00:00,  5.58frame/s]


处理帧: 100%|██████████| 192/192 [00:33<00:00,  5.68frame/s]
处理帧: 100%|██████████| 192/192 [00:33<00:00,  5.68frame/s]


处理帧: 100%|██████████| 198/198 [00:33<00:00,  5.92frame/s]
处理帧: 100%|██████████| 198/198 [00:33<00:00,  5.92frame/s]


处理帧: 100%|██████████| 199/199 [00:37<00:00,  5.35frame/s]
处理帧: 100%|██████████| 199/199 [00:37<00:00,  5.35frame/s]


处理帧: 100%|██████████| 199/199 [00:36<00:00,  5.47frame/s]
处理帧: 100%|██████████| 199/199 [00:36<00:00,  5.47frame/s]


处理帧: 100%|██████████| 198/198 [00:35<00:00,  5.56frame/s]
处理帧: 100%|██████████| 198/198 [00:35<00:00,  5.56frame/s]


处理帧: 100%|██████████| 198/198 [00:34<00:00,  5.80frame/s]
处理帧: 100%|██████████| 198/198 [00:34<00:00,  5.80frame/s]


处理帧: 100%|██████████| 198/198 [00:35<00:00,  5.55frame/s]
处理帧: 100%|██████████| 198/198 [00:35<00:00,  5.55frame/s]


处理帧: 100%|██████████| 198/198 [00:36<00:00,  5.37frame/s]
处理帧: 100%|██████████| 198/198 [00:36<00:00,  5.37frame/s]


处理帧: 100%|██████████| 198/198 [00:35<00:00,  5.52frame/s]
处理帧: 100%|██████████| 198/198 [00:35<00:00,  5.52frame/s]


处理帧: 100%|██████████| 198/198 [00:36<00:00,  5.46frame/s]
处理帧: 100%|██████████| 198/198 [00:36<00:00,  5.46frame/s]


处理帧: 100%|██████████| 199/199 [00:34<00:00,  5.72frame/s]
处理帧: 100%|██████████| 199/199 [00:34<00:00,  5.72frame/s]


处理帧: 100%|██████████| 199/199 [00:36<00:00,  5.47frame/s]
处理帧: 100%|██████████| 199/199 [00:36<00:00,  5.47frame/s]


处理帧: 100%|██████████| 197/197 [00:34<00:00,  5.70frame/s]
处理帧: 100%|██████████| 197/197 [00:34<00:00,  5.70frame/s]


处理帧: 100%|██████████| 198/198 [00:35<00:00,  5.52frame/s]
处理帧: 100%|██████████| 198/198 [00:35<00:00,  5.52frame/s]


处理帧: 100%|██████████| 198/198 [00:34<00:00,  5.75frame/s]
处理帧: 100%|██████████| 198/198 [00:34<00:00,  5.75frame/s]


In [26]:
from PIL import Image
import numpy as np

# 读取 label PNG
label_path = "/home/datuwsl/Research/SYSU/data/Waymo_NOTR/datasets/pvps/processed_data/validation/001/labels/1553735853462203_0.png"
label_img = Image.open(label_path)

# 转 numpy array
label_array = np.array(label_img)

print("图像 shape:", label_array.shape)
print("数据类型:", label_array.dtype)
print("唯一类别 ID:", np.unique(label_array))


图像 shape: (1280, 1920)
数据类型: uint8
唯一类别 ID: [ 2  3  4  5  9 14 15 17 18 19 20 21 22 23 24 25 26 27 28]
