# 09 — Interactive 3D Visualization (Plotly)

Interactive 3D visualization of LiDAR point clouds with predicted bounding boxes.

- Uses **Plotly** for browser-based 3D rendering (works in Colab)
- Runs PointNetSegV4 inference on a single frame, then clusters + PCA bounding boxes
- Background points are heavily subsampled (~15k) for performance; obstacle points are kept at higher density (~10k)
- Bounding boxes are drawn as wireframes with class-specific colors
- Rotate / zoom / pan with your mouse

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!pip install -q h5py scikit-learn plotly

In [None]:
import gc, os, time
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.cluster import DBSCAN
from sklearn.neighbors import BallTree
import plotly.graph_objects as go

DRIVE_BASE = "/content/drive/MyDrive/airbus_hackathon"
INPUT_DIR = f"{DRIVE_BASE}/data"
CKPT_V5 = f"{DRIVE_BASE}/checkpoints_v5/best_model_v5.pt"
CKPT_V4 = f"{DRIVE_BASE}/checkpoints_v4/best_model_v4.pt"
CKPT_PATH = CKPT_V5 if os.path.exists(CKPT_V5) else CKPT_V4
SINGLE_SCENE = "scene_8"
FRAME_INDEX = 30  # Which frame to visualize (0-indexed)

# Display settings
MAX_BG_POINTS = 15000     # subsample background
MAX_OBS_POINTS = 10000    # keep more obstacle points
POINT_SIZE = 1.5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
print(f"Checkpoint: {CKPT_PATH}")
print(f"Scene: {SINGLE_SCENE}, frame: {FRAME_INDEX}")

# Config (v7.3)

In [None]:
NUM_CLASSES = 5
IN_CHANNELS = 5
CHUNK_SIZE = 65536

CLASS_NAMES = {0: "Background", 1: "Antenna", 2: "Cable", 3: "Electric Pole", 4: "Wind Turbine"}
CLASS_LABELS_CSV = {1: "Antenna", 2: "Cable", 3: "Electric Pole", 4: "Wind Turbine"}

DBSCAN_PARAMS = {
    1: {"eps": 2.0, "min_samples": 15},
    2: {"eps": 5.0, "min_samples": 5},
    3: {"eps": 2.0, "min_samples": 8},
    4: {"eps": 5.0, "min_samples": 20},
}
CABLE_MERGE_ANGLE_DEG = 15.0
CABLE_MERGE_GAP_M = 10.0

CONFIDENCE_THRESHOLD_PER_CLASS = {1: 0.40, 2: 0.27, 3: 0.25, 4: 0.30}
CONFIDENCE_THRESHOLD_DEFAULT = 0.3
BOX_CONFIDENCE_THRESHOLD_PER_CLASS = {1: 0.70, 2: 0.55, 3: 0.45, 4: 0.60}
BOX_CONFIDENCE_THRESHOLD_DEFAULT = 0.6

MIN_POINTS_PER_BOX = {1: 15, 2: 3, 3: 5, 4: 15}
MAX_DIM_PER_CLASS = {1: 200.0, 2: 400.0, 3: 100.0, 4: 250.0}
NMS_IOU_THRESHOLD = 0.3

# Colors for plotly (RGB strings)
CLASS_COLORS_PLOTLY = {
    0: 'rgb(180, 180, 180)',    # background — grey
    1: 'rgb(38, 23, 180)',      # antenna — blue
    2: 'rgb(177, 132, 47)',     # cable — gold
    3: 'rgb(129, 81, 97)',      # electric pole — mauve
    4: 'rgb(66, 132, 9)',       # wind turbine — green
}
BOX_COLORS_PLOTLY = {
    1: 'rgb(38, 23, 230)',      # antenna — bright blue
    2: 'rgb(220, 165, 60)',     # cable — bright gold
    3: 'rgb(180, 100, 130)',    # electric pole — bright mauve
    4: 'rgb(80, 180, 10)',      # wind turbine — bright green
}

print("Config loaded.")

# Model

In [None]:
class SharedMLP(nn.Module):
    def __init__(self, in_ch, out_ch, bn=True):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, 1, bias=not bn)
        self.bn = nn.BatchNorm1d(out_ch) if bn else None
    def forward(self, x):
        x = self.conv(x)
        if self.bn:
            x = self.bn(x)
        return F.relu(x, inplace=True)


class PointNetSegV4(nn.Module):
    """PointNet v4 segmentation — multi-scale skips, ~1.88M params."""
    def __init__(self, in_channels=5, num_classes=5):
        super().__init__()
        self.enc1 = SharedMLP(in_channels, 64)
        self.enc2 = SharedMLP(64, 128)
        self.enc3 = SharedMLP(128, 256)
        self.enc4 = SharedMLP(256, 512)
        self.enc5 = SharedMLP(512, 1024)
        # 64+128+256+512+1024 = 1984
        self.seg1 = SharedMLP(64 + 128 + 256 + 512 + 1024, 512)
        self.seg2 = SharedMLP(512, 256)
        self.seg3 = SharedMLP(256, 128)
        self.dropout1 = nn.Dropout(0.4)
        self.dropout2 = nn.Dropout(0.3)
        self.head = nn.Conv1d(128, num_classes, 1)

    def forward(self, x):
        B, N, _ = x.shape
        x = x.transpose(1, 2)
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        g = e5.max(dim=2, keepdim=True)[0].expand(-1, -1, N)
        seg = torch.cat([e1, e2, e3, e4, g], dim=1)
        seg = self.seg1(seg)
        seg = self.dropout1(seg)
        seg = self.seg2(seg)
        seg = self.dropout2(seg)
        seg = self.seg3(seg)
        seg = self.head(seg)
        return seg.transpose(1, 2)


# Load checkpoint
model = PointNetSegV4(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES).to(device)
ckpt = torch.load(CKPT_PATH, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
n_params = sum(p.numel() for p in model.parameters())
print(f"Model loaded: {n_params:,} params")
if "val_obstacle_miou" in ckpt:
    print(f"Checkpoint epoch {ckpt.get('epoch', '?')}, val obstacle mIoU={ckpt['val_obstacle_miou']:.4f}")

# HDF5 Reader + Inference + Clustering

All pipeline functions inlined from `scripts/inference.py` (no external imports).

In [None]:
# ============================================================================
# HDF5 CHUNKED READER
# ============================================================================

def get_frame_boundaries(h5_path, dataset_name="lidar_points", chunk_size=2_000_000):
    """Find frame boundaries by reading in chunks — vectorized with np.diff."""
    change_indices = []
    with h5py.File(h5_path, "r") as f:
        ds = f[dataset_name]
        n = ds.shape[0]
        prev_last_pose = None
        for offset in range(0, n, chunk_size):
            end = min(offset + chunk_size, n)
            chunk = ds[offset:end]
            ex = chunk["ego_x"]
            ey = chunk["ego_y"]
            ez = chunk["ego_z"]
            eyaw = chunk["ego_yaw"]
            if prev_last_pose is not None:
                cur_first = (int(ex[0]), int(ey[0]), int(ez[0]), int(eyaw[0]))
                if cur_first != prev_last_pose:
                    change_indices.append(offset)
            changes = np.where(
                (np.diff(ex) != 0) | (np.diff(ey) != 0) |
                (np.diff(ez) != 0) | (np.diff(eyaw) != 0)
            )[0] + 1
            for c in changes:
                change_indices.append(offset + int(c))
            prev_last_pose = (int(ex[-1]), int(ey[-1]), int(ez[-1]), int(eyaw[-1]))
            del chunk, ex, ey, ez, eyaw
            gc.collect()

    starts = [0] + change_indices
    ends = change_indices + [n]
    frames = []
    with h5py.File(h5_path, "r") as f:
        ds = f[dataset_name]
        for s, e in zip(starts, ends):
            row = ds[s]
            frames.append((s, e, int(row["ego_x"]), int(row["ego_y"]),
                           int(row["ego_z"]), int(row["ego_yaw"])))
    return frames


def read_frame_for_inference(h5_path, start, end, dataset_name="lidar_points"):
    """Read a single frame and compute features for inference.

    Returns:
        xyz_m: (N, 3) float32 — local cartesian coordinates in meters
        features: (N, 5) float32 — [x, y, z, reflectivity_norm, distance_norm]
    """
    with h5py.File(h5_path, "r") as f:
        chunk = f[dataset_name][start:end]

    valid = chunk[chunk["distance_cm"] > 0]
    del chunk

    dist_m = valid["distance_cm"].astype(np.float64) / 100.0
    az_rad = np.radians(valid["azimuth_raw"].astype(np.float64) / 100.0)
    el_rad = np.radians(valid["elevation_raw"].astype(np.float64) / 100.0)

    cos_el = np.cos(el_rad)
    x = dist_m * cos_el * np.cos(az_rad)
    y = -dist_m * cos_el * np.sin(az_rad)
    z = dist_m * np.sin(el_rad)

    xyz = np.column_stack((x, y, z)).astype(np.float32)
    refl_norm = (valid["reflectivity"].astype(np.float32) / 255.0).reshape(-1, 1)
    dist_norm = (dist_m.astype(np.float32) / 300.0).reshape(-1, 1)

    features = np.concatenate([xyz, refl_norm, dist_norm], axis=1)  # (N, 5)

    del valid, dist_m, az_rad, el_rad, cos_el, x, y, z
    return xyz, features


# ============================================================================
# CHUNKED INFERENCE
# ============================================================================

@torch.no_grad()
def predict_frame(model, features_np, device, chunk_size=65536,
                  confidence_threshold=None):
    """Run inference on a full frame, chunked to avoid OOM.

    Returns:
        predictions: (N,) numpy int64 — class IDs [0..4]
        confidences: (N,) numpy float32 — softmax probability of predicted class
    """
    n = len(features_np)
    predictions = np.zeros(n, dtype=np.int64)
    confidences = np.zeros(n, dtype=np.float32)

    for start in range(0, n, chunk_size):
        end = min(start + chunk_size, n)
        chunk = features_np[start:end]

        # Pad to at least 128 points (BatchNorm needs reasonable batch stats)
        pad_to = max(len(chunk), 128)
        if len(chunk) < pad_to:
            padded = np.zeros((pad_to, chunk.shape[1]), dtype=np.float32)
            padded[:len(chunk)] = chunk
        else:
            padded = chunk

        tensor = torch.from_numpy(padded).unsqueeze(0).to(device)  # (1, N, 5)
        logits = model(tensor)  # (1, N, 5)
        probs = F.softmax(logits[0, :len(chunk)], dim=-1)  # (N, 5)
        conf, preds = probs.max(dim=-1)
        preds = preds.cpu().numpy()
        conf = conf.cpu().numpy()

        # Per-class confidence threshold: low-confidence obstacle predictions -> background
        for cid in range(1, 5):
            thresh = CONFIDENCE_THRESHOLD_PER_CLASS.get(cid, CONFIDENCE_THRESHOLD_DEFAULT)
            low_conf = (preds == cid) & (conf < thresh)
            preds[low_conf] = 0

        predictions[start:end] = preds
        confidences[start:end] = conf

        del tensor, logits, probs, preds, conf
    return predictions, confidences


# ============================================================================
# CLUSTERING + BOUNDING BOXES
# ============================================================================

def pca_oriented_bbox(points_m):
    """Compute PCA-oriented bounding box for a cluster of points."""
    center_xyz = points_m.mean(axis=0)
    centered = points_m - center_xyz
    cov = np.cov(centered.T)
    if np.any(np.isnan(cov)) or np.any(np.isinf(cov)):
        mins = points_m.min(axis=0)
        maxs = points_m.max(axis=0)
        return {"center_xyz": (mins + maxs) / 2.0, "dimensions": maxs - mins, "yaw": 0.0}
    try:
        eigenvalues, eigenvectors = np.linalg.eigh(cov)
    except np.linalg.LinAlgError:
        mins = points_m.min(axis=0)
        maxs = points_m.max(axis=0)
        return {"center_xyz": (mins + maxs) / 2.0, "dimensions": maxs - mins, "yaw": 0.0}
    order = eigenvalues.argsort()[::-1]
    eigenvectors = eigenvectors[:, order]
    projected = centered @ eigenvectors
    mins = projected.min(axis=0)
    maxs = projected.max(axis=0)
    dimensions = maxs - mins
    box_center_pca = (mins + maxs) / 2.0
    center_xyz = center_xyz + eigenvectors @ box_center_pca
    axis1_xy = eigenvectors[:2, 0]
    yaw = np.arctan2(axis1_xy[1], axis1_xy[0])
    return {"center_xyz": center_xyz, "dimensions": dimensions, "yaw": float(yaw)}


def cluster_class_points(points_m, class_id, max_points=10000):
    """DBSCAN clustering for a single class. Returns list of point arrays."""
    params = DBSCAN_PARAMS[class_id]
    eps, min_samples = params["eps"], params["min_samples"]
    if len(points_m) < min_samples:
        return []

    full_points = points_m
    if len(points_m) > max_points:
        idx = np.random.choice(len(points_m), max_points, replace=False)
        points_m = points_m[idx]

    labels = DBSCAN(eps=eps, min_samples=min_samples, algorithm="ball_tree").fit_predict(points_m)

    if len(full_points) > max_points:
        sampled_mask = labels >= 0
        if sampled_mask.sum() == 0:
            return []
        tree = BallTree(points_m[sampled_mask])
        _, indices = tree.query(full_points, k=1)
        full_labels = labels[sampled_mask][indices.ravel()]
        dists = np.linalg.norm(full_points - points_m[sampled_mask][indices.ravel()], axis=1)
        full_labels[dists > eps * 2] = -1
        labels = full_labels
        points_m = full_points

    clusters = []
    for lbl in sorted(set(labels) - {-1}):
        clusters.append(points_m[labels == lbl])
    return clusters


def merge_cable_clusters(clusters):
    """Merge collinear cable clusters that are close together."""
    if len(clusters) <= 1:
        return clusters
    angle_thresh = np.radians(CABLE_MERGE_ANGLE_DEG)
    gap_thresh = CABLE_MERGE_GAP_M
    infos = []
    for pts in clusters:
        if len(pts) < 4:
            infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": None})
            continue
        centered = pts - pts.mean(axis=0)
        cov = np.cov(centered.T)
        if np.any(np.isnan(cov)) or np.any(np.isinf(cov)):
            infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": None})
            continue
        try:
            eigvals, eigvecs = np.linalg.eigh(cov)
        except np.linalg.LinAlgError:
            infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": None})
            continue
        axis1 = eigvecs[:, eigvals.argsort()[::-1][0]]
        if axis1[0] < 0:
            axis1 = -axis1
        infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": axis1})

    merged_flags = [False] * len(infos)
    result = []
    for i in range(len(infos)):
        if merged_flags[i]:
            continue
        current = infos[i]["points"]
        if infos[i]["axis1"] is not None:
            for j in range(i + 1, len(infos)):
                if merged_flags[j] or infos[j]["axis1"] is None:
                    continue
                dot = min(abs(np.dot(infos[i]["axis1"], infos[j]["axis1"])), 1.0)
                if np.arccos(dot) > angle_thresh:
                    continue
                cdist = np.linalg.norm(infos[i]["center"] - infos[j]["center"])
                ext_i = np.abs((infos[i]["points"] - infos[i]["center"]) @ infos[i]["axis1"]).max()
                ext_j = np.abs((infos[j]["points"] - infos[j]["center"]) @ infos[j]["axis1"]).max()
                if cdist - ext_i - ext_j <= gap_thresh:
                    current = np.vstack([current, infos[j]["points"]])
                    merged_flags[j] = True
        result.append(current)
    return result


# ============================================================================
# POST-PROCESSING: SIZE FILTER + NMS + GEOMETRIC RECLASSIFICATION
# ============================================================================

def filter_boxes(boxes):
    """Remove boxes that are too small (few points) or too large (over-merged)."""
    filtered = []
    for box in boxes:
        cid = box["class_id"]
        if box["num_points"] < MIN_POINTS_PER_BOX.get(cid, 3):
            continue
        max_dim = max(box["dimensions"])
        if max_dim > MAX_DIM_PER_CLASS.get(cid, 500.0):
            continue
        filtered.append(box)
    return filtered


def _box_iou_3d(box_a, box_b):
    """Approximate 3D IoU using axis-aligned overlap of PCA extents."""
    ca, da = box_a["center_xyz"], box_a["dimensions"]
    cb, db = box_b["center_xyz"], box_b["dimensions"]
    ha, hb = da / 2.0, db / 2.0
    overlap = np.maximum(0, np.minimum(ca + ha, cb + hb) - np.maximum(ca - ha, cb - hb))
    inter = overlap[0] * overlap[1] * overlap[2]
    vol_a = da[0] * da[1] * da[2]
    vol_b = db[0] * db[1] * db[2]
    union = vol_a + vol_b - inter
    if union <= 0:
        return 0.0
    return inter / union


def nms_boxes(boxes, iou_threshold=NMS_IOU_THRESHOLD):
    """Non-Maximum Suppression within each class. Keep box with more points."""
    if len(boxes) <= 1:
        return boxes
    by_class = {}
    for box in boxes:
        by_class.setdefault(box["class_id"], []).append(box)
    result = []
    for cid, class_boxes in by_class.items():
        class_boxes.sort(key=lambda b: b["num_points"], reverse=True)
        keep = []
        suppressed = [False] * len(class_boxes)
        for i in range(len(class_boxes)):
            if suppressed[i]:
                continue
            keep.append(class_boxes[i])
            for j in range(i + 1, len(class_boxes)):
                if suppressed[j]:
                    continue
                iou = _box_iou_3d(class_boxes[i], class_boxes[j])
                if iou > iou_threshold:
                    suppressed[j] = True
        result.extend(keep)
    return result


def reclassify_by_geometry(boxes):
    """Reclassify boxes based on geometric properties.

    Fixes common model confusions:
    - Antenna classified but shape is elongated + flat -> likely Cable
    - Antenna classified but very large + many points -> likely Wind Turbine
    """
    for box in boxes:
        if box["class_id"] != 1:  # only reclassify from antenna
            continue
        dims = box["dimensions"]
        sorted_dims = sorted(dims, reverse=True)  # [longest, middle, shortest]
        longest, middle, shortest = sorted_dims

        # Elongated + flat -> Cable
        if middle > 0 and longest / middle > 5.0 and shortest < 1.0:
            box["class_id"] = 2
            box["class_label"] = CLASS_LABELS_CSV[2]

        # Very large + many points -> Wind Turbine
        elif longest > 15.0 and box["num_points"] > 200:
            box["class_id"] = 4
            box["class_label"] = CLASS_LABELS_CSV[4]

    return boxes


# ============================================================================
# PREDICTIONS -> BOUNDING BOXES
# ============================================================================

def predictions_to_boxes(xyz_m, predictions, confidences=None,
                         use_per_class_conf=True):
    """Convert per-point predictions to bounding boxes via DBSCAN + PCA + post-processing.

    Pipeline: cluster -> PCA bbox -> geometric reclassification -> confidence filter -> size filter -> NMS
    """
    boxes = []
    for cid in range(1, 5):
        mask = predictions == cid
        n_pts = mask.sum()
        if n_pts == 0:
            continue

        class_points = xyz_m[mask]
        class_conf = confidences[mask] if confidences is not None else None

        clusters = cluster_class_points(class_points, cid)

        if cid == 2 and len(clusters) > 1:
            clusters = merge_cable_clusters(clusters)

        # Build BallTree once per class for confidence lookup
        conf_tree = None
        if class_conf is not None and len(clusters) > 0:
            conf_tree = BallTree(class_points)

        for pts in clusters:
            if len(pts) < 3:
                continue

            # Compute mean confidence for this cluster
            box_confidence = 0.0
            if conf_tree is not None:
                _, indices = conf_tree.query(pts, k=1)
                box_confidence = float(class_conf[indices.ravel()].mean())

            bbox = pca_oriented_bbox(pts)
            boxes.append({
                "center_xyz": bbox["center_xyz"],
                "dimensions": bbox["dimensions"],
                "yaw": bbox["yaw"],
                "class_id": cid,
                "class_label": CLASS_LABELS_CSV[cid],
                "num_points": len(pts),
                "confidence": box_confidence,
            })

    # Geometric reclassification (before confidence filter)
    boxes = reclassify_by_geometry(boxes)

    # Per-class box confidence filter
    if use_per_class_conf:
        filtered = []
        for b in boxes:
            thresh = BOX_CONFIDENCE_THRESHOLD_PER_CLASS.get(
                b["class_id"], BOX_CONFIDENCE_THRESHOLD_DEFAULT)
            if b["confidence"] >= thresh:
                filtered.append(b)
        boxes = filtered

    # Post-processing
    boxes = filter_boxes(boxes)
    boxes = nms_boxes(boxes)
    return boxes


print("Pipeline functions loaded.")

# 3D Visualization with Plotly

In [None]:
def create_3d_wireframe_box(center, dims, yaw, color, name=""):
    """Create 3D wireframe box as Plotly Scatter3d traces.
    Returns list of traces for the 12 edges of the box.
    """
    cx, cy, cz = center
    w, l, h = dims
    hw, hl, hh = w / 2, l / 2, h / 2

    # 8 corners in local frame
    corners_local = np.array([
        [-hw, -hl, -hh], [ hw, -hl, -hh], [ hw,  hl, -hh], [-hw,  hl, -hh],
        [-hw, -hl,  hh], [ hw, -hl,  hh], [ hw,  hl,  hh], [-hw,  hl,  hh],
    ])

    # Rotate by yaw around Z axis
    cos_y, sin_y = np.cos(yaw), np.sin(yaw)
    R = np.array([[cos_y, -sin_y, 0], [sin_y, cos_y, 0], [0, 0, 1]])
    corners = (R @ corners_local.T).T + center

    # 12 edges
    edges = [
        [0, 1], [1, 2], [2, 3], [3, 0],  # bottom
        [4, 5], [5, 6], [6, 7], [7, 4],  # top
        [0, 4], [1, 5], [2, 6], [3, 7],  # vertical
    ]

    traces = []
    for i, (a, b) in enumerate(edges):
        traces.append(go.Scatter3d(
            x=[corners[a, 0], corners[b, 0]],
            y=[corners[a, 1], corners[b, 1]],
            z=[corners[a, 2], corners[b, 2]],
            mode='lines',
            line=dict(color=color, width=4),
            name=name if i == 0 else "",
            showlegend=(i == 0),
            legendgroup=name,
            hoverinfo='name',
        ))
    return traces


def visualize_frame_3d(xyz_m, predictions, boxes, frame_idx):
    """Create interactive 3D plotly figure for one frame."""
    fig = go.Figure()

    # Subsample points for display
    bg_mask = predictions == 0
    obs_mask = ~bg_mask

    # Background: heavy subsample
    bg_indices = np.where(bg_mask)[0]
    if len(bg_indices) > MAX_BG_POINTS:
        bg_indices = np.random.choice(bg_indices, MAX_BG_POINTS, replace=False)

    # Obstacles: lighter subsample
    obs_indices = np.where(obs_mask)[0]
    if len(obs_indices) > MAX_OBS_POINTS:
        obs_indices = np.random.choice(obs_indices, MAX_OBS_POINTS, replace=False)

    # Plot background points
    if len(bg_indices) > 0:
        fig.add_trace(go.Scatter3d(
            x=xyz_m[bg_indices, 0],
            y=xyz_m[bg_indices, 1],
            z=xyz_m[bg_indices, 2],
            mode='markers',
            marker=dict(size=POINT_SIZE, color=CLASS_COLORS_PLOTLY[0], opacity=0.15),
            name=f'Background ({len(bg_indices):,}pts)',
            hoverinfo='name',
        ))

    # Plot obstacle points by class
    for cid in [1, 2, 3, 4]:
        class_mask_in_obs = predictions[obs_indices] == cid
        cid_indices = obs_indices[class_mask_in_obs]
        if len(cid_indices) == 0:
            continue
        fig.add_trace(go.Scatter3d(
            x=xyz_m[cid_indices, 0],
            y=xyz_m[cid_indices, 1],
            z=xyz_m[cid_indices, 2],
            mode='markers',
            marker=dict(size=POINT_SIZE + 1, color=CLASS_COLORS_PLOTLY[cid], opacity=0.8),
            name=f'{CLASS_NAMES[cid]} ({len(cid_indices):,}pts)',
        ))

    # Add bounding boxes as wireframes
    box_counts = {1: 0, 2: 0, 3: 0, 4: 0}
    for box in boxes:
        cid = box['class_id']
        box_counts[cid] += 1
        color = BOX_COLORS_PLOTLY[cid]
        label = f"{CLASS_NAMES[cid]} box #{box_counts[cid]} ({box['num_points']}pts, conf={box['confidence']:.2f})"
        traces = create_3d_wireframe_box(
            box['center_xyz'], box['dimensions'], box['yaw'],
            color=color, name=label
        )
        for t in traces:
            fig.add_trace(t)

    # Layout
    box_summary = ", ".join(f"{CLASS_NAMES[c]}={box_counts[c]}" for c in [1, 2, 3, 4] if box_counts[c] > 0)
    fig.update_layout(
        title=f"Frame {frame_idx} — {len(boxes)} boxes ({box_summary})",
        scene=dict(
            xaxis_title='X (m)',
            yaxis_title='Y (m)',
            zaxis_title='Z (m)',
            aspectmode='data',  # preserve real proportions
            bgcolor='rgb(20, 20, 30)',
        ),
        width=1200,
        height=800,
        legend=dict(
            yanchor="top", y=0.99,
            xanchor="left", x=0.01,
            bgcolor="rgba(255,255,255,0.8)",
            font=dict(size=10),
        ),
        paper_bgcolor='rgb(30, 30, 40)',
        font=dict(color='white'),
    )

    return fig


print("Visualization functions loaded.")

# Run inference and visualize

In [None]:
h5_path = os.path.join(INPUT_DIR, f"{SINGLE_SCENE}.h5")
print(f"Loading frame boundaries from {h5_path}...")
t0 = time.time()
frames_info = get_frame_boundaries(h5_path)
print(f"{len(frames_info)} frames found ({time.time() - t0:.1f}s)")

# Select frame
idx = min(FRAME_INDEX, len(frames_info) - 1)
start, end, ego_x, ego_y, ego_z, ego_yaw = frames_info[idx]
print(f"\nProcessing frame {idx} (points {start}-{end})...")

# Read + inference
xyz_m, features = read_frame_for_inference(h5_path, start, end)
print(f"  {len(xyz_m):,} points")

predictions, confidences = predict_frame(model, features, device)
del features

# Boxes
boxes = predictions_to_boxes(xyz_m, predictions, confidences)
print(f"  {len(boxes)} boxes detected")
for cid in [1, 2, 3, 4]:
    n = sum(1 for b in boxes if b['class_id'] == cid)
    if n > 0:
        print(f"    {CLASS_NAMES[cid]}: {n}")

# Visualize
print("\nRendering 3D visualization...")
fig = visualize_frame_3d(xyz_m, predictions, boxes, idx)
fig.show()
print("Done! Rotate/zoom/pan with your mouse.")

del xyz_m, predictions, confidences; gc.collect()

# Try another frame

Change `NEW_FRAME` below and run the cell to visualize a different frame.

No need to re-run the frame boundaries or reload the model — just pick a new index.

Interesting frames to try: **27, 30, 44, 65, 69, 77** (lots of obstacles).

In [None]:
# Change this and run this cell
NEW_FRAME = 65  # try different frames: 27, 30, 44, 65, 69, 77

idx = min(NEW_FRAME, len(frames_info) - 1)
start, end, ego_x, ego_y, ego_z, ego_yaw = frames_info[idx]
xyz_m, features = read_frame_for_inference(h5_path, start, end)
predictions, confidences = predict_frame(model, features, device)
del features
boxes = predictions_to_boxes(xyz_m, predictions, confidences)
print(f"Frame {idx}: {len(xyz_m):,} points, {len(boxes)} boxes")
for cid in [1, 2, 3, 4]:
    n = sum(1 for b in boxes if b['class_id'] == cid)
    if n > 0:
        print(f"  {CLASS_NAMES[cid]}: {n}")

fig = visualize_frame_3d(xyz_m, predictions, boxes, idx)
fig.show()

del xyz_m, predictions, confidences; gc.collect()