# 05 — Inference Pipeline

**Story 3** — Full inference: HDF5 → PointNetSegV4 → DBSCAN → Bounding Boxes → CSV

Run on Colab T4/A100. All code inline (no `src/` dependency).

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install -q h5py scikit-learn

In [None]:
import gc
import glob
import os
import sys
import time

import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.cluster import DBSCAN

DRIVE_BASE = "/content/drive/MyDrive/airbus_hackathon"
DATA_DIR = f"{DRIVE_BASE}/data"
CKPT_DIR = f"{DRIVE_BASE}/checkpoints_v4"
OUTPUT_DIR = f"{DRIVE_BASE}/outputs/pred_v4"
os.makedirs(OUTPUT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

## Config

In [None]:
NUM_CLASSES = 5
IN_CHANNELS = 5  # x, y, z, reflectivity, norm_distance

CLASS_NAMES = {0: "background", 1: "antenna", 2: "cable", 3: "electric_pole", 4: "wind_turbine"}
CLASS_LABELS_CSV = {1: "Antenna", 2: "Cable", 3: "Electric pole", 4: "Wind turbine"}

DBSCAN_PARAMS = {
    1: {"eps": 2.0, "min_samples": 10},
    2: {"eps": 5.0, "min_samples": 5},
    3: {"eps": 2.0, "min_samples": 10},
    4: {"eps": 5.0, "min_samples": 15},
}

CABLE_MERGE_ANGLE_DEG = 15.0
CABLE_MERGE_GAP_M = 10.0

CHUNK_SIZE = 65536  # points per forward pass

# CSV header — Airbus deliverable format
CSV_HEADER = ("ego_x,ego_y,ego_z,ego_yaw,"
              "bbox_center_x,bbox_center_y,bbox_center_z,"
              "bbox_width,bbox_length,bbox_height,"
              "bbox_yaw,"
              "class_ID,class_label\n")

print("Config loaded.")

## Model (PointNetSegV4 — inline)

In [None]:
class SharedMLP(nn.Module):
    def __init__(self, in_ch, out_ch, bn=True):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, 1, bias=not bn)
        self.bn = nn.BatchNorm1d(out_ch) if bn else None
    def forward(self, x):
        x = self.conv(x)
        if self.bn:
            x = self.bn(x)
        return F.relu(x, inplace=True)


class PointNetSegV4(nn.Module):
    """PointNet v4 segmentation — multi-scale skips, ~1.88M params."""
    def __init__(self, in_channels=5, num_classes=5):
        super().__init__()
        self.enc1 = SharedMLP(in_channels, 64)
        self.enc2 = SharedMLP(64, 128)
        self.enc3 = SharedMLP(128, 256)
        self.enc4 = SharedMLP(256, 512)
        self.enc5 = SharedMLP(512, 1024)
        self.seg1 = SharedMLP(64 + 128 + 256 + 512 + 1024, 512)
        self.seg2 = SharedMLP(512, 256)
        self.seg3 = SharedMLP(256, 128)
        self.dropout1 = nn.Dropout(0.4)
        self.dropout2 = nn.Dropout(0.3)
        self.head = nn.Conv1d(128, num_classes, 1)

    def forward(self, x):
        B, N, _ = x.shape
        x = x.transpose(1, 2)
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        g = e5.max(dim=2, keepdim=True)[0].expand(-1, -1, N)
        seg = torch.cat([e1, e2, e3, e4, g], dim=1)
        seg = self.seg1(seg)
        seg = self.dropout1(seg)
        seg = self.seg2(seg)
        seg = self.dropout2(seg)
        seg = self.seg3(seg)
        seg = self.head(seg)
        return seg.transpose(1, 2)


# Load model
ckpt_path = os.path.join(CKPT_DIR, "best_model_v4.pt")
print(f"Loading {ckpt_path}...")

model = PointNetSegV4(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES).to(device)
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

n_params = sum(p.numel() for p in model.parameters())
print(f"PointNetSegV4: {n_params:,} params on {device}")
if "val_obstacle_miou" in ckpt:
    print(f"Epoch {ckpt.get('epoch', '?')}, val obstacle mIoU={ckpt['val_obstacle_miou']:.4f}")

## HDF5 Reader + Inference Functions

In [None]:
def get_frame_boundaries(h5_path, dataset_name="lidar_points", chunk_size=2_000_000):
    """Find frame boundaries by reading in chunks."""
    change_indices = []
    with h5py.File(h5_path, "r") as f:
        ds = f[dataset_name]
        n = ds.shape[0]
        prev_last_pose = None
        for offset in range(0, n, chunk_size):
            end = min(offset + chunk_size, n)
            chunk = ds[offset:end]
            ex = chunk["ego_x"]
            ey = chunk["ego_y"]
            ez = chunk["ego_z"]
            eyaw = chunk["ego_yaw"]
            if prev_last_pose is not None:
                cur_first = (int(ex[0]), int(ey[0]), int(ez[0]), int(eyaw[0]))
                if cur_first != prev_last_pose:
                    change_indices.append(offset)
            changes = np.where(
                (np.diff(ex) != 0) | (np.diff(ey) != 0) |
                (np.diff(ez) != 0) | (np.diff(eyaw) != 0)
            )[0] + 1
            for c in changes:
                change_indices.append(offset + int(c))
            prev_last_pose = (int(ex[-1]), int(ey[-1]), int(ez[-1]), int(eyaw[-1]))
            del chunk, ex, ey, ez, eyaw
            gc.collect()
    starts = [0] + change_indices
    ends = change_indices + [n]
    frames = []
    with h5py.File(h5_path, "r") as f:
        ds = f[dataset_name]
        for s, e in zip(starts, ends):
            row = ds[s]
            frames.append((s, e, int(row["ego_x"]), int(row["ego_y"]),
                           int(row["ego_z"]), int(row["ego_yaw"])))
    return frames


def read_frame_for_inference(h5_path, start, end, dataset_name="lidar_points"):
    """Read a single frame and compute features for inference.
    Returns: xyz_m (N,3), features (N,5)
    """
    with h5py.File(h5_path, "r") as f:
        chunk = f[dataset_name][start:end]
    valid = chunk[chunk["distance_cm"] > 0]
    del chunk
    dist_m = valid["distance_cm"].astype(np.float64) / 100.0
    az_rad = np.radians(valid["azimuth_raw"].astype(np.float64) / 100.0)
    el_rad = np.radians(valid["elevation_raw"].astype(np.float64) / 100.0)
    cos_el = np.cos(el_rad)
    x = dist_m * cos_el * np.cos(az_rad)
    y = -dist_m * cos_el * np.sin(az_rad)
    z = dist_m * np.sin(el_rad)
    xyz = np.column_stack((x, y, z)).astype(np.float32)
    refl_norm = (valid["reflectivity"].astype(np.float32) / 255.0).reshape(-1, 1)
    dist_norm = (dist_m.astype(np.float32) / 300.0).reshape(-1, 1)
    features = np.concatenate([xyz, refl_norm, dist_norm], axis=1)
    del valid, dist_m, az_rad, el_rad, cos_el, x, y, z
    return xyz, features


@torch.no_grad()
def predict_frame(model, features_np, device, chunk_size=65536):
    """Chunked inference — max chunk_size points per forward pass."""
    n = len(features_np)
    predictions = np.zeros(n, dtype=np.int64)
    for start in range(0, n, chunk_size):
        end = min(start + chunk_size, n)
        chunk = features_np[start:end]
        pad_to = max(len(chunk), 128)
        if len(chunk) < pad_to:
            padded = np.zeros((pad_to, chunk.shape[1]), dtype=np.float32)
            padded[:len(chunk)] = chunk
        else:
            padded = chunk
        tensor = torch.from_numpy(padded).unsqueeze(0).to(device)
        logits = model(tensor)
        preds = logits[0, :len(chunk)].argmax(dim=-1).cpu().numpy()
        predictions[start:end] = preds
        del tensor, logits, preds
    return predictions

print("Reader + inference functions defined.")

## Clustering + Bounding Boxes

In [None]:
def pca_oriented_bbox(points_m):
    """PCA-oriented bounding box."""
    center_xyz = points_m.mean(axis=0)
    centered = points_m - center_xyz
    cov = np.cov(centered.T)
    if np.any(np.isnan(cov)) or np.any(np.isinf(cov)):
        mins = points_m.min(axis=0)
        maxs = points_m.max(axis=0)
        return {"center_xyz": (mins + maxs) / 2.0, "dimensions": maxs - mins, "yaw": 0.0}
    try:
        eigenvalues, eigenvectors = np.linalg.eigh(cov)
    except np.linalg.LinAlgError:
        mins = points_m.min(axis=0)
        maxs = points_m.max(axis=0)
        return {"center_xyz": (mins + maxs) / 2.0, "dimensions": maxs - mins, "yaw": 0.0}
    order = eigenvalues.argsort()[::-1]
    eigenvectors = eigenvectors[:, order]
    projected = centered @ eigenvectors
    mins = projected.min(axis=0)
    maxs = projected.max(axis=0)
    dimensions = maxs - mins
    box_center_pca = (mins + maxs) / 2.0
    center_xyz = center_xyz + eigenvectors @ box_center_pca
    axis1_xy = eigenvectors[:2, 0]
    yaw = np.arctan2(axis1_xy[1], axis1_xy[0])
    return {"center_xyz": center_xyz, "dimensions": dimensions, "yaw": float(yaw)}


def cluster_class_points(points_m, class_id, max_points=10000):
    """DBSCAN clustering for a single class."""
    params = DBSCAN_PARAMS[class_id]
    eps, min_samples = params["eps"], params["min_samples"]
    if len(points_m) < min_samples:
        return []
    full_points = points_m
    if len(points_m) > max_points:
        idx = np.random.choice(len(points_m), max_points, replace=False)
        points_m = points_m[idx]
    labels = DBSCAN(eps=eps, min_samples=min_samples, algorithm="ball_tree").fit_predict(points_m)
    if len(full_points) > max_points:
        from sklearn.neighbors import BallTree
        sampled_mask = labels >= 0
        if sampled_mask.sum() == 0:
            return []
        tree = BallTree(points_m[sampled_mask])
        _, indices = tree.query(full_points, k=1)
        full_labels = labels[sampled_mask][indices.ravel()]
        dists = np.linalg.norm(full_points - points_m[sampled_mask][indices.ravel()], axis=1)
        full_labels[dists > eps * 2] = -1
        labels = full_labels
        points_m = full_points
    clusters = []
    for lbl in sorted(set(labels) - {-1}):
        clusters.append(points_m[labels == lbl])
    return clusters


def merge_cable_clusters(clusters):
    """Merge collinear cable clusters."""
    if len(clusters) <= 1:
        return clusters
    angle_thresh = np.radians(CABLE_MERGE_ANGLE_DEG)
    gap_thresh = CABLE_MERGE_GAP_M
    infos = []
    for pts in clusters:
        if len(pts) < 4:
            infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": None})
            continue
        centered = pts - pts.mean(axis=0)
        cov = np.cov(centered.T)
        if np.any(np.isnan(cov)) or np.any(np.isinf(cov)):
            infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": None})
            continue
        try:
            eigvals, eigvecs = np.linalg.eigh(cov)
        except np.linalg.LinAlgError:
            infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": None})
            continue
        axis1 = eigvecs[:, eigvals.argsort()[::-1][0]]
        if axis1[0] < 0:
            axis1 = -axis1
        infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": axis1})
    merged_flags = [False] * len(infos)
    result = []
    for i in range(len(infos)):
        if merged_flags[i]:
            continue
        current = infos[i]["points"]
        if infos[i]["axis1"] is not None:
            for j in range(i + 1, len(infos)):
                if merged_flags[j] or infos[j]["axis1"] is None:
                    continue
                dot = min(abs(np.dot(infos[i]["axis1"], infos[j]["axis1"])), 1.0)
                if np.arccos(dot) > angle_thresh:
                    continue
                cdist = np.linalg.norm(infos[i]["center"] - infos[j]["center"])
                ext_i = np.abs((infos[i]["points"] - infos[i]["center"]) @ infos[i]["axis1"]).max()
                ext_j = np.abs((infos[j]["points"] - infos[j]["center"]) @ infos[j]["axis1"]).max()
                if cdist - ext_i - ext_j <= gap_thresh:
                    current = np.vstack([current, infos[j]["points"]])
                    merged_flags[j] = True
        result.append(current)
    return result


def predictions_to_boxes(xyz_m, predictions):
    """Convert per-point predictions to bounding boxes."""
    boxes = []
    for cid in range(1, 5):
        mask = predictions == cid
        if mask.sum() == 0:
            continue
        class_points = xyz_m[mask]
        clusters = cluster_class_points(class_points, cid)
        if cid == 2 and len(clusters) > 1:
            clusters = merge_cable_clusters(clusters)
        for pts in clusters:
            if len(pts) < 3:
                continue
            bbox = pca_oriented_bbox(pts)
            boxes.append({
                "center_xyz": bbox["center_xyz"],
                "dimensions": bbox["dimensions"],
                "yaw": bbox["yaw"],
                "class_id": cid,
                "class_label": CLASS_LABELS_CSV[cid],
                "num_points": len(pts),
            })
    return boxes

print("Clustering + bbox functions defined.")

## Run Inference

In [None]:
# === SELECT INPUT FILES ===
# For evaluation: point to the eval data directory
# For testing: use a single training scene

INPUT_DIR = DATA_DIR  # change to eval dir when ready
h5_files = sorted(glob.glob(os.path.join(INPUT_DIR, "*.h5")))
print(f"Found {len(h5_files)} HDF5 files:")
for f in h5_files:
    print(f"  {os.path.basename(f)}")

In [None]:
# === RUN ON ONE SCENE (for testing) ===
# Change to loop over all files for final evaluation

h5_path = os.path.join(DATA_DIR, "scene_8.h5")  # validation scene
scene_name = "scene_8"
output_csv = os.path.join(OUTPUT_DIR, f"{scene_name}.csv")

print(f"Processing {scene_name}...")
t0 = time.time()

# Frame boundaries
frames_info = get_frame_boundaries(h5_path)
n_frames = len(frames_info)
print(f"{n_frames} frames found ({time.time()-t0:.1f}s)")

# Write CSV
with open(output_csv, "w") as f:
    f.write(CSV_HEADER)

total_boxes = 0
class_counts = {1: 0, 2: 0, 3: 0, 4: 0}
# Store last frame for visualization
last_xyz = None
last_preds = None
last_boxes = None

for idx in range(n_frames):
    start, end, ego_x, ego_y, ego_z, ego_yaw = frames_info[idx]

    xyz_m, features = read_frame_for_inference(h5_path, start, end)
    if len(xyz_m) == 0:
        continue

    predictions = predict_frame(model, features, device, chunk_size=CHUNK_SIZE)
    del features

    boxes = predictions_to_boxes(xyz_m, predictions)

    if boxes:
        lines = []
        for box in boxes:
            c = box["center_xyz"]
            d = box["dimensions"]
            lines.append(
                f"{ego_x},{ego_y},{ego_z},{ego_yaw},"
                f"{c[0]:.4f},{c[1]:.4f},{c[2]:.4f},"
                f"{d[0]:.4f},{d[1]:.4f},{d[2]:.4f},"
                f"{box['yaw']:.4f},"
                f"{box['class_id']},{box['class_label']}\n"
            )
        with open(output_csv, "a") as f:
            f.writelines(lines)
        total_boxes += len(boxes)
        for box in boxes:
            class_counts[box["class_id"]] += 1

    # Keep last frame for viz
    if idx == n_frames - 1:
        last_xyz = xyz_m.copy()
        last_preds = predictions.copy()
        last_boxes = boxes

    del xyz_m, predictions, boxes
    gc.collect()

    if (idx + 1) % 10 == 0 or idx == n_frames - 1:
        print(f"  {idx+1}/{n_frames} frames, {total_boxes} boxes")

elapsed = time.time() - t0
print(f"\nDONE — {total_boxes} boxes from {n_frames} frames in {elapsed:.0f}s ({elapsed/60:.1f} min)")
print(f"Avg: {elapsed/n_frames:.2f}s/frame, {total_boxes/n_frames:.1f} boxes/frame")
print(f"Per class: " + ", ".join(f"{CLASS_NAMES[c]}={class_counts[c]}" for c in range(1, 5)))
print(f"Output: {output_csv}")

del frames_info
gc.collect()

## Visualization

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

CLASS_COLORS_PLOT = {
    0: [0.7, 0.7, 0.7],  # background — gray
    1: [0.15, 0.09, 0.71],  # antenna — blue
    2: [0.69, 0.52, 0.18],  # cable — brown
    3: [0.51, 0.32, 0.38],  # electric_pole — mauve
    4: [0.26, 0.52, 0.04],  # wind_turbine — green
}

if last_xyz is not None:
    # Subsample for plotting
    n_plot = min(50000, len(last_xyz))
    idx_plot = np.random.choice(len(last_xyz), n_plot, replace=False)

    colors = np.array([CLASS_COLORS_PLOT[p] for p in last_preds[idx_plot]])

    fig = plt.figure(figsize=(16, 10))

    # Top-down view (XY)
    ax1 = fig.add_subplot(121)
    ax1.scatter(last_xyz[idx_plot, 0], last_xyz[idx_plot, 1],
                c=colors, s=0.3, alpha=0.5)
    ax1.set_xlabel("X (m)"); ax1.set_ylabel("Y (m)")
    ax1.set_title(f"Top-down — {scene_name} last frame")
    ax1.set_aspect("equal")

    # Mark bounding box centers
    if last_boxes:
        for box in last_boxes:
            c = box["center_xyz"]
            color = CLASS_COLORS_PLOT[box["class_id"]]
            ax1.plot(c[0], c[1], 'x', color=color, markersize=10, markeredgewidth=2)
            ax1.annotate(box["class_label"], (c[0], c[1]),
                         fontsize=7, color=color, ha='center', va='bottom')

    # Side view (XZ)
    ax2 = fig.add_subplot(122)
    ax2.scatter(last_xyz[idx_plot, 0], last_xyz[idx_plot, 2],
                c=colors, s=0.3, alpha=0.5)
    ax2.set_xlabel("X (m)"); ax2.set_ylabel("Z (m)")
    ax2.set_title(f"Side view — {scene_name} last frame")

    if last_boxes:
        for box in last_boxes:
            c = box["center_xyz"]
            color = CLASS_COLORS_PLOT[box["class_id"]]
            ax2.plot(c[0], c[2], 'x', color=color, markersize=10, markeredgewidth=2)

    plt.tight_layout()
    plt.show()
    plt.close(fig)
else:
    print("No frame data for visualization.")

## Summary Stats

In [None]:
import pandas as pd

df = pd.read_csv(output_csv)
print(f"Output CSV: {output_csv}")
print(f"Total rows (boxes): {len(df)}")
print(f"Unique frames: {df[['ego_x','ego_y','ego_z','ego_yaw']].drop_duplicates().shape[0]}")
print(f"\nBoxes per class:")
print(df['class_label'].value_counts())
print(f"\nBbox size stats (m):")
for col in ['bbox_width', 'bbox_length', 'bbox_height']:
    print(f"  {col}: mean={df[col].mean():.2f}, median={df[col].median():.2f}, max={df[col].max():.2f}")
print(f"\nFirst 5 rows:")
df.head()

## Full Evaluation Run (all files)

Uncomment and run when evaluation files are available.

In [None]:
# === FULL EVALUATION RUN ===
# Uncomment when eval files arrive (2 scenes × 4 densities = 8 files)

# EVAL_DIR = f"{DRIVE_BASE}/eval_data"  # <-- adjust path
# eval_files = sorted(glob.glob(os.path.join(EVAL_DIR, "*.h5")))
# print(f"Eval files: {len(eval_files)}")
#
# for h5_path in eval_files:
#     scene_name = os.path.splitext(os.path.basename(h5_path))[0]
#     output_csv = os.path.join(OUTPUT_DIR, f"{scene_name}.csv")
#     print(f"\n{'='*60}")
#     print(f"Processing {scene_name}...")
#     t0 = time.time()
#
#     frames_info = get_frame_boundaries(h5_path)
#     n_frames = len(frames_info)
#     with open(output_csv, "w") as f:
#         f.write(CSV_HEADER)
#
#     scene_boxes = 0
#     for idx in range(n_frames):
#         start, end, ego_x, ego_y, ego_z, ego_yaw = frames_info[idx]
#         xyz_m, features = read_frame_for_inference(h5_path, start, end)
#         if len(xyz_m) == 0:
#             continue
#         predictions = predict_frame(model, features, device, chunk_size=CHUNK_SIZE)
#         del features
#         boxes = predictions_to_boxes(xyz_m, predictions)
#         del xyz_m, predictions
#         if boxes:
#             lines = []
#             for box in boxes:
#                 c, d = box["center_xyz"], box["dimensions"]
#                 lines.append(
#                     f"{ego_x},{ego_y},{ego_z},{ego_yaw},"
#                     f"{c[0]:.4f},{c[1]:.4f},{c[2]:.4f},"
#                     f"{d[0]:.4f},{d[1]:.4f},{d[2]:.4f},"
#                     f"{box['yaw']:.4f},"
#                     f"{box['class_id']},{box['class_label']}\n"
#                 )
#             with open(output_csv, "a") as f:
#                 f.writelines(lines)
#             scene_boxes += len(boxes)
#         del boxes; gc.collect()
#         if (idx + 1) % 10 == 0:
#             print(f"  {idx+1}/{n_frames} frames, {scene_boxes} boxes")
#
#     print(f"DONE — {scene_boxes} boxes in {time.time()-t0:.0f}s")
#     del frames_info; gc.collect()
#
# print("\nAll eval files processed!")