# 04 — PointNet-lite Training v3 (GPU)

**Story 2.1-2.3** — Per-point segmentation on T4 GPU.

**v3 — key changes vs v1/v2:**
- **Balanced sampling**: ~50% obstacle / ~50% background per sample (core fix)
- **64k points** per sample (was 32k) — better spatial coverage
- **Batch size 8** — fits T4 15GB VRAM
- **5 input features**: x, y, z, reflectivity, normalized_distance
- Focal Loss + obstacle-only mIoU (from v2)
- LR 3e-4 with warmup + gradient clipping (from v2)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install -q onnx onnxscript

In [None]:
import os, sys, gc, time, json
import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

DRIVE_BASE = "/content/drive/MyDrive/airbus_hackathon"
DATA_DIR = f"{DRIVE_BASE}/data"
CKPT_DIR = f"{DRIVE_BASE}/checkpoints_v3"
os.makedirs(CKPT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    vram = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"VRAM: {vram:.1f} GB")

## 1. Config

In [None]:
# === CONFIG v3 ===
SCENE_FILES = [f"scene_{i}.h5" for i in range(1, 11)]
VAL_SCENE = "scene_8"
NUM_CLASSES = 5
IN_CHANNELS = 5  # x, y, z, reflectivity, norm_distance
CLASS_NAMES = {0: "background", 1: "antenna", 2: "cable", 3: "electric_pole", 4: "wind_turbine"}
OBSTACLE_CLASSES = [1, 2, 3, 4]
CLASS_COLORS = {
    (38, 23, 180): 1, (177, 132, 47): 2,
    (129, 81, 97): 3, (66, 132, 9): 4,
}

# === RESOURCES ===
NUM_POINTS = 65536       # 64k (was 32k)
BATCH_SIZE = 8           # fits T4 15GB
OBSTACLE_RATIO = 0.5     # target 50% obstacle points per sample

# === TRAINING ===
EPOCHS = 80
LR = 3e-4
WARMUP_EPOCHS = 5
MAX_GRAD_NORM = 1.0

# Focal loss — alpha adjusted for balanced sampling
FOCAL_ALPHA = [0.10, 0.20, 0.35, 0.20, 0.15]
FOCAL_GAMMA = 2.0

print(f"Config v3: {NUM_POINTS} pts, batch={BATCH_SIZE}, {IN_CHANNELS} features")
print(f"Balanced sampling: {OBSTACLE_RATIO*100:.0f}% obstacle target")

## 2. Dataset with Balanced Sampling

In [None]:
def get_frame_boundaries(h5_path, dataset_name="lidar_points", chunk_size=2_000_000):
    """Find frame boundaries by reading in chunks."""
    change_indices = []
    with h5py.File(h5_path, "r") as f:
        ds = f[dataset_name]
        n = ds.shape[0]
        prev_last_pose = None
        for offset in range(0, n, chunk_size):
            end = min(offset + chunk_size, n)
            chunk = ds[offset:end]
            ex, ey, ez, eyaw = chunk["ego_x"], chunk["ego_y"], chunk["ego_z"], chunk["ego_yaw"]
            if prev_last_pose is not None:
                cur_first = (int(ex[0]), int(ey[0]), int(ez[0]), int(eyaw[0]))
                if cur_first != prev_last_pose:
                    change_indices.append(offset)
            changes = np.where(
                (np.diff(ex) != 0) | (np.diff(ey) != 0) |
                (np.diff(ez) != 0) | (np.diff(eyaw) != 0)
            )[0] + 1
            for c in changes:
                change_indices.append(offset + int(c))
            prev_last_pose = (int(ex[-1]), int(ey[-1]), int(ez[-1]), int(eyaw[-1]))
            del chunk, ex, ey, ez, eyaw
    starts = [0] + change_indices
    ends = change_indices + [n]
    frames = []
    with h5py.File(h5_path, "r") as f:
        ds = f[dataset_name]
        for s, e in zip(starts, ends):
            row = ds[s]
            frames.append((s, e, int(row["ego_x"]), int(row["ego_y"]),
                           int(row["ego_z"]), int(row["ego_yaw"])))
    return frames


def map_rgb_to_class(r, g, b):
    class_ids = np.zeros(len(r), dtype=np.int64)
    for (cr, cg, cb), cid in CLASS_COLORS.items():
        mask = (r == cr) & (g == cg) & (b == cb)
        class_ids[mask] = cid
    return class_ids


class LidarSegDatasetV3(Dataset):
    """Dataset with balanced obstacle/background sampling."""

    def __init__(self, data_dir, scene_files, num_points=65536,
                 obstacle_ratio=0.5, augment=False):
        self.data_dir = data_dir
        self.num_points = num_points
        self.obstacle_ratio = obstacle_ratio
        self.augment = augment
        self.index = []
        for sf in scene_files:
            h5_path = os.path.join(data_dir, sf)
            scene_name = sf.replace(".h5", "")
            if not os.path.exists(h5_path):
                continue
            frames = get_frame_boundaries(h5_path)
            for idx, (start, end, ex, ey, ez, eyaw) in enumerate(frames):
                self.index.append((h5_path, start, end, ex, ey, ez, eyaw, scene_name, idx))
        print(f"LidarSegDatasetV3: {len(self.index)} frames, "
              f"{num_points} pts, {obstacle_ratio*100:.0f}% obstacle target")

    def __len__(self):
        return len(self.index)

    def _convert_to_features(self, valid):
        """Convert raw HDF5 fields to (xyz, dist_norm, refl, labels)."""
        dist_m = valid["distance_cm"].astype(np.float64) / 100.0
        az_rad = np.radians(valid["azimuth_raw"].astype(np.float64) / 100.0)
        el_rad = np.radians(valid["elevation_raw"].astype(np.float64) / 100.0)
        cos_el = np.cos(el_rad)
        x = dist_m * cos_el * np.cos(az_rad)
        y = -dist_m * cos_el * np.sin(az_rad)
        z = dist_m * np.sin(el_rad)
        xyz = np.column_stack((x, y, z)).astype(np.float32)

        # Normalized distance (0-1 range, max ~3km for LiDAR)
        dist_norm = (dist_m / 300.0).astype(np.float32).reshape(-1, 1)  # 300m normalization
        refl = (valid["reflectivity"].astype(np.float32) / 255.0).reshape(-1, 1)

        labels = map_rgb_to_class(
            valid["r"].astype(np.uint8),
            valid["g"].astype(np.uint8),
            valid["b"].astype(np.uint8)
        )
        del dist_m, az_rad, el_rad, cos_el, x, y, z
        return xyz, dist_norm, refl, labels

    def _balanced_sample(self, xyz, dist_norm, refl, labels):
        """Sample with balanced obstacle/background ratio."""
        obs_mask = labels > 0
        n_obs = obs_mask.sum()
        n_bg = len(labels) - n_obs

        if n_obs == 0 or n_bg == 0:
            # No obstacles or no background: random sample
            n_pts = len(labels)
            if n_pts >= self.num_points:
                idx = np.random.choice(n_pts, self.num_points, replace=False)
            else:
                idx = np.random.choice(n_pts, self.num_points, replace=True)
            return xyz[idx], dist_norm[idx], refl[idx], labels[idx]

        # Target counts
        n_obs_target = int(self.num_points * self.obstacle_ratio)
        n_bg_target = self.num_points - n_obs_target

        obs_idx = np.where(obs_mask)[0]
        bg_idx = np.where(~obs_mask)[0]

        # Sample obstacles (with replacement if not enough)
        if n_obs >= n_obs_target:
            sel_obs = np.random.choice(obs_idx, n_obs_target, replace=False)
        else:
            sel_obs = np.random.choice(obs_idx, n_obs_target, replace=True)

        # Sample background
        if n_bg >= n_bg_target:
            sel_bg = np.random.choice(bg_idx, n_bg_target, replace=False)
        else:
            sel_bg = np.random.choice(bg_idx, n_bg_target, replace=True)

        idx = np.concatenate([sel_obs, sel_bg])
        np.random.shuffle(idx)
        return xyz[idx], dist_norm[idx], refl[idx], labels[idx]

    def __getitem__(self, i):
        h5_path, start, end, ex, ey, ez, eyaw, scene_name, frame_idx = self.index[i]
        with h5py.File(h5_path, "r") as f:
            chunk = f["lidar_points"][start:end]
        valid = chunk[chunk["distance_cm"] > 0]
        del chunk

        n_pts = len(valid)
        if n_pts == 0:
            return (torch.zeros(self.num_points, IN_CHANNELS, dtype=torch.float32),
                    torch.zeros(self.num_points, dtype=torch.int64))

        xyz, dist_norm, refl, labels = self._convert_to_features(valid)
        del valid

        # Balanced sampling
        xyz, dist_norm, refl, labels = self._balanced_sample(xyz, dist_norm, refl, labels)

        # Augmentation
        if self.augment:
            # Random Z-rotation
            theta = np.random.uniform(0, 2 * np.pi)
            c, s = np.cos(theta), np.sin(theta)
            rot = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]], dtype=np.float32)
            xyz = xyz @ rot.T
            # Jitter
            xyz += np.random.normal(0, 0.02, xyz.shape).astype(np.float32)
            # Random scale
            scale = np.random.uniform(0.9, 1.1)
            xyz *= scale
            # Random flip X or Y
            if np.random.random() > 0.5:
                xyz[:, 0] *= -1
            if np.random.random() > 0.5:
                xyz[:, 1] *= -1
            # Random point drop (simulates density reduction)
            drop = np.random.uniform(0.0, 0.25)
            n_drop = int(self.num_points * drop)
            if n_drop > 0:
                keep = np.random.choice(self.num_points, self.num_points - n_drop, replace=False)
                fill = np.random.choice(keep, n_drop, replace=True)
                order = np.concatenate([keep, fill])
                np.random.shuffle(order)
                xyz = xyz[order]; dist_norm = dist_norm[order]
                refl = refl[order]; labels = labels[order]

        features = np.concatenate([xyz, refl, dist_norm], axis=1)  # (N, 5)
        return torch.from_numpy(features), torch.from_numpy(labels)

print("Dataset v3 with balanced sampling defined.")

## 3. Model + Focal Loss

In [None]:
class FocalLoss(nn.Module):
    """Focal Loss for extreme class imbalance."""
    def __init__(self, alpha, gamma=2.0):
        super().__init__()
        self.register_buffer('alpha', torch.tensor(alpha, dtype=torch.float32))
        self.gamma = gamma

    def forward(self, logits, targets):
        ce = F.cross_entropy(logits, targets, reduction='none')
        pt = torch.exp(-ce)
        alpha_t = self.alpha[targets]
        loss = alpha_t * (1 - pt) ** self.gamma * ce
        return loss.mean()


class SharedMLP(nn.Module):
    def __init__(self, in_ch, out_ch, bn=True):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, 1, bias=not bn)
        self.bn = nn.BatchNorm1d(out_ch) if bn else None
    def forward(self, x):
        x = self.conv(x)
        if self.bn: x = self.bn(x)
        return F.relu(x, inplace=True)


class PointNetLiteSegV3(nn.Module):
    """Wider PointNet with multi-scale skip connections.
    
    Encoder: 5 -> 64 -> 128 -> 256 -> 512
    Decoder: concat(e1, e2, e3, global) -> 256 -> 128 -> 5
    """
    def __init__(self, in_channels=5, num_classes=5):
        super().__init__()
        # Encoder
        self.enc1 = SharedMLP(in_channels, 64)
        self.enc2 = SharedMLP(64, 128)
        self.enc3 = SharedMLP(128, 256)
        self.enc4 = SharedMLP(256, 512)

        # Decoder: multi-scale skip connections
        self.seg1 = SharedMLP(64 + 128 + 256 + 512, 256)
        self.seg2 = SharedMLP(256, 128)
        self.dropout = nn.Dropout(0.3)
        self.seg3 = nn.Conv1d(128, num_classes, 1)

    def forward(self, x):
        B, N, _ = x.shape
        x = x.transpose(1, 2)  # (B, C, N)

        e1 = self.enc1(x)    # (B, 64, N)
        e2 = self.enc2(e1)   # (B, 128, N)
        e3 = self.enc3(e2)   # (B, 256, N)
        e4 = self.enc4(e3)   # (B, 512, N)

        # Global feature
        g = e4.max(dim=2, keepdim=True)[0].expand(-1, -1, N)  # (B, 512, N)

        # Multi-scale concat
        seg = torch.cat([e1, e2, e3, g], dim=1)  # (B, 960, N)
        seg = self.seg1(seg)
        seg = self.dropout(seg)
        seg = self.seg2(seg)
        seg = self.seg3(seg)
        return seg.transpose(1, 2)  # (B, N, C)


model = PointNetLiteSegV3(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"PointNet-lite v3: {n_params:,} parameters on {device}")

# Quick shape test
with torch.no_grad():
    test = torch.randn(2, 1000, IN_CHANNELS, device=device)
    out = model(test)
    print(f"Test forward: {test.shape} -> {out.shape}")

# VRAM estimate
if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()
    with torch.no_grad():
        test_full = torch.randn(BATCH_SIZE, NUM_POINTS, IN_CHANNELS, device=device)
        out_full = model(test_full)
    peak = torch.cuda.max_memory_allocated() / 1e9
    print(f"Inference VRAM estimate: {peak:.2f} GB (training ~3x = {peak*3:.1f} GB)")
    del test_full, out_full
    torch.cuda.empty_cache()

## 4. Load Data & Create Loaders

In [None]:
%%time

print("Loading dataset (indexing frame boundaries)...")
train_dataset = LidarSegDatasetV3(DATA_DIR, SCENE_FILES, num_points=NUM_POINTS,
                                   obstacle_ratio=OBSTACLE_RATIO, augment=True)
val_dataset = LidarSegDatasetV3(DATA_DIR, SCENE_FILES, num_points=NUM_POINTS,
                                 obstacle_ratio=OBSTACLE_RATIO, augment=False)

# Split by scene
train_idx = [i for i, e in enumerate(train_dataset.index) if e[7] != VAL_SCENE]
val_idx = [i for i, e in enumerate(val_dataset.index) if e[7] == VAL_SCENE]

train_subset = Subset(train_dataset, train_idx)
val_subset = Subset(val_dataset, val_idx)

print(f"Train: {len(train_subset)}, Val: {len(val_subset)}")

train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=2, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=2, pin_memory=True)

# Quick sanity check: verify balanced sampling
print("\nSanity check — class distribution in first batch:")
features_sample, labels_sample = next(iter(train_loader))
total = labels_sample.numel()
for c in range(NUM_CLASSES):
    count = (labels_sample == c).sum().item()
    print(f"  {CLASS_NAMES[c]:15s}: {count:8d} ({count/total*100:5.1f}%)")
obs_pct = (labels_sample > 0).sum().item() / total * 100
print(f"  {'OBSTACLE TOTAL':15s}: {(labels_sample > 0).sum().item():8d} ({obs_pct:5.1f}%)")
del features_sample, labels_sample

## 5. Training Loop

In [None]:
def compute_metrics(preds, labels):
    metrics = {}
    metrics["accuracy"] = (preds == labels).sum().item() / labels.numel()

    obstacle_ious = []
    all_ious = []
    for c in range(NUM_CLASSES):
        intersection = ((preds == c) & (labels == c)).sum().item()
        union = ((preds == c) | (labels == c)).sum().item()
        iou = intersection / union if union > 0 else float("nan")
        metrics[f"iou_{CLASS_NAMES[c]}"] = iou
        if union > 0:
            all_ious.append(iou)
        if c in OBSTACLE_CLASSES and union > 0:
            obstacle_ious.append(iou)

    metrics["obstacle_miou"] = np.nanmean(obstacle_ious) if obstacle_ious else 0.0
    metrics["mean_iou"] = np.nanmean(all_ious) if all_ious else 0.0
    return metrics


def train_one_epoch(model, loader, optimizer, criterion, max_grad_norm):
    model.train()
    total_loss, total_correct, total_pts = 0, 0, 0
    for features, labels in loader:
        features, labels = features.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(features)
        loss = criterion(logits.reshape(-1, NUM_CLASSES), labels.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        total_loss += loss.item() * features.size(0)
        total_correct += (logits.argmax(-1) == labels).sum().item()
        total_pts += labels.numel()
    return total_loss / len(loader.dataset), total_correct / total_pts


@torch.no_grad()
def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    for features, labels in loader:
        features, labels = features.to(device), labels.to(device)
        logits = model(features)
        loss = criterion(logits.reshape(-1, NUM_CLASSES), labels.reshape(-1))
        total_loss += loss.item() * features.size(0)
        all_preds.append(logits.argmax(-1).cpu())
        all_labels.append(labels.cpu())
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    metrics = compute_metrics(all_preds, all_labels)
    metrics["loss"] = total_loss / len(loader.dataset)
    return metrics

print("Training functions defined.")

In [None]:
# === TRAINING ===
criterion = FocalLoss(alpha=FOCAL_ALPHA, gamma=FOCAL_GAMMA).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)

# Warmup + Cosine schedule
def lr_lambda(epoch):
    if epoch < WARMUP_EPOCHS:
        return (epoch + 1) / WARMUP_EPOCHS
    progress = (epoch - WARMUP_EPOCHS) / max(EPOCHS - WARMUP_EPOCHS, 1)
    return 0.5 * (1 + np.cos(np.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

best_obs_miou = 0.0
history = []

print(f"Training for {EPOCHS} epochs on {device}")
print(f"Focal Loss (gamma={FOCAL_GAMMA}), LR={LR}, warmup={WARMUP_EPOCHS}")
print(f"Balanced sampling: {OBSTACLE_RATIO*100:.0f}% obstacles, {NUM_POINTS} pts, batch={BATCH_SIZE}")
print()
print(f"{'Ep':>3} | {'TrLoss':>7} {'TrAcc':>6} | {'VaLoss':>7} {'ObmIoU':>6} | "
      f"{'Ant':>5} {'Cab':>5} {'Pol':>5} {'Tur':>5} | {'BG':>5} | {'LR':>8} {'T':>4}")
print("-" * 95)

t_total = time.time()

for epoch in range(1, EPOCHS + 1):
    t0 = time.time()

    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, MAX_GRAD_NORM)
    val_m = validate(model, val_loader, criterion)
    scheduler.step()

    elapsed = time.time() - t0
    lr = optimizer.param_groups[0]["lr"]

    log = {
        "epoch": epoch, "train_loss": train_loss, "train_acc": train_acc,
        "val_loss": val_m["loss"], "val_acc": val_m["accuracy"],
        "val_obstacle_miou": val_m["obstacle_miou"],
        "val_miou": val_m["mean_iou"], "lr": lr, "time_s": elapsed,
    }
    for c in range(NUM_CLASSES):
        log[f"val_iou_{CLASS_NAMES[c]}"] = val_m.get(f"iou_{CLASS_NAMES[c]}", 0)
    history.append(log)

    ant = val_m.get("iou_antenna", 0) or 0
    cab = val_m.get("iou_cable", 0) or 0
    pol = val_m.get("iou_electric_pole", 0) or 0
    tur = val_m.get("iou_wind_turbine", 0) or 0
    bg = val_m.get("iou_background", 0) or 0

    marker = ""
    if val_m["obstacle_miou"] > best_obs_miou:
        best_obs_miou = val_m["obstacle_miou"]
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "val_obstacle_miou": best_obs_miou,
            "val_metrics": val_m,
            "n_params": n_params,
        }, os.path.join(CKPT_DIR, "best_model_v3.pt"))
        marker = f" ** BEST"

    print(f"{epoch:3d} | {train_loss:7.4f} {train_acc:6.4f} | "
          f"{val_m['loss']:7.4f} {val_m['obstacle_miou']:6.4f} | "
          f"{ant:5.3f} {cab:5.3f} {pol:5.3f} {tur:5.3f} | "
          f"{bg:5.3f} | {lr:8.6f} {elapsed:4.0f}s{marker}")

    # Checkpoint every 20 epochs
    if epoch % 20 == 0:
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
        }, os.path.join(CKPT_DIR, f"checkpoint_v3_epoch{epoch}.pt"))

    # VRAM monitoring (first epoch only)
    if epoch == 1 and torch.cuda.is_available():
        vram_used = torch.cuda.max_memory_allocated() / 1e9
        print(f"  [VRAM peak: {vram_used:.1f} GB / {vram:.1f} GB ({vram_used/vram*100:.0f}%)]")

total_time = time.time() - t_total
print(f"\n{'='*60}")
print(f"TRAINING COMPLETE in {total_time:.0f}s ({total_time/60:.1f} min)")
print(f"Best obstacle mIoU: {best_obs_miou:.4f}")
print(f"Model: {n_params:,} params")

## 6. Save History & ONNX Export

In [None]:
# Save training history
with open(os.path.join(CKPT_DIR, "training_history_v3.json"), "w") as f:
    json.dump(history, f, indent=2)
print("History saved.")

# Load best model and export to ONNX
best_ckpt = torch.load(os.path.join(CKPT_DIR, "best_model_v3.pt"),
                        map_location=device, weights_only=False)
model.load_state_dict(best_ckpt["model_state_dict"])
model.eval()

dummy_input = torch.randn(1, NUM_POINTS, IN_CHANNELS, device=device)
onnx_path = os.path.join(CKPT_DIR, "pointnet_lite_v3.onnx")

torch.onnx.export(
    model, dummy_input, onnx_path,
    input_names=["points"],
    output_names=["logits"],
    dynamic_axes={"points": {0: "batch", 1: "num_points"},
                  "logits": {0: "batch", 1: "num_points"}},
    opset_version=17,
)

onnx_size = os.path.getsize(onnx_path) / 1e6
print(f"\nONNX exported: {onnx_path}")
print(f"ONNX size: {onnx_size:.2f} MB")
print(f"Best epoch: {best_ckpt['epoch']}, obstacle mIoU: {best_ckpt['val_obstacle_miou']:.4f}")

## 7. Training Curves

In [None]:
import matplotlib.pyplot as plt

epochs_list = [h["epoch"] for h in history]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss
axes[0].plot(epochs_list, [h["train_loss"] for h in history], label="Train")
axes[0].plot(epochs_list, [h["val_loss"] for h in history], label="Val")
axes[0].set_xlabel("Epoch"); axes[0].set_ylabel("Loss")
axes[0].set_title("Focal Loss"); axes[0].legend()

# Obstacle mIoU
axes[1].plot(epochs_list, [h["val_obstacle_miou"] for h in history], 'g-', linewidth=2, label="Obstacle mIoU")
axes[1].plot(epochs_list, [h["val_miou"] for h in history], 'b--', alpha=0.5, label="All mIoU")
axes[1].set_xlabel("Epoch"); axes[1].set_ylabel("mIoU")
axes[1].set_title(f"Val mIoU (best obs={best_obs_miou:.4f})"); axes[1].legend()

# Per-class IoU
colors = {'antenna': 'blue', 'cable': 'orange', 'electric_pole': 'green', 'wind_turbine': 'red'}
for c in range(1, 5):
    name = CLASS_NAMES[c]
    vals = [h.get(f"val_iou_{name}", 0) or 0 for h in history]
    axes[2].plot(epochs_list, vals, label=name, color=colors[name])
axes[2].set_xlabel("Epoch"); axes[2].set_ylabel("IoU")
axes[2].set_title("Per-class IoU (obstacles)"); axes[2].legend()

plt.tight_layout()
plt.savefig(os.path.join(CKPT_DIR, "training_curves_v3.png"), dpi=150)
plt.show()
print("Curves saved.")

## 8. Final Summary

In [None]:
print("=" * 60)
print("TRAINING SUMMARY (v3 — balanced sampling)")
print("=" * 60)
print(f"Model: PointNet-lite v3 ({n_params:,} params)")
print(f"Input: {IN_CHANNELS} features (x,y,z,refl,dist)")
print(f"Loss: Focal (gamma={FOCAL_GAMMA})")
print(f"Balanced sampling: {OBSTACLE_RATIO*100:.0f}% obstacles")
print(f"Training: {EPOCHS} epochs, batch={BATCH_SIZE}, lr={LR}, warmup={WARMUP_EPOCHS}")
print(f"Points per sample: {NUM_POINTS:,}")
print(f"Train frames: {len(train_subset)}, Val frames: {len(val_subset)} (scene_8)")
print(f"Best obstacle mIoU: {best_obs_miou:.4f} (epoch {best_ckpt['epoch']})")
print(f"ONNX: {onnx_size:.2f} MB")
print(f"Total training time: {total_time/60:.1f} min")

if torch.cuda.is_available():
    vram_used = torch.cuda.max_memory_allocated() / 1e9
    print(f"VRAM peak: {vram_used:.1f} GB / {vram:.1f} GB")

print(f"\nPer-class IoU (best epoch):")
for c in range(NUM_CLASSES):
    iou = best_ckpt['val_metrics'].get(f'iou_{CLASS_NAMES[c]}', 0)
    if iou != iou: iou = 0  # NaN check
    marker = " <-- obstacle" if c > 0 else ""
    print(f"  {CLASS_NAMES[c]:15s}: {iou:.4f}{marker}")

print(f"\nv1 -> v2 -> v3 comparison:")
print(f"  v1: obstacle mIoU ~ 0.05 (weighted CE, no balancing)")
print(f"  v2: obstacle mIoU ~ 0.03 (focal loss, still unbalanced)")
print(f"  v3: obstacle mIoU = {best_obs_miou:.4f} (focal + balanced sampling)")