# KITTI 2D Object Detection Training

This notebook expects you to download the KITTI 2D detection dataset yourself and set `DATASET_ROOT` to the dataset root directory.

Expected structure:
- data_object_image_2/training/image_2
- data_object_label_2/training/label_2


In [None]:
import sys
import subprocess

subprocess.check_call(
    [sys.executable, "-m", "pip", "install", "-q", "opencv-python", "matplotlib"]
)


In [None]:
import os
import time
import random
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple, Optional

import numpy as np
import cv2

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.ops import box_iou
from torch.optim.sgd import SGD
from tqdm.auto import tqdm


In [None]:
@dataclass
class Config:
    IMG_HEIGHT: int = 384
    IMG_WIDTH: int = 1248

    TRAIN_BATCH_SIZE: int = 8
    VAL_BATCH_SIZE: int = 8

    BASE_LR: float = 0.01
    BACKBONE_LR_MULT: float = 0.1

    MAX_EPOCHS: int = 24
    SEED: int = 42
    VAL_SIZE: int = 1000

    DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    SCORE_THRESH_EVAL: float = 0.05
    MAX_DETECTIONS_PER_IMAGE: int = 100
    IOU_THRESHOLDS: Tuple[float, ...] = tuple(np.round(np.arange(0.50, 0.96, 0.05), 2).tolist())

    LR_PLATEAU_FACTOR: float = 0.5
    LR_PLATEAU_PATIENCE: int = 2
    LR_MIN: float = 1e-5

    GRAD_CLIP_NORM: float = 2.0


NUM_CLASSES = 4
SAVE_DIR = Path("outputs/train_v2")


In [None]:
DATASET_ROOT = Path("{YOUR_PROJECT}/data/kitti_detection/")  # TODO: change to your KITTI root
IMAGES_DIR = DATASET_ROOT / "data_object_image_2" / "training" / "image_2"
LABELS_DIR = DATASET_ROOT / "data_object_label_2" / "training" / "label_2"

if not IMAGES_DIR.exists() or not LABELS_DIR.exists():
    raise FileNotFoundError(
        "KITTI paths not found. Set DATASET_ROOT so that "
        "data_object_image_2/training/image_2 and "
        "data_object_label_2/training/label_2 exist."
    )

print(f"--> Images: {IMAGES_DIR}")
print(f"--> Labels: {LABELS_DIR}")


In [None]:
def seed_everything(seed: int = 42) -> None:
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def ensure_dir(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


def config_to_dict() -> dict:
    cfg = asdict(Config())
    cfg["DEVICE"] = str(cfg["DEVICE"])
    return cfg


def plot_history(history: List[dict], save_dir: Path) -> None:
    if not history:
        print("--> No history to plot.")
        return

    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    epochs = [h.get("epoch") for h in history]
    train_loss = [h.get("train_loss") for h in history]
    val_loss = [h.get("val_loss") for h in history]
    map50 = [h.get("map50") for h in history]
    map5095 = [h.get("map5095") for h in history]

    lrs = [h.get("lrs", []) for h in history]
    lr_groups: List[List[float]] = []
    if lrs and lrs[0]:
        num_groups = len(lrs[0])
        for gi in range(num_groups):
            lr_groups.append(
                [float(v[gi]) if v and gi < len(v) else float("nan") for v in lrs]
            )

    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    axes[0].plot(epochs, train_loss, label="train_loss")
    axes[0].plot(epochs, val_loss, label="val_loss")
    axes[0].set_xlabel("epoch")
    axes[0].set_ylabel("loss")
    axes[0].set_title("Loss")
    axes[0].grid(True, alpha=0.3)
    axes[0].legend()

    axes[1].plot(epochs, map50, label="mAP@0.50")
    axes[1].plot(epochs, map5095, label="mAP@0.50:0.95")
    axes[1].set_xlabel("epoch")
    axes[1].set_ylabel("mAP")
    axes[1].set_title("Metrics")
    axes[1].set_ylim(0.0, 1.0)
    axes[1].grid(True, alpha=0.3)
    axes[1].legend()

    if lr_groups:
        for gi, vals in enumerate(lr_groups):
            axes[2].plot(epochs, vals, label=f"lr_g{gi}")
        lr_values = [v for group in lr_groups for v in group if np.isfinite(v)]
        if lr_values and all(v > 0 for v in lr_values):
            axes[2].set_yscale("log")
        axes[2].set_xlabel("epoch")
        axes[2].set_ylabel("lr")
        axes[2].set_title("Learning Rate")
        axes[2].grid(True, alpha=0.3)
        axes[2].legend()
    else:
        axes[2].text(0.5, 0.5, "no lr data", ha="center", va="center")
        axes[2].set_axis_off()

    fig.tight_layout()
    ensure_dir(save_dir)
    out_path = save_dir / "training_curves.png"
    fig.savefig(out_path, dpi=150)
    plt.close(fig)
    print(f"--> Curves saved: {out_path}")


def save_training_history(history: List[dict], save_dir: Path, best_path: Path) -> None:
    hist_path = save_dir / "history.npy"
    np.save(hist_path, np.array(history, dtype=object), allow_pickle=True)
    print(f"--> History saved: {hist_path}")
    plot_history(history, save_dir)
    print(f"--> Best model: {best_path}")


In [None]:
class KittiDataset(Dataset):
    def __init__(self, images_dir: Path, labels_dir: Path):
        self.images_dir = Path(images_dir)
        self.labels_dir = Path(labels_dir)

        self.class_map = {"Car": 1, "Pedestrian": 2, "Cyclist": 3}
        self.id_to_name = {v: k for k, v in self.class_map.items()}

        if not self.images_dir.exists():
            raise FileNotFoundError(f"Images directory not found: {self.images_dir}")
        if not self.labels_dir.exists():
            raise FileNotFoundError(f"Labels directory not found: {self.labels_dir}")

        self.imgs = sorted(
            [
                f
                for f in os.listdir(self.images_dir)
                if f.lower().endswith((".png", ".jpg", ".jpeg"))
            ]
        )

    def __len__(self) -> int:
        return len(self.imgs)

    def __getitem__(self, sample_idx: int):
        img_name = self.imgs[sample_idx]
        img_path = self.images_dir / img_name

        img = cv2.imread(str(img_path))
        if img is None:
            raise FileNotFoundError(f"Failed to read image: {img_path}")

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        orig_h, orig_w, _ = img.shape

        img = cv2.resize(img, (Config.IMG_WIDTH, Config.IMG_HEIGHT), interpolation=cv2.INTER_LINEAR)
        img /= 255.0

        label_file = self.labels_dir / (Path(img_name).stem + ".txt")
        boxes = []
        labels = []

        if label_file.exists():
            with open(label_file, "r") as f:
                for line in f:
                    parts = line.strip().split()
                    if not parts:
                        continue
                    cls = parts[0]
                    if len(parts) < 8:
                        continue
                    if cls in self.class_map:
                        label_id = self.class_map[cls]
                        xmin, ymin, xmax, ymax = map(float, parts[4:8])
                        boxes.append([xmin, ymin, xmax, ymax])
                        labels.append(label_id)

        if boxes:
            boxes = torch.tensor(boxes, dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.int64)

            w_scale = Config.IMG_WIDTH / float(orig_w)
            h_scale = Config.IMG_HEIGHT / float(orig_h)

            boxes[:, [0, 2]] *= w_scale
            boxes[:, [1, 3]] *= h_scale

            boxes[:, [0, 2]] = boxes[:, [0, 2]].clamp(0, Config.IMG_WIDTH - 1)
            boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(0, Config.IMG_HEIGHT - 1)

            keep = (boxes[:, 2] > boxes[:, 0]) & (boxes[:, 3] > boxes[:, 1])
            boxes = boxes[keep]
            labels = labels[keep]
        else:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)

        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor(sample_idx, dtype=torch.int64),
        }

        img_tensor = torch.from_numpy(img).permute(2, 0, 1).contiguous()
        return img_tensor, target


def collate_fn(batch):
    return tuple(zip(*batch))


def move_batch_to_device(images, targets, device):
    images = [img.to(device) for img in images]
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
    return images, targets


def build_dataloader(dataset, batch_size: int, shuffle: bool) -> DataLoader:
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=collate_fn,
    )


In [None]:
def freeze_batchnorm(module: torch.nn.Module) -> None:
    for m in module.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.eval()


def build_model(num_classes: int) -> torch.nn.Module:
    from torchvision.models.detection import (
        fasterrcnn_resnet50_fpn_v2,
        FasterRCNN_ResNet50_FPN_V2_Weights,
    )

    weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
    model = fasterrcnn_resnet50_fpn_v2(
        weights=weights,
        min_size=Config.IMG_HEIGHT,
        max_size=Config.IMG_WIDTH,
    )

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
        in_features, num_classes
    )
    return model


def build_optimizer(model: torch.nn.Module) -> torch.optim.Optimizer:
    backbone_params = []
    head_params = []

    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if name.startswith("backbone."):
            backbone_params.append(p)
        else:
            head_params.append(p)

    param_groups = [
        {"params": backbone_params, "lr": Config.BASE_LR * Config.BACKBONE_LR_MULT},
        {"params": head_params, "lr": Config.BASE_LR},
    ]

    optimizer = SGD(param_groups, lr=Config.BASE_LR, momentum=0.9, weight_decay=1e-4)
    return optimizer


def build_lr_scheduler(optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.ReduceLROnPlateau:
    return torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="max",
        factor=Config.LR_PLATEAU_FACTOR,
        patience=Config.LR_PLATEAU_PATIENCE,
        min_lr=Config.LR_MIN,
    )


In [None]:
def safe_torch_save(payload: dict, path: Path) -> None:
    ensure_dir(path.parent)
    tmp_path = path.with_suffix(path.suffix + ".tmp")
    torch.save(payload, tmp_path)
    os.replace(tmp_path, path)


@torch.no_grad()
def evaluate_loss(model, data_loader, device):
    was_training = model.training
    model.train()
    freeze_batchnorm(model)

    val_loss_sum = 0.0
    num_batches = 0

    for images, targets in data_loader:
        images, targets = move_batch_to_device(images, targets, device)

        loss_dict = model(images, targets)
        loss = sum(v for v in loss_dict.values())

        val_loss_sum += float(loss.detach().item())
        num_batches += 1

    if not was_training:
        model.eval()

    return val_loss_sum / max(1, num_batches)


def _ap_from_pr_101(recall: np.ndarray, precision: np.ndarray) -> float:
    if recall.size == 0:
        return 0.0

    mrec = np.concatenate(([0.0], recall, [1.0]))
    mpre = np.concatenate(([0.0], precision, [0.0]))

    for i in range(mpre.size - 1, 0, -1):
        mpre[i - 1] = max(mpre[i - 1], mpre[i])

    recall_levels = np.linspace(0.0, 1.0, 101)
    prec_at_recall = np.zeros_like(recall_levels)

    for i, r in enumerate(recall_levels):
        inds = np.where(mrec >= r)[0]
        prec_at_recall[i] = np.max(mpre[inds]) if inds.size > 0 else 0.0

    return float(np.mean(prec_at_recall))


def _compute_ap_for_class(
    detections: List[Tuple[int, torch.Tensor, float]],
    gts: Dict[int, torch.Tensor],
    iou_thr: float,
) -> Optional[float]:
    npos = int(sum(int(v.shape[0]) for v in gts.values()))
    if npos == 0:
        return None

    dets = sorted(detections, key=lambda x: x[2], reverse=True)
    if len(dets) == 0:
        return 0.0

    matched: Dict[int, torch.Tensor] = {}
    for img_id, gt_boxes in gts.items():
        matched[img_id] = torch.zeros((gt_boxes.shape[0],), dtype=torch.bool)

    tp = np.zeros((len(dets),), dtype=np.float32)
    fp = np.zeros((len(dets),), dtype=np.float32)

    empty = torch.zeros((0, 4), dtype=torch.float32)

    for i, (img_id, det_box, det_score) in enumerate(dets):
        gt_boxes = gts.get(img_id, empty)
        if gt_boxes.numel() == 0:
            fp[i] = 1.0
            continue

        ious = box_iou(det_box.unsqueeze(0), gt_boxes).squeeze(0)
        max_iou, max_j = torch.max(ious, dim=0)
        max_iou_v = float(max_iou.item())
        j = int(max_j.item())

        if max_iou_v >= iou_thr and (not bool(matched[img_id][j].item())):
            tp[i] = 1.0
            matched[img_id][j] = True
        else:
            fp[i] = 1.0

    tp_cum = np.cumsum(tp)
    fp_cum = np.cumsum(fp)

    recall = tp_cum / (npos + 1e-12)
    precision = tp_cum / (tp_cum + fp_cum + 1e-12)

    return _ap_from_pr_101(recall, precision)


@torch.inference_mode()
def evaluate_map(
    model: torch.nn.Module,
    data_loader: DataLoader,
    device: torch.device,
    class_ids: List[int],
    iou_thresholds: Tuple[float, ...],
    score_thresh: float,
    max_dets_per_image: int,
) -> Dict[str, float]:
    model.eval()

    gts_by_class: Dict[int, Dict[int, torch.Tensor]] = {c: {} for c in class_ids}
    dets_by_class: Dict[int, List[Tuple[int, torch.Tensor, float]]] = {c: [] for c in class_ids}

    for images, targets in tqdm(data_loader, desc="Eval mAP", leave=False):
        images = [img.to(device) for img in images]
        outputs = model(images)

        for out, tgt in zip(outputs, targets):
            img_id = int(tgt["image_id"].item())

            gt_boxes = tgt["boxes"].cpu()
            gt_labels = tgt["labels"].cpu()

            for c in class_ids:
                mask = gt_labels == c
                gts_by_class[c][img_id] = gt_boxes[mask].float()

            pred_boxes = out["boxes"].detach().cpu().float()
            pred_scores = out["scores"].detach().cpu().float()
            pred_labels = out["labels"].detach().cpu().long()

            keep = pred_scores >= float(score_thresh)
            pred_boxes = pred_boxes[keep]
            pred_scores = pred_scores[keep]
            pred_labels = pred_labels[keep]

            if pred_scores.numel() > max_dets_per_image:
                topk = torch.topk(pred_scores, k=max_dets_per_image, largest=True).indices
                pred_boxes = pred_boxes[topk]
                pred_scores = pred_scores[topk]
                pred_labels = pred_labels[topk]

            for c in class_ids:
                cmask = pred_labels == c
                if cmask.any():
                    boxes_c = pred_boxes[cmask]
                    scores_c = pred_scores[cmask]
                    for b, s in zip(boxes_c, scores_c):
                        dets_by_class[c].append((img_id, b, float(s.item())))

    ap_per_thr: Dict[float, List[float]] = {t: [] for t in iou_thresholds}
    ap50_per_class: Dict[int, float] = {}

    for c in class_ids:
        gt_c = gts_by_class[c]
        det_c = dets_by_class[c]

        ap_list_for_c = []
        for t in iou_thresholds:
            ap = _compute_ap_for_class(det_c, gt_c, float(t))
            ap_list_for_c.append(ap)

        if ap_list_for_c[0] is not None:
            ap50_per_class[c] = float(ap_list_for_c[0])

        for t, ap in zip(iou_thresholds, ap_list_for_c):
            if ap is not None:
                ap_per_thr[t].append(float(ap))

    map50 = float(np.mean(ap_per_thr.get(0.50, []))) if ap_per_thr.get(0.50, []) else 0.0
    map5095_vals = []
    for t in iou_thresholds:
        vals = ap_per_thr.get(t, [])
        if vals:
            map5095_vals.append(float(np.mean(vals)))
    map5095 = float(np.mean(map5095_vals)) if map5095_vals else 0.0

    out_metrics: Dict[str, float] = {
        "mAP@0.50": map50,
        "mAP@0.50:0.95": map5095,
    }
    for c, ap in ap50_per_class.items():
        out_metrics[f"AP@0.50_cls{c}"] = ap

    return out_metrics


In [None]:
def main(images_dir: Path, labels_dir: Path) -> None:
    seed_everything(Config.SEED)

    print(f"--> torch={torch.__version__} torchvision={torchvision.__version__}")
    print(f"--> Device: {Config.DEVICE}")

    save_dir = SAVE_DIR
    ensure_dir(save_dir)
    print(f"--> Save dir: {save_dir}")

    dataset = KittiDataset(images_dir, labels_dir)
    if len(dataset) < 2:
        raise RuntimeError("Dataset too small.")

    val_size = min(Config.VAL_SIZE, max(1, len(dataset) // 10))
    train_size = len(dataset) - val_size

    g = torch.Generator().manual_seed(Config.SEED)
    dataset_train, dataset_val = torch.utils.data.random_split(
        dataset, [train_size, val_size], generator=g
    )

    train_loader = build_dataloader(dataset_train, Config.TRAIN_BATCH_SIZE, shuffle=True)
    val_loader = build_dataloader(dataset_val, Config.VAL_BATCH_SIZE, shuffle=False)

    print(f"--> Training samples: {len(dataset_train)} | Val samples: {len(dataset_val)}")

    model = build_model(num_classes=NUM_CLASSES)
    model.to(Config.DEVICE)

    optimizer = build_optimizer(model)
    lr_scheduler = build_lr_scheduler(optimizer)

    class_ids = sorted(list(dataset.class_map.values()))
    id_to_name = dataset.id_to_name

    best_map50 = -1.0
    best_path = save_dir / "best_model.pth"
    last_path = save_dir / "last.pth"

    history: List[dict] = []

    print("--> Sanity check: one forward pass...")
    model.train()
    freeze_batchnorm(model)

    images, targets = next(iter(train_loader))
    images, targets = move_batch_to_device(images, targets, Config.DEVICE)

    with torch.no_grad():
        loss_dict = model(images, targets)
        loss = sum(v for v in loss_dict.values())
    if not torch.isfinite(loss):
        raise RuntimeError(f"Sanity check loss is not finite: {loss.item()}")
    print({k: float(v) for k, v in loss_dict.items()})
    print(f"--> Sanity check total loss: {float(loss)}")

    print("--> Starting training...")
    last_checkpoint_state = None
    interrupted = False

    try:
        for epoch in range(Config.MAX_EPOCHS):
            model.train()
            freeze_batchnorm(model)

            epoch_loss = 0.0
            t0 = time.time()

            pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{Config.MAX_EPOCHS}", leave=True)
            for images, targets in pbar:
                images, targets = move_batch_to_device(images, targets, Config.DEVICE)

                optimizer.zero_grad(set_to_none=True)

                loss_dict = model(images, targets)
                losses = sum(v for v in loss_dict.values())

                losses.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=Config.GRAD_CLIP_NORM)
                optimizer.step()

                loss_val = float(losses.detach().item())
                epoch_loss += loss_val
                pbar.set_postfix(loss=f"{loss_val:.4f}")

            train_loss = epoch_loss / max(1, len(train_loader))
            val_loss = evaluate_loss(model, val_loader, Config.DEVICE)

            val_metrics = evaluate_map(
                model=model,
                data_loader=val_loader,
                device=Config.DEVICE,
                class_ids=class_ids,
                iou_thresholds=Config.IOU_THRESHOLDS,
                score_thresh=Config.SCORE_THRESH_EVAL,
                max_dets_per_image=Config.MAX_DETECTIONS_PER_IMAGE,
            )

            map50 = float(val_metrics.get("mAP@0.50", 0.0))
            map5095 = float(val_metrics.get("mAP@0.50:0.95", 0.0))

            lr_scheduler.step(map50)

            current_lrs = [pg["lr"] for pg in optimizer.param_groups]
            dt = time.time() - t0

            per_class_str_parts = []
            for c in class_ids:
                k = f"AP@0.50_cls{c}"
                if k in val_metrics:
                    per_class_str_parts.append(f"{id_to_name.get(c, str(c))}:{val_metrics[k]:.3f}")
            per_class_str = " ".join(per_class_str_parts) if per_class_str_parts else "n/a"

            print(
                f"Epoch {epoch + 1} | time {dt:.1f}s | "
                f"train_loss {train_loss:.4f} | val_loss {val_loss:.4f} | "
                f"mAP@0.50 {map50:.4f} | mAP@0.50:0.95 {map5095:.4f} | "
                f"AP@0.50 per class {per_class_str} | "
                f"lr {current_lrs}"
            )

            history.append(
                {
                    "epoch": epoch + 1,
                    "train_loss": train_loss,
                    "val_loss": val_loss,
                    "map50": map50,
                    "map5095": map5095,
                    "lrs": current_lrs,
                }
            )

            is_best = map50 > best_map50
            if is_best:
                best_map50 = map50

            checkpoint_state = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch + 1,
                "map50": map50,
                "map5095": map5095,
                "best_map50": best_map50,
                "config": config_to_dict(),
                "history": history,
            }

            last_checkpoint_state = checkpoint_state
            safe_torch_save(checkpoint_state, last_path)

            if is_best:
                safe_torch_save(checkpoint_state, best_path)
                print(f"--> Best checkpoint updated: {best_path} (mAP@0.50={best_map50:.4f})")
    except KeyboardInterrupt:
        interrupted = True
        print("--> Training interrupted by user.")
    except Exception as e:
        print(f"--> Training failed: {e}. Saving last checkpoint...")
        if last_checkpoint_state is not None:
            safe_torch_save(last_checkpoint_state, last_path)
        save_training_history(history, save_dir, best_path)
        raise

    if interrupted:
        if last_checkpoint_state is not None:
            safe_torch_save(last_checkpoint_state, last_path)
        save_training_history(history, save_dir, best_path)
        return

    save_training_history(history, save_dir, best_path)
    print("--> Training complete.")


In [None]:
main(IMAGES_DIR, LABELS_DIR)
