# 07 — Visualizations (v7.3)

**Notebook Colab** pour générer les screenshots de visualisation (Livrable #4 Airbus).

Produit jusqu'à 10 PNGs montrant les nuages de points avec les bounding boxes 3D prédites, colorées par classe.
Deux vues par frame : **top-down (XY)** et **side view (XZ)**.

**Pipeline :** HDF5 → inference PointNetSegV4 → clustering DBSCAN → PCA bounding boxes → matplotlib 2-panel figure → PNG

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install -q h5py scikit-learn matplotlib

In [None]:
import gc
import os
import time

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D

import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.cluster import DBSCAN
from sklearn.neighbors import BallTree

# ==========================================================================
# PATHS
# ==========================================================================
DRIVE_BASE = "/content/drive/MyDrive/airbus_hackathon"
INPUT_DIR = f"{DRIVE_BASE}/data"
OUTPUT_DIR = f"{DRIVE_BASE}/outputs/visualizations_v73"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Checkpoint (auto-select best)
CKPT_V5 = f"{DRIVE_BASE}/checkpoints_v5/best_model_v5.pt"
CKPT_V4 = f"{DRIVE_BASE}/checkpoints_v4/best_model_v4.pt"
CKPT_PATH = CKPT_V5 if os.path.exists(CKPT_V5) else CKPT_V4

SINGLE_SCENE = "scene_8"  # Scene to visualize
NUM_FRAMES = 10           # Number of frames to visualize

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Checkpoint: {CKPT_PATH}")
print(f"Input dir:  {INPUT_DIR}")
print(f"Output dir: {OUTPUT_DIR}")
print(f"Device:     {device}")
print(f"Scene:      {SINGLE_SCENE}")
print(f"Frames:     {NUM_FRAMES}")
if torch.cuda.is_available():
    print(f"GPU:        {torch.cuda.get_device_name()}")
    print(f"VRAM:       {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

# Config (v7.3)

In [None]:
NUM_CLASSES = 5
IN_CHANNELS = 5  # x, y, z, reflectivity, norm_distance
CHUNK_SIZE = 65536  # points per forward pass

CLASS_NAMES = {0: "Background", 1: "Antenna", 2: "Cable", 3: "Electric pole", 4: "Wind turbine"}

# Airbus spec labels
CLASS_LABELS_CSV = {1: "Antenna", 2: "Cable", 3: "Electric Pole", 4: "Wind Turbine"}

# Internal class_id (1-4) -> Airbus class_ID (0-3)
CLASS_ID_TO_AIRBUS = {1: 0, 2: 1, 3: 2, 4: 3}

# Colors for visualization (RGBA)
CLASS_COLORS = {
    0: (0.7, 0.7, 0.7, 0.05),    # background — very transparent grey
    1: (0.15, 0.09, 0.71, 0.9),   # antenna — blue (from RGB 38,23,180)
    2: (0.69, 0.52, 0.18, 0.9),   # cable — gold (from RGB 177,132,47)
    3: (0.51, 0.32, 0.38, 0.9),   # electric pole — mauve (from RGB 129,81,97)
    4: (0.26, 0.52, 0.04, 0.9),   # wind turbine — green (from RGB 66,132,9)
}

# Bounding box edge colors (opaque, vivid)
BOX_COLORS = {
    1: "#2617B4",   # antenna — blue
    2: "#B18430",   # cable — gold
    3: "#815161",   # electric pole — mauve
    4: "#428409",   # wind turbine — green
}

# DBSCAN per-class params
DBSCAN_PARAMS = {
    1: {"eps": 2.0, "min_samples": 15},    # Antenna
    2: {"eps": 5.0, "min_samples": 5},     # Cable
    3: {"eps": 2.0, "min_samples": 8},     # Electric pole
    4: {"eps": 5.0, "min_samples": 20},    # Wind turbine
}

CABLE_MERGE_ANGLE_DEG = 15.0
CABLE_MERGE_GAP_M = 10.0

# === v7.3: Per-class confidence thresholds ===
CONFIDENCE_THRESHOLD_PER_CLASS = {
    1: 0.40,  # antenna
    2: 0.27,  # cable
    3: 0.25,  # electric_pole
    4: 0.30,  # wind_turbine
}
CONFIDENCE_THRESHOLD_DEFAULT = 0.3

BOX_CONFIDENCE_THRESHOLD_PER_CLASS = {
    1: 0.70,  # antenna
    2: 0.55,  # cable
    3: 0.45,  # electric_pole
    4: 0.60,  # wind_turbine
}
BOX_CONFIDENCE_THRESHOLD_DEFAULT = 0.6

MIN_POINTS_PER_BOX = {1: 15, 2: 3, 3: 5, 4: 15}
MAX_DIM_PER_CLASS = {1: 200.0, 2: 400.0, 3: 100.0, 4: 250.0}
NMS_IOU_THRESHOLD = 0.3

print("Config v7.3 loaded.")
print(f"Per-class point conf:  {CONFIDENCE_THRESHOLD_PER_CLASS}")
print(f"Per-class box conf:    {BOX_CONFIDENCE_THRESHOLD_PER_CLASS}")

# Model

In [None]:
class SharedMLP(nn.Module):
    def __init__(self, in_ch, out_ch, bn=True):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, 1, bias=not bn)
        self.bn = nn.BatchNorm1d(out_ch) if bn else None
    def forward(self, x):
        x = self.conv(x)
        if self.bn:
            x = self.bn(x)
        return F.relu(x, inplace=True)


class PointNetSegV4(nn.Module):
    """PointNet v4 segmentation — multi-scale skips, ~1.88M params."""
    def __init__(self, in_channels=5, num_classes=5):
        super().__init__()
        self.enc1 = SharedMLP(in_channels, 64)
        self.enc2 = SharedMLP(64, 128)
        self.enc3 = SharedMLP(128, 256)
        self.enc4 = SharedMLP(256, 512)
        self.enc5 = SharedMLP(512, 1024)
        self.seg1 = SharedMLP(64 + 128 + 256 + 512 + 1024, 512)
        self.seg2 = SharedMLP(512, 256)
        self.seg3 = SharedMLP(256, 128)
        self.dropout1 = nn.Dropout(0.4)
        self.dropout2 = nn.Dropout(0.3)
        self.head = nn.Conv1d(128, num_classes, 1)

    def forward(self, x):
        B, N, _ = x.shape
        x = x.transpose(1, 2)
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        g = e5.max(dim=2, keepdim=True)[0].expand(-1, -1, N)
        seg = torch.cat([e1, e2, e3, e4, g], dim=1)
        seg = self.seg1(seg)
        seg = self.dropout1(seg)
        seg = self.seg2(seg)
        seg = self.dropout2(seg)
        seg = self.seg3(seg)
        seg = self.head(seg)
        return seg.transpose(1, 2)


# Load model
print(f"Loading {CKPT_PATH}...")
model = PointNetSegV4(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES).to(device)
ckpt = torch.load(CKPT_PATH, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

n_params = sum(p.numel() for p in model.parameters())
print(f"PointNetSegV4: {n_params:,} params on {device}")
if "val_obstacle_miou" in ckpt:
    print(f"Epoch {ckpt.get('epoch', '?')}, val obstacle mIoU={ckpt['val_obstacle_miou']:.4f}")

# HDF5 Reader + Inference

In [None]:
# === HDF5 READER ===

def get_frame_boundaries(h5_path, dataset_name="lidar_points", chunk_size=2_000_000):
    change_indices = []
    with h5py.File(h5_path, "r") as f:
        ds = f[dataset_name]
        n = ds.shape[0]
        prev_last_pose = None
        for offset in range(0, n, chunk_size):
            end = min(offset + chunk_size, n)
            chunk = ds[offset:end]
            ex, ey, ez, eyaw = chunk["ego_x"], chunk["ego_y"], chunk["ego_z"], chunk["ego_yaw"]
            if prev_last_pose is not None:
                cur_first = (int(ex[0]), int(ey[0]), int(ez[0]), int(eyaw[0]))
                if cur_first != prev_last_pose:
                    change_indices.append(offset)
            changes = np.where(
                (np.diff(ex) != 0) | (np.diff(ey) != 0) |
                (np.diff(ez) != 0) | (np.diff(eyaw) != 0)
            )[0] + 1
            for c in changes:
                change_indices.append(offset + int(c))
            prev_last_pose = (int(ex[-1]), int(ey[-1]), int(ez[-1]), int(eyaw[-1]))
            del chunk, ex, ey, ez, eyaw; gc.collect()
    starts = [0] + change_indices
    ends = change_indices + [n]
    frames = []
    with h5py.File(h5_path, "r") as f:
        ds = f[dataset_name]
        for s, e in zip(starts, ends):
            row = ds[s]
            frames.append((s, e, int(row["ego_x"]), int(row["ego_y"]),
                           int(row["ego_z"]), int(row["ego_yaw"])))
    return frames


def read_frame_for_inference(h5_path, start, end, dataset_name="lidar_points"):
    with h5py.File(h5_path, "r") as f:
        chunk = f[dataset_name][start:end]
    valid = chunk[chunk["distance_cm"] > 0]
    del chunk
    dist_m = valid["distance_cm"].astype(np.float64) / 100.0
    az_rad = np.radians(valid["azimuth_raw"].astype(np.float64) / 100.0)
    el_rad = np.radians(valid["elevation_raw"].astype(np.float64) / 100.0)
    cos_el = np.cos(el_rad)
    x = dist_m * cos_el * np.cos(az_rad)
    y = -dist_m * cos_el * np.sin(az_rad)
    z = dist_m * np.sin(el_rad)
    xyz = np.column_stack((x, y, z)).astype(np.float32)
    refl_norm = (valid["reflectivity"].astype(np.float32) / 255.0).reshape(-1, 1)
    dist_norm = (dist_m.astype(np.float32) / 300.0).reshape(-1, 1)
    features = np.concatenate([xyz, refl_norm, dist_norm], axis=1)
    del valid, dist_m, az_rad, el_rad, cos_el, x, y, z
    return xyz, features


# === INFERENCE (v7.3 — per-class confidence thresholds) ===

@torch.no_grad()
def predict_frame(model, features_np, device, chunk_size=CHUNK_SIZE):
    """Standard inference with per-class confidence thresholds."""
    n = len(features_np)
    predictions = np.zeros(n, dtype=np.int64)
    confidences = np.zeros(n, dtype=np.float32)
    for start in range(0, n, chunk_size):
        end = min(start + chunk_size, n)
        chunk = features_np[start:end]
        pad_to = max(len(chunk), 128)
        if len(chunk) < pad_to:
            padded = np.zeros((pad_to, chunk.shape[1]), dtype=np.float32)
            padded[:len(chunk)] = chunk
        else:
            padded = chunk
        tensor = torch.from_numpy(padded).unsqueeze(0).to(device)
        logits = model(tensor)
        probs = F.softmax(logits[0, :len(chunk)], dim=-1)
        conf, preds = probs.max(dim=-1)
        preds, conf = preds.cpu().numpy(), conf.cpu().numpy()
        # Per-class confidence threshold
        for cid in range(1, 5):
            thresh = CONFIDENCE_THRESHOLD_PER_CLASS.get(cid, CONFIDENCE_THRESHOLD_DEFAULT)
            low_conf = (preds == cid) & (conf < thresh)
            preds[low_conf] = 0
        predictions[start:end] = preds
        confidences[start:end] = conf
        del tensor, logits, probs, preds, conf
    return predictions, confidences


print("HDF5 reader + inference functions defined.")

# Clustering + Post-processing (v7.3)

In [None]:
def pca_oriented_bbox(points_m):
    center_xyz = points_m.mean(axis=0)
    centered = points_m - center_xyz
    cov = np.cov(centered.T)
    if np.any(np.isnan(cov)) or np.any(np.isinf(cov)):
        mins, maxs = points_m.min(axis=0), points_m.max(axis=0)
        return {"center_xyz": (mins + maxs) / 2.0, "dimensions": maxs - mins, "yaw": 0.0}
    try:
        eigenvalues, eigenvectors = np.linalg.eigh(cov)
    except np.linalg.LinAlgError:
        mins, maxs = points_m.min(axis=0), points_m.max(axis=0)
        return {"center_xyz": (mins + maxs) / 2.0, "dimensions": maxs - mins, "yaw": 0.0}
    order = eigenvalues.argsort()[::-1]
    eigenvectors = eigenvectors[:, order]
    projected = centered @ eigenvectors
    mins, maxs = projected.min(axis=0), projected.max(axis=0)
    dimensions = maxs - mins
    box_center_pca = (mins + maxs) / 2.0
    center_xyz = center_xyz + eigenvectors @ box_center_pca
    yaw = np.arctan2(eigenvectors[1, 0], eigenvectors[0, 0])
    return {"center_xyz": center_xyz, "dimensions": dimensions, "yaw": float(yaw)}


def cluster_class_points(points_m, class_id, max_points=10000):
    params = DBSCAN_PARAMS[class_id]
    eps, min_samples = params["eps"], params["min_samples"]
    if len(points_m) < min_samples:
        return []
    full_points = points_m
    if len(points_m) > max_points:
        idx = np.random.choice(len(points_m), max_points, replace=False)
        points_m = points_m[idx]
    labels = DBSCAN(eps=eps, min_samples=min_samples, algorithm="ball_tree").fit_predict(points_m)
    if len(full_points) > max_points:
        sampled_mask = labels >= 0
        if sampled_mask.sum() == 0:
            return []
        tree = BallTree(points_m[sampled_mask])
        _, indices = tree.query(full_points, k=1)
        full_labels = labels[sampled_mask][indices.ravel()]
        dists = np.linalg.norm(full_points - points_m[sampled_mask][indices.ravel()], axis=1)
        full_labels[dists > eps * 2] = -1
        labels = full_labels
        points_m = full_points
    clusters = []
    for lbl in sorted(set(labels) - {-1}):
        clusters.append(points_m[labels == lbl])
    return clusters


def merge_cable_clusters(clusters):
    if len(clusters) <= 1:
        return clusters
    angle_thresh = np.radians(CABLE_MERGE_ANGLE_DEG)
    gap_thresh = CABLE_MERGE_GAP_M
    infos = []
    for pts in clusters:
        if len(pts) < 4:
            infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": None})
            continue
        centered = pts - pts.mean(axis=0)
        cov = np.cov(centered.T)
        if np.any(np.isnan(cov)) or np.any(np.isinf(cov)):
            infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": None})
            continue
        try:
            eigvals, eigvecs = np.linalg.eigh(cov)
        except np.linalg.LinAlgError:
            infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": None})
            continue
        axis1 = eigvecs[:, eigvals.argsort()[::-1][0]]
        if axis1[0] < 0: axis1 = -axis1
        infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": axis1})
    merged_flags = [False] * len(infos)
    result = []
    for i in range(len(infos)):
        if merged_flags[i]: continue
        current = infos[i]["points"]
        if infos[i]["axis1"] is not None:
            for j in range(i + 1, len(infos)):
                if merged_flags[j] or infos[j]["axis1"] is None: continue
                dot = min(abs(np.dot(infos[i]["axis1"], infos[j]["axis1"])), 1.0)
                if np.arccos(dot) > angle_thresh: continue
                cdist = np.linalg.norm(infos[i]["center"] - infos[j]["center"])
                ext_i = np.abs((infos[i]["points"] - infos[i]["center"]) @ infos[i]["axis1"]).max()
                ext_j = np.abs((infos[j]["points"] - infos[j]["center"]) @ infos[j]["axis1"]).max()
                if cdist - ext_i - ext_j <= gap_thresh:
                    current = np.vstack([current, infos[j]["points"]])
                    merged_flags[j] = True
        result.append(current)
    return result


def filter_boxes(boxes):
    return [b for b in boxes
            if b["num_points"] >= MIN_POINTS_PER_BOX.get(b["class_id"], 3)
            and max(b["dimensions"]) <= MAX_DIM_PER_CLASS.get(b["class_id"], 500.0)]


def _box_iou_3d(a, b):
    ca, da, cb, db = a["center_xyz"], a["dimensions"], b["center_xyz"], b["dimensions"]
    ha, hb = da / 2.0, db / 2.0
    overlap = np.maximum(0, np.minimum(ca + ha, cb + hb) - np.maximum(ca - ha, cb - hb))
    inter = overlap[0] * overlap[1] * overlap[2]
    union = da[0]*da[1]*da[2] + db[0]*db[1]*db[2] - inter
    return inter / union if union > 0 else 0.0


def nms_boxes(boxes, iou_threshold=NMS_IOU_THRESHOLD):
    if len(boxes) <= 1: return boxes
    by_class = {}
    for b in boxes: by_class.setdefault(b["class_id"], []).append(b)
    result = []
    for cid, cb in by_class.items():
        cb.sort(key=lambda b: b["num_points"], reverse=True)
        suppressed = [False] * len(cb)
        for i in range(len(cb)):
            if suppressed[i]: continue
            result.append(cb[i])
            for j in range(i+1, len(cb)):
                if not suppressed[j] and _box_iou_3d(cb[i], cb[j]) > iou_threshold:
                    suppressed[j] = True
    return result


def reclassify_by_geometry(boxes):
    """Reclassify boxes based on geometric properties.
    Fixes common model confusions:
    - Antenna classified but shape is elongated + flat -> likely Cable
    - Antenna classified but very large + many points -> likely Wind Turbine
    """
    for box in boxes:
        if box["class_id"] != 1:  # only reclassify from antenna
            continue
        dims = box["dimensions"]
        sorted_dims = sorted(dims, reverse=True)
        longest, middle, shortest = sorted_dims
        # Elongated + flat -> Cable
        if middle > 0 and longest / middle > 5.0 and shortest < 1.0:
            box["class_id"] = 2
            box["class_label"] = CLASS_LABELS_CSV[2]
        # Very large + many points -> Wind Turbine
        elif longest > 15.0 and box["num_points"] > 200:
            box["class_id"] = 4
            box["class_label"] = CLASS_LABELS_CSV[4]
    return boxes


def predictions_to_boxes(xyz_m, predictions, confidences=None):
    """v7.3: cluster -> PCA bbox -> geometric reclassification -> per-class confidence filter -> size filter -> NMS."""
    boxes = []
    for cid in range(1, 5):
        mask = predictions == cid
        if mask.sum() == 0: continue
        class_points = xyz_m[mask]
        class_conf = confidences[mask] if confidences is not None else None
        clusters = cluster_class_points(class_points, cid)
        if cid == 2 and len(clusters) > 1:
            clusters = merge_cable_clusters(clusters)
        conf_tree = None
        if class_conf is not None and len(clusters) > 0:
            conf_tree = BallTree(class_points)
        for pts in clusters:
            if len(pts) < 3: continue
            box_confidence = 0.0
            if conf_tree is not None:
                _, indices = conf_tree.query(pts, k=1)
                box_confidence = float(class_conf[indices.ravel()].mean())
            bbox = pca_oriented_bbox(pts)
            boxes.append({
                "center_xyz": bbox["center_xyz"], "dimensions": bbox["dimensions"],
                "yaw": bbox["yaw"], "class_id": cid, "class_label": CLASS_LABELS_CSV[cid],
                "num_points": len(pts), "confidence": box_confidence,
            })
    # Geometric reclassification (before confidence filter)
    boxes = reclassify_by_geometry(boxes)
    # Per-class box confidence filter
    filtered = []
    for b in boxes:
        thresh = BOX_CONFIDENCE_THRESHOLD_PER_CLASS.get(
            b["class_id"], BOX_CONFIDENCE_THRESHOLD_DEFAULT)
        if b["confidence"] >= thresh:
            filtered.append(b)
    boxes = filtered
    boxes = filter_boxes(boxes)
    boxes = nms_boxes(boxes)
    return boxes


print("Clustering + post-processing v7.3 defined.")

# Visualization functions

In [None]:
def draw_rotated_box_2d(ax, cx, cy, w, h, yaw, color, linewidth=2):
    """Draw a rotated rectangle on a 2D matplotlib axes."""
    cos_y, sin_y = np.cos(yaw), np.sin(yaw)
    # 4 corners of the box before rotation
    corners = np.array([
        [-w/2, -h/2],
        [+w/2, -h/2],
        [+w/2, +h/2],
        [-w/2, +h/2],
        [-w/2, -h/2],  # close the box
    ])
    # Rotate
    rot = np.array([[cos_y, -sin_y], [sin_y, cos_y]])
    rotated = corners @ rot.T
    rotated[:, 0] += cx
    rotated[:, 1] += cy
    ax.plot(rotated[:, 0], rotated[:, 1], color=color, linewidth=linewidth, solid_capstyle='round')


def render_frame(xyz_m, predictions, boxes, frame_idx, ego_info, output_path,
                 max_display_points=100000):
    """Render a frame with two views: top-down (XY) and side (XZ)."""

    # Subsample points for display
    if len(xyz_m) > max_display_points:
        idx = np.random.choice(len(xyz_m), max_display_points, replace=False)
        xyz_disp = xyz_m[idx]
        pred_disp = predictions[idx]
    else:
        xyz_disp = xyz_m
        pred_disp = predictions

    # Assign colors to points
    colors = np.array([CLASS_COLORS[p] for p in pred_disp])

    fig, axes = plt.subplots(1, 2, figsize=(20, 9))

    ego_x, ego_y, ego_z, ego_yaw = ego_info

    # --- Top-down view (XY) ---
    ax = axes[0]
    # Background points first (below), then obstacle points on top
    bg_mask = pred_disp == 0
    obs_mask = ~bg_mask

    ax.scatter(xyz_disp[bg_mask, 0], xyz_disp[bg_mask, 1],
               c=colors[bg_mask], s=0.1, rasterized=True)
    ax.scatter(xyz_disp[obs_mask, 0], xyz_disp[obs_mask, 1],
               c=colors[obs_mask], s=1.5, rasterized=True)

    # Draw bounding boxes (top-down: use X, Y, width, length, yaw)
    for box in boxes:
        c = box["center_xyz"]
        d = box["dimensions"]
        draw_rotated_box_2d(ax, c[0], c[1], d[0], d[1], box["yaw"],
                           color=BOX_COLORS[box["class_id"]], linewidth=2.5)

    ax.set_xlabel("X (m)", fontsize=11)
    ax.set_ylabel("Y (m)", fontsize=11)
    ax.set_title("Top-down view (XY)", fontsize=13, fontweight="bold")
    ax.set_aspect("equal")
    ax.grid(True, alpha=0.3)

    # --- Side view (XZ) ---
    ax = axes[1]
    ax.scatter(xyz_disp[bg_mask, 0], xyz_disp[bg_mask, 2],
               c=colors[bg_mask], s=0.1, rasterized=True)
    ax.scatter(xyz_disp[obs_mask, 0], xyz_disp[obs_mask, 2],
               c=colors[obs_mask], s=1.5, rasterized=True)

    # Draw bounding boxes (side: use X, Z, width, height, no rotation)
    for box in boxes:
        c = box["center_xyz"]
        d = box["dimensions"]
        # Side view: axis-aligned rectangle (X, Z)
        rect = patches.Rectangle(
            (c[0] - d[0]/2, c[2] - d[2]/2), d[0], d[2],
            linewidth=2.5, edgecolor=BOX_COLORS[box["class_id"]],
            facecolor="none", linestyle="-"
        )
        ax.add_patch(rect)

    ax.set_xlabel("X (m)", fontsize=11)
    ax.set_ylabel("Z (m)", fontsize=11)
    ax.set_title("Side view (XZ)", fontsize=13, fontweight="bold")
    ax.set_aspect("equal")
    ax.grid(True, alpha=0.3)

    # Legend
    legend_elements = [
        Line2D([0], [0], color=BOX_COLORS[1], linewidth=3, label=f"Antenna ({sum(1 for b in boxes if b['class_id']==1)})"),
        Line2D([0], [0], color=BOX_COLORS[2], linewidth=3, label=f"Cable ({sum(1 for b in boxes if b['class_id']==2)})"),
        Line2D([0], [0], color=BOX_COLORS[3], linewidth=3, label=f"Electric pole ({sum(1 for b in boxes if b['class_id']==3)})"),
        Line2D([0], [0], color=BOX_COLORS[4], linewidth=3, label=f"Wind turbine ({sum(1 for b in boxes if b['class_id']==4)})"),
    ]
    fig.legend(handles=legend_elements, loc="lower center", ncol=4, fontsize=11,
               frameon=True, fancybox=True, shadow=True)

    # Title
    n_obs = sum(1 for p in predictions if p > 0)
    fig.suptitle(
        f"Frame {frame_idx} — {len(boxes)} detections — "
        f"{n_obs:,} obstacle pts / {len(predictions):,} total — "
        f"PointNetSegV4 (v7.3)",
        fontsize=14, fontweight="bold", y=0.98
    )

    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    plt.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white")
    plt.close(fig)
    print(f"  Saved: {output_path}", flush=True)


def select_diverse_frames(all_frame_results, num_frames=10):
    """Select frames that showcase different class combinations and box counts."""
    if len(all_frame_results) <= num_frames:
        return all_frame_results

    # Sort by number of distinct classes detected (descending), then by box count
    scored = []
    for fr in all_frame_results:
        classes_present = set(b["class_id"] for b in fr["boxes"])
        scored.append((len(classes_present), len(fr["boxes"]), fr))

    scored.sort(key=lambda x: (x[0], x[1]), reverse=True)

    selected = []
    seen_class_combos = set()

    # First pass: pick frames with unique class combinations
    for n_cls, n_box, fr in scored:
        if len(selected) >= num_frames:
            break
        combo = frozenset(b["class_id"] for b in fr["boxes"])
        if combo not in seen_class_combos:
            selected.append(fr)
            seen_class_combos.add(combo)

    # Second pass: fill remaining slots with highest box count frames
    if len(selected) < num_frames:
        selected_indices = set(fr["frame_idx"] for fr in selected)
        for n_cls, n_box, fr in scored:
            if len(selected) >= num_frames:
                break
            if fr["frame_idx"] not in selected_indices:
                selected.append(fr)
                selected_indices.add(fr["frame_idx"])

    return selected


print("Visualization functions defined.")

# Generate visualizations

In [None]:
# === MAIN: Run inference on all frames, select best, render PNGs ===

h5_path = os.path.join(INPUT_DIR, f"{SINGLE_SCENE}.h5")
assert os.path.exists(h5_path), f"File not found: {h5_path}"
scene_name = SINGLE_SCENE

print(f"Reading frame boundaries from {h5_path}...")
t0 = time.time()
frames_info = get_frame_boundaries(h5_path)
n_frames = len(frames_info)
print(f"Found {n_frames} frames ({time.time()-t0:.1f}s)")

# Run inference on ALL frames to find the best ones
print(f"\nRunning inference on all {n_frames} frames to select best {NUM_FRAMES}...")
all_results = []
t0 = time.time()

for idx in range(n_frames):
    start, end, ego_x, ego_y, ego_z, ego_yaw = frames_info[idx]
    xyz_m, features = read_frame_for_inference(h5_path, start, end)
    if len(xyz_m) == 0:
        continue

    predictions, confidences = predict_frame(model, features, device)
    del features

    boxes = predictions_to_boxes(xyz_m, predictions, confidences)

    if boxes:  # Only consider frames with detections
        all_results.append({
            "frame_idx": idx,
            "xyz_m": xyz_m,
            "predictions": predictions,
            "boxes": boxes,
            "ego_info": (ego_x, ego_y, ego_z, ego_yaw),
        })
    else:
        del xyz_m, predictions
        gc.collect()

    del confidences

    if (idx + 1) % 20 == 0:
        print(f"  {idx+1}/{n_frames} frames processed, {len(all_results)} with detections", flush=True)

elapsed_inf = time.time() - t0
print(f"\nInference done: {len(all_results)} frames with detections ({elapsed_inf:.0f}s)")

# Select diverse frames
selected = select_diverse_frames(all_results, NUM_FRAMES)
print(f"Selected {len(selected)} frames for visualization")

# Free unselected frames
selected_indices = set(fr["frame_idx"] for fr in selected)
for fr in all_results:
    if fr["frame_idx"] not in selected_indices:
        del fr["xyz_m"], fr["predictions"], fr["boxes"]
del all_results
gc.collect()

# Render selected frames
print(f"\nRendering visualizations...")
t_render = time.time()

for i, fr in enumerate(selected):
    filename = f"{scene_name}_frame{fr['frame_idx']:03d}.png"
    output_path = os.path.join(OUTPUT_DIR, filename)

    classes_in_frame = sorted(set(b["class_id"] for b in fr["boxes"]))
    class_names = [CLASS_NAMES[c] for c in classes_in_frame]
    print(f"\n  [{i+1}/{len(selected)}] Frame {fr['frame_idx']}: "
          f"{len(fr['boxes'])} boxes ({', '.join(class_names)})", flush=True)

    render_frame(
        fr["xyz_m"], fr["predictions"], fr["boxes"],
        fr["frame_idx"], fr["ego_info"], output_path
    )

    del fr["xyz_m"], fr["predictions"]
    gc.collect()

elapsed_render = time.time() - t_render
elapsed_total = time.time() - t0 + (t0 - t0)  # total from inference start

print(f"\n{'='*60}")
print(f"DONE! {len(selected)} visualizations saved to {OUTPUT_DIR}/")
print(f"Inference: {elapsed_inf:.0f}s, Rendering: {elapsed_render:.0f}s")
print(f"{'='*60}")

# List generated files
print(f"\nGenerated files:")
for f in sorted(os.listdir(OUTPUT_DIR)):
    if f.endswith(".png"):
        fpath = os.path.join(OUTPUT_DIR, f)
        size_mb = os.path.getsize(fpath) / 1e6
        print(f"  {f} ({size_mb:.1f} MB)")