# 05 — Inference Pipeline (v6 — D-Day Ready)

**Notebook Colab GPU** pour l'inference jour J.

**Améliorations v6 :**
- class_ID corrigé : 0,1,2,3 (pas 1,2,3,4)
- class_label : "Electric Pole", "Wind Turbine" (majuscules)
- Box confidence filtering (seuil 0.6)
- DBSCAN resserré + min_points rehaussés
- TTA optionnel (4x rotation Z)

**Usage jour J :**
1. Uploader les fichiers d'évaluation sur Drive dans `eval_data/`
2. Lancer toutes les cellules (Runtime > Run all)
3. Les CSVs sont dans `outputs/pred_eval/`

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install -q h5py scikit-learn

In [None]:
import gc
import glob
import os
import time

import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.cluster import DBSCAN
from sklearn.neighbors import BallTree

# ==========================================================================
# PATHS — Modifier ici pour le jour J
# ==========================================================================
DRIVE_BASE = "/content/drive/MyDrive/airbus_hackathon"

# --- JOUR J : décommenter la ligne EVAL et commenter TRAINING ---
INPUT_DIR = f"{DRIVE_BASE}/data"                    # TRAINING (test)
# INPUT_DIR = f"{DRIVE_BASE}/eval_data"             # JOUR J (évaluation)

OUTPUT_DIR = f"{DRIVE_BASE}/outputs/pred_v6"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Checkpoint (auto-select best)
CKPT_V5 = f"{DRIVE_BASE}/checkpoints_v5/best_model_v5.pt"
CKPT_V4 = f"{DRIVE_BASE}/checkpoints_v4/best_model_v4.pt"
CKPT_PATH = CKPT_V5 if os.path.exists(CKPT_V5) else CKPT_V4

# ==========================================================================
# OPTIONS
# ==========================================================================
USE_TTA = False          # True = 4x rotation averaging (slower, more robust)
SINGLE_SCENE = "scene_8" # None = traiter TOUS les fichiers .h5 du dossier

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Checkpoint: {CKPT_PATH}")
print(f"Input dir:  {INPUT_DIR}")
print(f"Output dir: {OUTPUT_DIR}")
print(f"Device:     {device}")
print(f"TTA:        {'ON (4x rotations)' if USE_TTA else 'OFF'}")
print(f"Mode:       {'Single scene (' + SINGLE_SCENE + ')' if SINGLE_SCENE else 'ALL scenes'}")
if torch.cuda.is_available():
    print(f"GPU:        {torch.cuda.get_device_name()}")
    print(f"VRAM:       {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Config (v6 — calibrated on scene_8 validation)

In [None]:
NUM_CLASSES = 5
IN_CHANNELS = 5  # x, y, z, reflectivity, norm_distance
CHUNK_SIZE = 65536  # points per forward pass

CLASS_NAMES = {0: "background", 1: "antenna", 2: "cable", 3: "electric_pole", 4: "wind_turbine"}

# Airbus spec labels — capitalization matters!
CLASS_LABELS_CSV = {1: "Antenna", 2: "Cable", 3: "Electric Pole", 4: "Wind Turbine"}

# Internal class_id (1-4) → Airbus class_ID (0-3)
CLASS_ID_TO_AIRBUS = {1: 0, 2: 1, 3: 2, 4: 3}

# DBSCAN — tightened for noisy predictions (v6)
DBSCAN_PARAMS = {
    1: {"eps": 2.0, "min_samples": 15},    # Antenna
    2: {"eps": 5.0, "min_samples": 8},     # Cable
    3: {"eps": 2.0, "min_samples": 12},    # Electric pole
    4: {"eps": 5.0, "min_samples": 25},    # Wind turbine
}

CABLE_MERGE_ANGLE_DEG = 15.0
CABLE_MERGE_GAP_M = 10.0

# Post-processing thresholds (v6 — calibrated on scene_8)
CONFIDENCE_THRESHOLD = 0.3       # per-point softmax filter
BOX_CONFIDENCE_THRESHOLD = 0.6   # per-box mean confidence filter
MIN_POINTS_PER_BOX = {1: 15, 2: 5, 3: 10, 4: 25}
MAX_DIM_PER_CLASS = {1: 200.0, 2: 400.0, 3: 100.0, 4: 250.0}
NMS_IOU_THRESHOLD = 0.3

CSV_HEADER = ("ego_x,ego_y,ego_z,ego_yaw,"
              "bbox_center_x,bbox_center_y,bbox_center_z,"
              "bbox_width,bbox_length,bbox_height,"
              "bbox_yaw,"
              "class_ID,class_label\n")

print("Config v6 loaded.")

# Model + Load checkpoint

In [None]:
class SharedMLP(nn.Module):
    def __init__(self, in_ch, out_ch, bn=True):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, 1, bias=not bn)
        self.bn = nn.BatchNorm1d(out_ch) if bn else None
    def forward(self, x):
        x = self.conv(x)
        if self.bn:
            x = self.bn(x)
        return F.relu(x, inplace=True)


class PointNetSegV4(nn.Module):
    """PointNet v4 segmentation — multi-scale skips, ~1.88M params."""
    def __init__(self, in_channels=5, num_classes=5):
        super().__init__()
        self.enc1 = SharedMLP(in_channels, 64)
        self.enc2 = SharedMLP(64, 128)
        self.enc3 = SharedMLP(128, 256)
        self.enc4 = SharedMLP(256, 512)
        self.enc5 = SharedMLP(512, 1024)
        self.seg1 = SharedMLP(64 + 128 + 256 + 512 + 1024, 512)
        self.seg2 = SharedMLP(512, 256)
        self.seg3 = SharedMLP(256, 128)
        self.dropout1 = nn.Dropout(0.4)
        self.dropout2 = nn.Dropout(0.3)
        self.head = nn.Conv1d(128, num_classes, 1)

    def forward(self, x):
        B, N, _ = x.shape
        x = x.transpose(1, 2)
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        g = e5.max(dim=2, keepdim=True)[0].expand(-1, -1, N)
        seg = torch.cat([e1, e2, e3, e4, g], dim=1)
        seg = self.seg1(seg)
        seg = self.dropout1(seg)
        seg = self.seg2(seg)
        seg = self.dropout2(seg)
        seg = self.seg3(seg)
        seg = self.head(seg)
        return seg.transpose(1, 2)


# Load model
print(f"Loading {CKPT_PATH}...")
model = PointNetSegV4(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES).to(device)
ckpt = torch.load(CKPT_PATH, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

n_params = sum(p.numel() for p in model.parameters())
print(f"PointNetSegV4: {n_params:,} params on {device}")
if "val_obstacle_miou" in ckpt:
    print(f"Epoch {ckpt.get('epoch', '?')}, val obstacle mIoU={ckpt['val_obstacle_miou']:.4f}")

# HDF5 Reader + Inference + TTA

In [None]:
# === HDF5 READER ===

def get_frame_boundaries(h5_path, dataset_name="lidar_points", chunk_size=2_000_000):
    change_indices = []
    with h5py.File(h5_path, "r") as f:
        ds = f[dataset_name]
        n = ds.shape[0]
        prev_last_pose = None
        for offset in range(0, n, chunk_size):
            end = min(offset + chunk_size, n)
            chunk = ds[offset:end]
            ex, ey, ez, eyaw = chunk["ego_x"], chunk["ego_y"], chunk["ego_z"], chunk["ego_yaw"]
            if prev_last_pose is not None:
                cur_first = (int(ex[0]), int(ey[0]), int(ez[0]), int(eyaw[0]))
                if cur_first != prev_last_pose:
                    change_indices.append(offset)
            changes = np.where(
                (np.diff(ex) != 0) | (np.diff(ey) != 0) |
                (np.diff(ez) != 0) | (np.diff(eyaw) != 0)
            )[0] + 1
            for c in changes:
                change_indices.append(offset + int(c))
            prev_last_pose = (int(ex[-1]), int(ey[-1]), int(ez[-1]), int(eyaw[-1]))
            del chunk, ex, ey, ez, eyaw; gc.collect()
    starts = [0] + change_indices
    ends = change_indices + [n]
    frames = []
    with h5py.File(h5_path, "r") as f:
        ds = f[dataset_name]
        for s, e in zip(starts, ends):
            row = ds[s]
            frames.append((s, e, int(row["ego_x"]), int(row["ego_y"]),
                           int(row["ego_z"]), int(row["ego_yaw"])))
    return frames


def read_frame_for_inference(h5_path, start, end, dataset_name="lidar_points"):
    with h5py.File(h5_path, "r") as f:
        chunk = f[dataset_name][start:end]
    valid = chunk[chunk["distance_cm"] > 0]
    del chunk
    dist_m = valid["distance_cm"].astype(np.float64) / 100.0
    az_rad = np.radians(valid["azimuth_raw"].astype(np.float64) / 100.0)
    el_rad = np.radians(valid["elevation_raw"].astype(np.float64) / 100.0)
    cos_el = np.cos(el_rad)
    x = dist_m * cos_el * np.cos(az_rad)
    y = -dist_m * cos_el * np.sin(az_rad)
    z = dist_m * np.sin(el_rad)
    xyz = np.column_stack((x, y, z)).astype(np.float32)
    refl_norm = (valid["reflectivity"].astype(np.float32) / 255.0).reshape(-1, 1)
    dist_norm = (dist_m.astype(np.float32) / 300.0).reshape(-1, 1)
    features = np.concatenate([xyz, refl_norm, dist_norm], axis=1)
    del valid, dist_m, az_rad, el_rad, cos_el, x, y, z
    return xyz, features


# === INFERENCE ===

@torch.no_grad()
def predict_frame_standard(model, features_np, device, chunk_size=CHUNK_SIZE):
    """Standard inference — returns (predictions, confidences)."""
    n = len(features_np)
    predictions = np.zeros(n, dtype=np.int64)
    confidences = np.zeros(n, dtype=np.float32)
    for start in range(0, n, chunk_size):
        end = min(start + chunk_size, n)
        chunk = features_np[start:end]
        pad_to = max(len(chunk), 128)
        if len(chunk) < pad_to:
            padded = np.zeros((pad_to, chunk.shape[1]), dtype=np.float32)
            padded[:len(chunk)] = chunk
        else:
            padded = chunk
        tensor = torch.from_numpy(padded).unsqueeze(0).to(device)
        logits = model(tensor)
        probs = F.softmax(logits[0, :len(chunk)], dim=-1)
        conf, preds = probs.max(dim=-1)
        preds, conf = preds.cpu().numpy(), conf.cpu().numpy()
        low_conf = (preds > 0) & (conf < CONFIDENCE_THRESHOLD)
        preds[low_conf] = 0
        predictions[start:end] = preds
        confidences[start:end] = conf
        del tensor, logits, probs, preds, conf
    return predictions, confidences


# === TTA ===

@torch.no_grad()
def _get_logits_chunked(model, features_np, device, chunk_size=CHUNK_SIZE):
    n = len(features_np)
    all_logits = np.zeros((n, NUM_CLASSES), dtype=np.float32)
    for start in range(0, n, chunk_size):
        end = min(start + chunk_size, n)
        chunk = features_np[start:end]
        pad_to = max(len(chunk), 128)
        if len(chunk) < pad_to:
            padded = np.zeros((pad_to, chunk.shape[1]), dtype=np.float32)
            padded[:len(chunk)] = chunk
        else:
            padded = chunk
        tensor = torch.from_numpy(padded).unsqueeze(0).to(device)
        logits = model(tensor)
        all_logits[start:end] = logits[0, :len(chunk)].cpu().numpy()
        del tensor, logits
    return all_logits


def _rotate_features_z(features_np, angle_rad):
    rotated = features_np.copy()
    cos_a, sin_a = np.cos(angle_rad), np.sin(angle_rad)
    x, y = features_np[:, 0], features_np[:, 1]
    rotated[:, 0] = cos_a * x - sin_a * y
    rotated[:, 1] = sin_a * x + cos_a * y
    return rotated


@torch.no_grad()
def predict_frame_tta(model, features_np, device, chunk_size=CHUNK_SIZE):
    """TTA: average logits over 4 Z-rotations. Returns (predictions, confidences)."""
    angles = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
    n = len(features_np)
    avg_logits = np.zeros((n, NUM_CLASSES), dtype=np.float32)
    for angle in angles:
        feats = features_np if angle == 0 else _rotate_features_z(features_np, angle)
        avg_logits += _get_logits_chunked(model, feats, device, chunk_size)
        if angle != 0: del feats
    avg_logits /= len(angles)
    probs = np.exp(avg_logits - avg_logits.max(axis=1, keepdims=True))
    probs /= probs.sum(axis=1, keepdims=True)
    predictions = probs.argmax(axis=1).astype(np.int64)
    confidences = probs[np.arange(n), predictions].astype(np.float32)
    low_conf = (predictions > 0) & (confidences < CONFIDENCE_THRESHOLD)
    predictions[low_conf] = 0
    del avg_logits, probs
    return predictions, confidences


# Select predict function
predict_frame = predict_frame_tta if USE_TTA else predict_frame_standard
print(f"Functions defined. Predict: {'TTA' if USE_TTA else 'standard'}")

# Clustering + Bounding Boxes + Post-processing (v6)

In [None]:
def pca_oriented_bbox(points_m):
    center_xyz = points_m.mean(axis=0)
    centered = points_m - center_xyz
    cov = np.cov(centered.T)
    if np.any(np.isnan(cov)) or np.any(np.isinf(cov)):
        mins, maxs = points_m.min(axis=0), points_m.max(axis=0)
        return {"center_xyz": (mins + maxs) / 2.0, "dimensions": maxs - mins, "yaw": 0.0}
    try:
        eigenvalues, eigenvectors = np.linalg.eigh(cov)
    except np.linalg.LinAlgError:
        mins, maxs = points_m.min(axis=0), points_m.max(axis=0)
        return {"center_xyz": (mins + maxs) / 2.0, "dimensions": maxs - mins, "yaw": 0.0}
    order = eigenvalues.argsort()[::-1]
    eigenvectors = eigenvectors[:, order]
    projected = centered @ eigenvectors
    mins, maxs = projected.min(axis=0), projected.max(axis=0)
    dimensions = maxs - mins
    box_center_pca = (mins + maxs) / 2.0
    center_xyz = center_xyz + eigenvectors @ box_center_pca
    yaw = np.arctan2(eigenvectors[1, 0], eigenvectors[0, 0])
    return {"center_xyz": center_xyz, "dimensions": dimensions, "yaw": float(yaw)}


def cluster_class_points(points_m, class_id, max_points=10000):
    params = DBSCAN_PARAMS[class_id]
    eps, min_samples = params["eps"], params["min_samples"]
    if len(points_m) < min_samples:
        return []
    full_points = points_m
    if len(points_m) > max_points:
        idx = np.random.choice(len(points_m), max_points, replace=False)
        points_m = points_m[idx]
    labels = DBSCAN(eps=eps, min_samples=min_samples, algorithm="ball_tree").fit_predict(points_m)
    if len(full_points) > max_points:
        sampled_mask = labels >= 0
        if sampled_mask.sum() == 0:
            return []
        tree = BallTree(points_m[sampled_mask])
        _, indices = tree.query(full_points, k=1)
        full_labels = labels[sampled_mask][indices.ravel()]
        dists = np.linalg.norm(full_points - points_m[sampled_mask][indices.ravel()], axis=1)
        full_labels[dists > eps * 2] = -1
        labels = full_labels
        points_m = full_points
    clusters = []
    for lbl in sorted(set(labels) - {-1}):
        clusters.append(points_m[labels == lbl])
    return clusters


def merge_cable_clusters(clusters):
    if len(clusters) <= 1:
        return clusters
    angle_thresh = np.radians(CABLE_MERGE_ANGLE_DEG)
    gap_thresh = CABLE_MERGE_GAP_M
    infos = []
    for pts in clusters:
        if len(pts) < 4:
            infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": None})
            continue
        centered = pts - pts.mean(axis=0)
        cov = np.cov(centered.T)
        if np.any(np.isnan(cov)) or np.any(np.isinf(cov)):
            infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": None})
            continue
        try:
            eigvals, eigvecs = np.linalg.eigh(cov)
        except np.linalg.LinAlgError:
            infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": None})
            continue
        axis1 = eigvecs[:, eigvals.argsort()[::-1][0]]
        if axis1[0] < 0: axis1 = -axis1
        infos.append({"points": pts, "center": pts.mean(axis=0), "axis1": axis1})
    merged_flags = [False] * len(infos)
    result = []
    for i in range(len(infos)):
        if merged_flags[i]: continue
        current = infos[i]["points"]
        if infos[i]["axis1"] is not None:
            for j in range(i + 1, len(infos)):
                if merged_flags[j] or infos[j]["axis1"] is None: continue
                dot = min(abs(np.dot(infos[i]["axis1"], infos[j]["axis1"])), 1.0)
                if np.arccos(dot) > angle_thresh: continue
                cdist = np.linalg.norm(infos[i]["center"] - infos[j]["center"])
                ext_i = np.abs((infos[i]["points"] - infos[i]["center"]) @ infos[i]["axis1"]).max()
                ext_j = np.abs((infos[j]["points"] - infos[j]["center"]) @ infos[j]["axis1"]).max()
                if cdist - ext_i - ext_j <= gap_thresh:
                    current = np.vstack([current, infos[j]["points"]])
                    merged_flags[j] = True
        result.append(current)
    return result


def filter_boxes(boxes):
    return [b for b in boxes
            if b["num_points"] >= MIN_POINTS_PER_BOX.get(b["class_id"], 3)
            and max(b["dimensions"]) <= MAX_DIM_PER_CLASS.get(b["class_id"], 500.0)]


def _box_iou_3d(a, b):
    ca, da, cb, db = a["center_xyz"], a["dimensions"], b["center_xyz"], b["dimensions"]
    ha, hb = da / 2.0, db / 2.0
    overlap = np.maximum(0, np.minimum(ca + ha, cb + hb) - np.maximum(ca - ha, cb - hb))
    inter = overlap[0] * overlap[1] * overlap[2]
    union = da[0]*da[1]*da[2] + db[0]*db[1]*db[2] - inter
    return inter / union if union > 0 else 0.0


def nms_boxes(boxes, iou_threshold=NMS_IOU_THRESHOLD):
    if len(boxes) <= 1: return boxes
    by_class = {}
    for b in boxes: by_class.setdefault(b["class_id"], []).append(b)
    result = []
    for cid, cb in by_class.items():
        cb.sort(key=lambda b: b["num_points"], reverse=True)
        suppressed = [False] * len(cb)
        for i in range(len(cb)):
            if suppressed[i]: continue
            result.append(cb[i])
            for j in range(i+1, len(cb)):
                if not suppressed[j] and _box_iou_3d(cb[i], cb[j]) > iou_threshold:
                    suppressed[j] = True
    return result


def predictions_to_boxes(xyz_m, predictions, confidences=None):
    """Full pipeline: cluster → PCA bbox → confidence filter → size filter → NMS."""
    boxes = []
    for cid in range(1, 5):
        mask = predictions == cid
        if mask.sum() == 0: continue
        class_points = xyz_m[mask]
        class_conf = confidences[mask] if confidences is not None else None
        clusters = cluster_class_points(class_points, cid)
        if cid == 2 and len(clusters) > 1:
            clusters = merge_cable_clusters(clusters)
        # Build BallTree once per class for confidence lookup
        conf_tree = None
        if class_conf is not None and len(clusters) > 0:
            conf_tree = BallTree(class_points)
        for pts in clusters:
            if len(pts) < 3: continue
            box_confidence = 0.0
            if conf_tree is not None:
                _, indices = conf_tree.query(pts, k=1)
                box_confidence = float(class_conf[indices.ravel()].mean())
            bbox = pca_oriented_bbox(pts)
            boxes.append({
                "center_xyz": bbox["center_xyz"], "dimensions": bbox["dimensions"],
                "yaw": bbox["yaw"], "class_id": cid, "class_label": CLASS_LABELS_CSV[cid],
                "num_points": len(pts), "confidence": box_confidence,
            })
    # Box confidence filter
    if BOX_CONFIDENCE_THRESHOLD > 0:
        boxes = [b for b in boxes if b["confidence"] >= BOX_CONFIDENCE_THRESHOLD]
    boxes = filter_boxes(boxes)
    boxes = nms_boxes(boxes)
    return boxes


def boxes_to_csv_lines(boxes, ego_x, ego_y, ego_z, ego_yaw):
    """Format boxes as CSV lines — Airbus deliverable format (class_ID 0-3)."""
    lines = []
    for box in boxes:
        c, d = box["center_xyz"], box["dimensions"]
        airbus_cid = CLASS_ID_TO_AIRBUS[box["class_id"]]
        lines.append(
            f"{ego_x},{ego_y},{ego_z},{ego_yaw},"
            f"{c[0]:.4f},{c[1]:.4f},{c[2]:.4f},"
            f"{d[0]:.4f},{d[1]:.4f},{d[2]:.4f},"
            f"{box['yaw']:.4f},"
            f"{airbus_cid},{box['class_label']}\n"
        )
    return lines

print("Clustering + post-processing v6 defined.")

# RUN INFERENCE

In [None]:
# Collect input files
if SINGLE_SCENE:
    h5_files = [os.path.join(INPUT_DIR, f"{SINGLE_SCENE}.h5")]
else:
    h5_files = sorted(glob.glob(os.path.join(INPUT_DIR, "*.h5")))

print(f"Files to process: {len(h5_files)}")
for f in h5_files:
    print(f"  {os.path.basename(f)}")

assert len(h5_files) > 0, f"No .h5 files found in {INPUT_DIR}"

In [None]:
# === MAIN INFERENCE LOOP ===

total_boxes = 0
total_frames = 0
t_total = time.time()
all_scene_stats = []

for h5_path in h5_files:
    scene_name = os.path.splitext(os.path.basename(h5_path))[0]
    output_csv = os.path.join(OUTPUT_DIR, f"{scene_name}.csv")

    print(f"\n{'='*60}")
    print(f"[{scene_name}] Processing {h5_path}")

    t0 = time.time()
    frames_info = get_frame_boundaries(h5_path)
    n_frames = len(frames_info)
    print(f"[{scene_name}] {n_frames} frames ({time.time()-t0:.1f}s)")

    with open(output_csv, "w") as f:
        f.write(CSV_HEADER)

    scene_boxes = 0
    class_counts = {1: 0, 2: 0, 3: 0, 4: 0}

    for idx in range(n_frames):
        start, end, ego_x, ego_y, ego_z, ego_yaw = frames_info[idx]

        xyz_m, features = read_frame_for_inference(h5_path, start, end)
        if len(xyz_m) == 0:
            total_frames += 1
            continue

        predictions, confidences = predict_frame(model, features, device)
        del features

        boxes = predictions_to_boxes(xyz_m, predictions, confidences)
        del xyz_m, predictions, confidences

        if boxes:
            lines = boxes_to_csv_lines(boxes, ego_x, ego_y, ego_z, ego_yaw)
            with open(output_csv, "a") as f:
                f.writelines(lines)
            scene_boxes += len(boxes)
            for box in boxes:
                class_counts[box["class_id"]] += 1

        del boxes; gc.collect()
        total_frames += 1

        if (idx + 1) % 10 == 0 or idx == n_frames - 1:
            print(f"  {idx+1}/{n_frames} frames, {scene_boxes} boxes")

    total_boxes += scene_boxes
    elapsed_scene = time.time() - t0
    stats = {
        "scene": scene_name, "frames": n_frames, "boxes": scene_boxes,
        "time_s": elapsed_scene, **{CLASS_NAMES[c]: class_counts[c] for c in range(1, 5)}
    }
    all_scene_stats.append(stats)
    print(f"[{scene_name}] DONE — {scene_boxes} boxes, {elapsed_scene:.0f}s")
    print(f"  Per class: " + ", ".join(f"{CLASS_NAMES[c]}={class_counts[c]}" for c in range(1, 5)))
    print(f"  Output: {output_csv}")

    del frames_info; gc.collect()

elapsed = time.time() - t_total
print(f"\n{'='*60}")
print(f"INFERENCE COMPLETE")
print(f"Total: {total_boxes} boxes across {total_frames} frames")
print(f"Time: {elapsed:.0f}s ({elapsed/60:.1f} min)")
print(f"Avg: {elapsed/max(total_frames,1):.2f}s/frame, {total_boxes/max(total_frames,1):.1f} boxes/frame")
print(f"Output dir: {OUTPUT_DIR}")
print(f"{'='*60}")

# Validation : vérifier le format CSV

In [None]:
import pandas as pd

# Validate all output CSVs
print("=== CSV VALIDATION ===\n")
csv_files = sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.csv")))
all_ok = True

for csv_path in csv_files:
    name = os.path.basename(csv_path)
    df = pd.read_csv(csv_path)

    # Check header
    expected_cols = ["ego_x","ego_y","ego_z","ego_yaw",
                     "bbox_center_x","bbox_center_y","bbox_center_z",
                     "bbox_width","bbox_length","bbox_height",
                     "bbox_yaw","class_ID","class_label"]
    cols_ok = list(df.columns) == expected_cols

    # Check class_ID range (must be 0-3, NOT 1-4)
    if len(df) > 0:
        cid_ok = set(df["class_ID"].unique()).issubset({0, 1, 2, 3})
        labels_ok = set(df["class_label"].unique()).issubset({"Antenna", "Cable", "Electric Pole", "Wind Turbine"})
    else:
        cid_ok = True
        labels_ok = True

    ok = cols_ok and cid_ok and labels_ok
    status = "OK" if ok else "FAIL"
    if not ok: all_ok = False

    print(f"  {name}: {len(df)} rows, class_IDs={sorted(df['class_ID'].unique()) if len(df) else '[]'}, [{status}]")
    if not cols_ok: print(f"    COLUMNS MISMATCH: {list(df.columns)}")
    if not cid_ok: print(f"    CLASS_ID ERROR: found {sorted(df['class_ID'].unique())}, expected 0-3")
    if not labels_ok: print(f"    LABEL ERROR: found {sorted(df['class_label'].unique())}")

print(f"\n{'ALL CSVs VALID' if all_ok else 'SOME CSVs HAVE ERRORS — FIX BEFORE SUBMISSION'}")

# Summary

In [None]:
import pandas as pd

print("=== SUMMARY ===\n")
print(f"Model: PointNetSegV4, {n_params:,} params")
print(f"Device: {device}, TTA: {'ON' if USE_TTA else 'OFF'}")
print(f"Box confidence threshold: {BOX_CONFIDENCE_THRESHOLD}")
print(f"Total time: {elapsed:.0f}s ({elapsed/60:.1f} min)\n")

if all_scene_stats:
    df_stats = pd.DataFrame(all_scene_stats)
    print(df_stats.to_string(index=False))

print(f"\nOutput CSVs in: {OUTPUT_DIR}")
print("\n--- CHECKLIST JOUR J ---")
print("[ ] Changer INPUT_DIR vers eval_data/")
print("[ ] Mettre SINGLE_SCENE = None")
print("[ ] Vérifier que le checkpoint est sur Drive")
print("[ ] Runtime > Run all")
print("[ ] Vérifier les CSVs (cellule validation)")
print("[ ] Télécharger les CSVs et soumettre")