In [None]:
#!/usr/bin/env python3
"""Extract a single object point cloud from a ScanNet scene.

The script reads the down-sampled ScanNet point cloud with RGB + nyu40 labels
(`*_vh_clean_2.labels.ply`), the corresponding segmentation indices
(`*_vh_clean_2.0.010000.segs.json`), and the aggregation metadata
(`*_vh_clean.aggregation.json`). A target object can be specified either by the
majority nyu40 class id, the textual label in the aggregation file, or the
object id reported there. The resulting subset of points is written to an ASCII
PLY file that preserves xyz, rgb(a), and label information so it can be easily
visualised or reused for custom segmentation experiments.
"""
import argparse
import json
from collections import Counter, defaultdict
from pathlib import Path
from typing import Iterable, List, Optional

import numpy as np


def find_single_file(scene_dir: Path, pattern: str) -> Path:
    matches = sorted(scene_dir.glob(pattern))
    if not matches:
        raise FileNotFoundError(f"No files matching {pattern!r} in {scene_dir}")
    if len(matches) > 1:
        raise RuntimeError(
            f"Expected one file matching {pattern!r} in {scene_dir}, found {len(matches)}"
        )
    return matches[0]


def load_vertices(ply_path: Path) -> np.ndarray:
    with ply_path.open("rb") as f:
        header: List[str] = []
        while True:
            line = f.readline().decode("ascii").strip()
            header.append(line)
            if line == "end_header":
                header_len = f.tell()
                break
        vertex_count: Optional[int] = None
        for line in header:
            if line.startswith("element vertex"):
                vertex_count = int(line.split()[2])
                break
        if vertex_count is None:
            raise RuntimeError(f"Could not find vertex count in {ply_path}")
        dtype = np.dtype(
            [
                ("x", "<f4"),
                ("y", "<f4"),
                ("z", "<f4"),
                ("red", "u1"),
                ("green", "u1"),
                ("blue", "u1"),
                ("alpha", "u1"),
                ("label", "<u2"),
            ]
        )
        f.seek(header_len)
        vertices = np.fromfile(f, dtype=dtype, count=vertex_count)
    return vertices


def build_segment_index(seg_indices: Iterable[int]) -> dict[int, np.ndarray]:
    segment_to_indices: defaultdict[int, List[int]] = defaultdict(list)
    for idx, seg_id in enumerate(seg_indices):
        segment_to_indices[int(seg_id)].append(idx)
    return {seg_id: np.array(indices, dtype=np.int64) for seg_id, indices in segment_to_indices.items()}


def gather_indices(group_segments: Iterable[int], segment_to_indices: dict[int, np.ndarray]) -> np.ndarray:
    indices: List[np.ndarray] = []
    for seg_id in group_segments:
        if seg_id in segment_to_indices:
            indices.append(segment_to_indices[seg_id])
    if not indices:
        return np.array([], dtype=np.int64)
    return np.concatenate(indices)


def write_ascii_ply(path: Path, vertices: np.ndarray) -> None:
    with path.open("w", encoding="ascii") as f:
        f.write("ply\n")
        f.write("format ascii 1.0\n")
        f.write(f"element vertex {len(vertices)}\n")
        f.write("property float x\n")
        f.write("property float y\n")
        f.write("property float z\n")
        f.write("property uchar red\n")
        f.write("property uchar green\n")
        f.write("property uchar blue\n")
        f.write("property uchar alpha\n")
        f.write("property ushort label\n")
        f.write("end_header\n")
        for v in vertices:
            f.write(
                f"{v['x']:.6f} {v['y']:.6f} {v['z']:.6f} "
                f"{int(v['red'])} {int(v['green'])} {int(v['blue'])} "
                f"{int(v['alpha'])} {int(v['label'])}\n"
            )


def choose_group(groups: List[dict], target_nyu: Optional[int], target_label: Optional[str], target_object_id: Optional[int], segment_to_indices: dict[int, np.ndarray], vertex_labels: np.ndarray) -> dict:
    candidates: List[dict] = []
    for group in groups:
        indices = gather_indices(group["segments"], segment_to_indices)
        group["_vertex_indices"] = indices
        if indices.size == 0:
            group["_nyu40"] = None
            continue
        counts = Counter(vertex_labels[indices].tolist())
        group["_nyu40"] = counts.most_common(1)[0][0]
        if target_nyu is not None and group["_nyu40"] == target_nyu:
            candidates.append(group)
        elif target_label is not None and group["label"] == target_label:
            candidates.append(group)
        elif target_object_id is not None and group["objectId"] == target_object_id:
            candidates.append(group)
    if target_nyu is None and target_label is None and target_object_id is None:
        raise ValueError("No selection criteria provided")
    if not candidates:
        raise RuntimeError("No aggregation group matched the provided criteria")
    if len(candidates) > 1:
        labels = [(g["objectId"], g["label"], g.get("_nyu40")) for g in candidates]
        raise RuntimeError(f"Selection is ambiguous; matched groups: {labels}")
    return candidates[0]


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--scene-dir", type=Path, required=True, help="Path to the ScanNet scene folder (e.g. scene0000_00)")
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--nyu40-id", type=int, help="nyu40 class id to extract (e.g. 33 for toilet)")
    group.add_argument("--object-label", type=str, help="Object label string in the aggregation file")
    group.add_argument("--object-id", type=int, help="Object id from the aggregation file")
    parser.add_argument("--output", type=Path, required=True, help="Destination PLY file path")
    args = parser.parse_args()

    scene_dir: Path = args.scene_dir
    ply_path = find_single_file(scene_dir, "*_vh_clean_2.labels.ply")
    segs_path = find_single_file(scene_dir, "*_vh_clean_2.0.010000.segs.json")
    agg_path = find_single_file(scene_dir, "*_vh_clean.aggregation.json")

    vertices = load_vertices(ply_path)
    with segs_path.open() as f:
        seg_indices = np.array(json.load(f)["segIndices"], dtype=np.int64)
    with agg_path.open() as f:
        agg = json.load(f)
    groups = agg["segGroups"]

    if seg_indices.shape[0] != vertices.shape[0]:
        raise RuntimeError(
            "Mismatch between vertex count in PLY and segmentation indices: "
            f"{vertices.shape[0]} vs {seg_indices.shape[0]}"
        )

    segment_to_indices = build_segment_index(seg_indices)
    target_group = choose_group(
        groups,
        target_nyu=args.nyu40_id,
        target_label=args.object_label,
        target_object_id=args.object_id,
        segment_to_indices=segment_to_indices,
        vertex_labels=vertices["label"],
    )

    indices = target_group.get("_vertex_indices", np.array([], dtype=np.int64))
    if indices.size == 0:
        raise RuntimeError("Selected group contains no vertices after segmentation lookup")
    subset = vertices[indices]

    args.output.parent.mkdir(parents=True, exist_ok=True)
    write_ascii_ply(args.output, subset)

    print(
        f"Wrote {len(subset)} vertices for objectId {target_group['objectId']} "
        f"('{target_group['label']}'), nyu40={target_group.get('_nyu40')} to {args.output}"
    )


if __name__ == "__main__":
    main()
