# 04 — PointNet-lite Training (GPU)

**Story 2.1-2.3** — Train per-point segmentation model on T4 GPU.

- Model: PointNet-lite (~117k params)
- Input: 32k points × 4 features (x, y, z, reflectivity)
- Output: 5-class per-point segmentation
- Val scene: scene_8

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install -q onnx onnxscript

In [None]:
import os, sys, gc, time, json
import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

DRIVE_BASE = "/content/drive/MyDrive/airbus_hackathon"
DATA_DIR = f"{DRIVE_BASE}/data"
CKPT_DIR = f"{DRIVE_BASE}/checkpoints"
os.makedirs(CKPT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Config

In [None]:
# === CONFIG ===
SCENE_FILES = [f"scene_{i}.h5" for i in range(1, 11)]
VAL_SCENE = "scene_8"
NUM_CLASSES = 5
CLASS_NAMES = {0: "background", 1: "antenna", 2: "cable", 3: "electric_pole", 4: "wind_turbine"}
CLASS_WEIGHTS = [0.1, 14.14, 70.41, 49.94, 6.29]
CLASS_COLORS = {
    (38, 23, 180): 1, (177, 132, 47): 2,
    (129, 81, 97): 3, (66, 132, 9): 4,
}

NUM_POINTS = 32000
BATCH_SIZE = 8  # GPU can handle more
EPOCHS = 50
LR = 1e-3

print("Config loaded.")

## 2. Dataset

In [None]:
def get_frame_boundaries(h5_path, dataset_name="lidar_points", chunk_size=2_000_000):
    """Find frame boundaries by reading in chunks — vectorized."""
    change_indices = []
    with h5py.File(h5_path, "r") as f:
        ds = f[dataset_name]
        n = ds.shape[0]
        prev_last_pose = None
        for offset in range(0, n, chunk_size):
            end = min(offset + chunk_size, n)
            chunk = ds[offset:end]
            ex, ey, ez, eyaw = chunk["ego_x"], chunk["ego_y"], chunk["ego_z"], chunk["ego_yaw"]
            if prev_last_pose is not None:
                cur_first = (int(ex[0]), int(ey[0]), int(ez[0]), int(eyaw[0]))
                if cur_first != prev_last_pose:
                    change_indices.append(offset)
            changes = np.where(
                (np.diff(ex) != 0) | (np.diff(ey) != 0) |
                (np.diff(ez) != 0) | (np.diff(eyaw) != 0)
            )[0] + 1
            for c in changes:
                change_indices.append(offset + int(c))
            prev_last_pose = (int(ex[-1]), int(ey[-1]), int(ez[-1]), int(eyaw[-1]))
            del chunk, ex, ey, ez, eyaw
    starts = [0] + change_indices
    ends = change_indices + [n]
    frames = []
    with h5py.File(h5_path, "r") as f:
        ds = f[dataset_name]
        for s, e in zip(starts, ends):
            row = ds[s]
            frames.append((s, e, int(row["ego_x"]), int(row["ego_y"]),
                           int(row["ego_z"]), int(row["ego_yaw"])))
    return frames


def map_rgb_to_class(r, g, b):
    class_ids = np.zeros(len(r), dtype=np.int64)
    for (cr, cg, cb), cid in CLASS_COLORS.items():
        mask = (r == cr) & (g == cg) & (b == cb)
        class_ids[mask] = cid
    return class_ids


class LidarSegDataset(Dataset):
    def __init__(self, data_dir, scene_files, num_points=32000, augment=False):
        self.data_dir = data_dir
        self.num_points = num_points
        self.augment = augment
        self.index = []
        for sf in scene_files:
            h5_path = os.path.join(data_dir, sf)
            scene_name = sf.replace(".h5", "")
            if not os.path.exists(h5_path):
                continue
            frames = get_frame_boundaries(h5_path)
            for idx, (start, end, ex, ey, ez, eyaw) in enumerate(frames):
                self.index.append((h5_path, start, end, ex, ey, ez, eyaw, scene_name, idx))
        print(f"LidarSegDataset: {len(self.index)} frames from {len(scene_files)} scenes")

    def __len__(self):
        return len(self.index)

    def __getitem__(self, i):
        h5_path, start, end, ex, ey, ez, eyaw, scene_name, frame_idx = self.index[i]
        with h5py.File(h5_path, "r") as f:
            chunk = f["lidar_points"][start:end]
        valid = chunk[chunk["distance_cm"] > 0]
        del chunk

        n_pts = len(valid)
        if n_pts == 0:
            return (torch.zeros(self.num_points, 4, dtype=torch.float32),
                    torch.zeros(self.num_points, dtype=torch.int64))

        dist_m = valid["distance_cm"].astype(np.float64) / 100.0
        az_rad = np.radians(valid["azimuth_raw"].astype(np.float64) / 100.0)
        el_rad = np.radians(valid["elevation_raw"].astype(np.float64) / 100.0)
        cos_el = np.cos(el_rad)
        x = dist_m * cos_el * np.cos(az_rad)
        y = -dist_m * cos_el * np.sin(az_rad)
        z = dist_m * np.sin(el_rad)
        xyz = np.column_stack((x, y, z)).astype(np.float32)
        refl = (valid["reflectivity"].astype(np.float32) / 255.0).reshape(-1, 1)
        labels = map_rgb_to_class(valid["r"].astype(np.uint8),
                                   valid["g"].astype(np.uint8),
                                   valid["b"].astype(np.uint8))
        del valid, dist_m, az_rad, el_rad, cos_el, x, y, z

        # Fixed-size sampling
        if n_pts >= self.num_points:
            idx = np.random.choice(n_pts, self.num_points, replace=False)
        else:
            idx = np.random.choice(n_pts, self.num_points, replace=True)
        xyz = xyz[idx]; refl = refl[idx]; labels = labels[idx]

        # Augmentation
        if self.augment:
            theta = np.random.uniform(0, 2 * np.pi)
            c, s = np.cos(theta), np.sin(theta)
            rot = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]], dtype=np.float32)
            xyz = xyz @ rot.T
            xyz += np.random.normal(0, 0.02, xyz.shape).astype(np.float32)
            # Random point drop
            drop = np.random.uniform(0.0, 0.25)
            n_drop = int(self.num_points * drop)
            if n_drop > 0:
                keep = np.random.choice(self.num_points, self.num_points - n_drop, replace=False)
                fill = np.random.choice(keep, n_drop, replace=True)
                order = np.concatenate([keep, fill])
                np.random.shuffle(order)
                xyz = xyz[order]; refl = refl[order]; labels = labels[order]

        features = np.concatenate([xyz, refl], axis=1)
        return torch.from_numpy(features), torch.from_numpy(labels)

print("Dataset class defined.")

## 3. Model

In [None]:
class SharedMLP(nn.Module):
    def __init__(self, in_ch, out_ch, bn=True):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, 1, bias=not bn)
        self.bn = nn.BatchNorm1d(out_ch) if bn else None
    def forward(self, x):
        x = self.conv(x)
        if self.bn: x = self.bn(x)
        return F.relu(x, inplace=True)


class PointNetLiteSeg(nn.Module):
    def __init__(self, in_channels=4, num_classes=5):
        super().__init__()
        self.local1 = SharedMLP(in_channels, 64)
        self.local2 = SharedMLP(64, 128)
        self.local3 = SharedMLP(128, 256)
        self.seg1 = SharedMLP(256 + 256, 128)
        self.seg2 = SharedMLP(128, 64)
        self.seg3 = nn.Conv1d(64, num_classes, 1)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        B, N, _ = x.shape
        x = x.transpose(1, 2)
        l1 = self.local1(x)
        l2 = self.local2(l1)
        l3 = self.local3(l2)
        g = l3.max(dim=2, keepdim=True)[0].expand(-1, -1, N)
        seg = self.seg1(torch.cat([l3, g], dim=1))
        seg = self.dropout(seg)
        seg = self.seg2(seg)
        seg = self.seg3(seg)
        return seg.transpose(1, 2)


model = PointNetLiteSeg(in_channels=4, num_classes=NUM_CLASSES).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"PointNet-lite: {n_params:,} parameters on {device}")

# Quick test
with torch.no_grad():
    test = torch.randn(2, 1000, 4, device=device)
    out = model(test)
    print(f"Test forward: {test.shape} → {out.shape}")

## 4. Load Data & Create Loaders

In [None]:
%%time

print("Loading dataset (indexing frame boundaries)...")
train_dataset = LidarSegDataset(DATA_DIR, SCENE_FILES, num_points=NUM_POINTS, augment=True)
val_dataset = LidarSegDataset(DATA_DIR, SCENE_FILES, num_points=NUM_POINTS, augment=False)

# Split by scene
train_idx = [i for i, e in enumerate(train_dataset.index) if e[7] != VAL_SCENE]
val_idx = [i for i, e in enumerate(val_dataset.index) if e[7] == VAL_SCENE]

train_subset = Subset(train_dataset, train_idx)
val_subset = Subset(val_dataset, val_idx)

print(f"Train: {len(train_subset)}, Val: {len(val_subset)}")

# Colab has 2 CPUs, use 2 workers
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=2, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=2, pin_memory=True)

## 5. Training Loop

In [None]:
def compute_metrics(preds, labels):
    metrics = {}
    metrics["accuracy"] = (preds == labels).sum().item() / labels.numel()
    ious = []
    for c in range(NUM_CLASSES):
        intersection = ((preds == c) & (labels == c)).sum().item()
        union = ((preds == c) | (labels == c)).sum().item()
        iou = intersection / union if union > 0 else float("nan")
        metrics[f"iou_{CLASS_NAMES[c]}"] = iou
        if union > 0: ious.append(iou)
    metrics["mean_iou"] = np.nanmean(ious) if ious else 0.0
    return metrics


def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss, total_correct, total_pts = 0, 0, 0
    for features, labels in loader:
        features, labels = features.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(features)
        loss = criterion(logits.reshape(-1, NUM_CLASSES), labels.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * features.size(0)
        total_correct += (logits.argmax(-1) == labels).sum().item()
        total_pts += labels.numel()
    return total_loss / len(loader.dataset), total_correct / total_pts


@torch.no_grad()
def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    for features, labels in loader:
        features, labels = features.to(device), labels.to(device)
        logits = model(features)
        loss = criterion(logits.reshape(-1, NUM_CLASSES), labels.reshape(-1))
        total_loss += loss.item() * features.size(0)
        all_preds.append(logits.argmax(-1).cpu())
        all_labels.append(labels.cpu())
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    metrics = compute_metrics(all_preds, all_labels)
    metrics["loss"] = total_loss / len(loader.dataset)
    return metrics

print("Training functions defined.")

In [None]:
# === TRAINING ===
weights = torch.tensor(CLASS_WEIGHTS, dtype=torch.float32, device=device)
criterion = nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

best_miou = 0.0
history = []

print(f"Training for {EPOCHS} epochs on {device}...")
print(f"{'Epoch':>5} | {'TrLoss':>7} {'TrAcc':>6} | {'VaLoss':>7} {'VaAcc':>6} {'mIoU':>6} | "
      f"{'Ant':>5} {'Cab':>5} {'Pol':>5} {'Tur':>5} | {'LR':>8} {'Time':>5}")
print("-" * 100)

t_total = time.time()

for epoch in range(1, EPOCHS + 1):
    t0 = time.time()
    
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_m = validate(model, val_loader, criterion)
    scheduler.step()
    
    elapsed = time.time() - t0
    lr = optimizer.param_groups[0]["lr"]
    
    log = {
        "epoch": epoch, "train_loss": train_loss, "train_acc": train_acc,
        "val_loss": val_m["loss"], "val_acc": val_m["accuracy"],
        "val_miou": val_m["mean_iou"], "lr": lr, "time_s": elapsed,
    }
    for c in range(1, 5):
        log[f"val_iou_{CLASS_NAMES[c]}"] = val_m.get(f"iou_{CLASS_NAMES[c]}", 0)
    history.append(log)
    
    ant = val_m.get("iou_antenna", 0)
    cab = val_m.get("iou_cable", 0)
    pol = val_m.get("iou_electric_pole", 0)
    tur = val_m.get("iou_wind_turbine", 0)
    
    print(f"{epoch:5d} | {train_loss:7.4f} {train_acc:6.4f} | "
          f"{val_m['loss']:7.4f} {val_m['accuracy']:6.4f} {val_m['mean_iou']:6.4f} | "
          f"{ant:5.3f} {cab:5.3f} {pol:5.3f} {tur:5.3f} | "
          f"{lr:8.6f} {elapsed:5.0f}s")
    
    # Save best
    if val_m["mean_iou"] > best_miou:
        best_miou = val_m["mean_iou"]
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "val_miou": best_miou,
            "val_metrics": val_m,
            "n_params": n_params,
        }, os.path.join(CKPT_DIR, "best_model.pt"))
        print(f"  >>> New best mIoU={best_miou:.4f}")
    
    # Checkpoint every 10 epochs
    if epoch % 10 == 0:
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
        }, os.path.join(CKPT_DIR, f"checkpoint_epoch{epoch}.pt"))

total_time = time.time() - t_total
print(f"\n{'='*60}")
print(f"TRAINING COMPLETE in {total_time:.0f}s ({total_time/60:.1f} min)")
print(f"Best mIoU: {best_miou:.4f}")
print(f"Model: {n_params:,} params")

## 6. Save History & ONNX Export

In [None]:
# Save training history
with open(os.path.join(CKPT_DIR, "training_history.json"), "w") as f:
    json.dump(history, f, indent=2)
print("History saved.")

# Load best model and export to ONNX
best_ckpt = torch.load(os.path.join(CKPT_DIR, "best_model.pt"), map_location=device, weights_only=False)
model.load_state_dict(best_ckpt["model_state_dict"])
model.eval()

dummy_input = torch.randn(1, NUM_POINTS, 4, device=device)
onnx_path = os.path.join(CKPT_DIR, "pointnet_lite.onnx")

torch.onnx.export(
    model, dummy_input, onnx_path,
    input_names=["points"],
    output_names=["logits"],
    dynamic_axes={"points": {0: "batch", 1: "num_points"},
                  "logits": {0: "batch", 1: "num_points"}},
    opset_version=17,
)

onnx_size = os.path.getsize(onnx_path) / 1e6
print(f"\nONNX exported: {onnx_path}")
print(f"ONNX size: {onnx_size:.2f} MB")
print(f"Best epoch: {best_ckpt['epoch']}, mIoU: {best_ckpt['val_miou']:.4f}")

## 7. Training Curves

In [None]:
import matplotlib.pyplot as plt

epochs_list = [h["epoch"] for h in history]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss
axes[0].plot(epochs_list, [h["train_loss"] for h in history], label="Train")
axes[0].plot(epochs_list, [h["val_loss"] for h in history], label="Val")
axes[0].set_xlabel("Epoch"); axes[0].set_ylabel("Loss")
axes[0].set_title("Loss"); axes[0].legend()

# mIoU
axes[1].plot(epochs_list, [h["val_miou"] for h in history], 'g-', linewidth=2)
axes[1].set_xlabel("Epoch"); axes[1].set_ylabel("mIoU")
axes[1].set_title(f"Val mIoU (best={best_miou:.4f})")

# Per-class IoU
for c in range(1, 5):
    name = CLASS_NAMES[c]
    vals = [h.get(f"val_iou_{name}", 0) for h in history]
    axes[2].plot(epochs_list, vals, label=name)
axes[2].set_xlabel("Epoch"); axes[2].set_ylabel("IoU")
axes[2].set_title("Per-class IoU"); axes[2].legend()

plt.tight_layout()
plt.savefig(os.path.join(CKPT_DIR, "training_curves.png"), dpi=150)
plt.show()
print("Curves saved.")

## 8. Final Summary

In [None]:
print("=" * 60)
print("STORY 2.1-2.3 SUMMARY")
print("=" * 60)
print(f"Model: PointNet-lite ({n_params:,} params)")
print(f"Training: {EPOCHS} epochs, batch={BATCH_SIZE}, lr={LR}")
print(f"Train frames: {len(train_subset)}, Val frames: {len(val_subset)} (scene_8)")
print(f"Best mIoU: {best_miou:.4f} (epoch {best_ckpt['epoch']})")
print(f"ONNX: {onnx_size:.2f} MB")
print(f"Total training time: {total_time/60:.1f} min")

print(f"\nPer-class IoU (best epoch):")
for c in range(NUM_CLASSES):
    iou = best_ckpt['val_metrics'].get(f'iou_{CLASS_NAMES[c]}', 0)
    print(f"  {CLASS_NAMES[c]:15s}: {iou:.4f}")

print(f"\nNext: Story 2.4 (robustness test) + Story 3 (inference pipeline)")