# CARLA Object-Graph GNN (Node Classification)

This notebook trains a **graph neural network (GNN)** on the CARLA object-detection
dataset, treating each image as a graph:

- **Nodes** = annotated objects (from VOC-style XML).
  - Node features = bbox geometry + CNN embedding of the cropped object.
- **Edges** = *k*-NN edges in XY space between object centers.
- **Task** = node classification (predict each object's class label).

The pipeline:

1. Build `class_to_idx` from all labels.
2. For each image, build a graph:
   - Extract object crops.
   - Embed with ResNet-18 (ImageNet-pretrained).
   - Concatenate geometry features.
   - Build k-NN edges.
3. Train a GATv2-based node classifier with
   - class-balanced weights,
   - validation-based early stopping.

---


In [None]:
import os
import glob
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from typing import List, Dict, Tuple

import random
from collections import Counter, defaultdict

import numpy as np
from PIL import Image, ImageDraw

import torch
import torch.nn as nn
from torch.utils.data import Dataset, Subset

import torchvision
import torchvision.transforms as T

from torch_geometric.data import Data as GeometricData
from torch_geometric.loader import DataLoader as GeoDataLoader
from torch_geometric.nn import GATv2Conv

from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import classification_report

from tqdm import tqdm
from IPython.display import display

SEED = 7
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DATA_ROOT = 'carla-object-detection-dataset'
IMG_TRAIN = os.path.join(DATA_ROOT, 'images', 'train')
LBL_TRAIN = os.path.join(DATA_ROOT, 'labels', 'train')
IMG_TEST  = os.path.join(DATA_ROOT, 'images', 'test')
LBL_TEST  = os.path.join(DATA_ROOT, 'labels', 'test')

KNN_K      = 8
EMBED_DIM  = 256
HIDDEN_DIM = 256
GNN_LAYERS = 2
BATCH_SIZE = 8
DROPOUT    = 0.5
EPOCHS     = 40
LR         = 3e-4
WEIGHT_DECAY = 5e-4
PATIENCE   = 8
VAL_FRACTION = 0.2

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE


  from .autonotebook import tqdm as notebook_tqdm


'cpu'

## 1) VOC Parsing & Geometry Features


In [2]:
def parse_voc_xml(xml_path: str) -> Tuple[Tuple[int, int], List[Tuple[str, List[int]]]]:
    """
    Parse VOC-style XML and return:
    - (width, height)
    - list of (class_name, [xmin, ymin, xmax, ymax])
    """
    tree = ET.parse(xml_path)
    root = tree.getroot()
    size = root.find('size')
    width = int(size.find('width').text)
    height = int(size.find('height').text)

    objects = []
    for obj in root.findall('object'):
        name = obj.find('name').text
        b = obj.find('bndbox')
        bbox = [
            int(b.find('xmin').text),
            int(b.find('ymin').text),
            int(b.find('xmax').text),
            int(b.find('ymax').text),
        ]
        objects.append((name, bbox))

    return (width, height), objects


def bbox_to_features(bbox, w: int, h: int) -> np.ndarray:
    """
    Convert a bounding box into normalized geometry features:
      [cx_norm, cy_norm, bw_norm, bh_norm, area_norm, aspect_ratio]
    """
    xmin, ymin, xmax, ymax = bbox
    bw = max(1, xmax - xmin)
    bh = max(1, ymax - ymin)
    cx = xmin + bw / 2.0
    cy = ymin + bh / 2.0

    cx_n = cx / w
    cy_n = cy / h
    bw_n = bw / w
    bh_n = bh / h
    area_n = (bw * bh) / float(w * h)
    aspect = bw / float(bh)

    return np.array([cx_n, cy_n, bw_n, bh_n, area_n, aspect], dtype=np.float32)


## 2) CNN Backbone for Visual Embeddings

We use a pretrained ResNet-18 as a crop encoder and project to `EMBED_DIM`.

We will:
- Use **stronger transforms** (jitter, blur) for **training crops**.
- Use **deterministic** transforms (resize + normalize) for **val/test** crops.


In [None]:
resnet = torchvision.models.resnet18(
    weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
)
backbone = nn.Sequential(*list(resnet.children())[:-1])
backbone.eval().to(DEVICE)

proj = nn.Sequential(
    nn.Flatten(),
    nn.Linear(512, EMBED_DIM),
)
proj.eval().to(DEVICE)


train_tx = T.Compose([
    T.Resize((224, 224)),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
    T.RandomApply([T.GaussianBlur(3)], p=0.2),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])

eval_tx = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])


def crop_and_embed(
    img: Image.Image,
    bbox,
    backbone: nn.Module,
    transform,
    proj: nn.Module,
    pad_scale: float = 1.5,
) -> np.ndarray:
    """
    Crop padded region around bbox, embed with CNN backbone and projection.
    """
    W, H = img.size
    xmin, ymin, xmax, ymax = bbox
    cx = (xmin + xmax) / 2.0
    cy = (ymin + ymax) / 2.0
    w = (xmax - xmin) * pad_scale
    h = (ymax - ymin) * pad_scale

    x1 = int(max(0, cx - w / 2.0))
    y1 = int(max(0, cy - h / 2.0))
    x2 = int(min(W, cx + w / 2.0))
    y2 = int(min(H, cy + h / 2.0))

    crop = img.crop((x1, y1, x2, y2))
    x = transform(crop).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        feat = backbone(x)
        emb  = proj(feat).cpu()

    return emb.squeeze(0).numpy()


## 3) Dataset & Graph Construction


In [None]:
@dataclass
class Sample:
    image_path: str
    xml_path: str


class CarlaGraphDataset(Dataset):
    """
    Dataset of graphs, one graph per image.

    For each XML:
      - Parse objects and bboxes.
      - Extract geometry + visual features per object.
      - Build k-NN graph in (cx, cy) space.
    """

    def __init__(
        self,
        img_dir: str,
        xml_dir: str,
        class_to_idx: Dict[str, int],
        backbone: nn.Module,
        proj: nn.Module,
        transform,
        device: str = DEVICE,
        knn_k: int = 4,
        embed_dim: int = 256,
        preload: bool = False,
        name: str = "split",
    ):
        super().__init__()
        self.img_dir = img_dir
        self.xml_dir = xml_dir
        self.class_to_idx = class_to_idx
        self.device = device
        self.knn_k = knn_k
        self.embed_dim = embed_dim
        self.preload = preload
        self.name = name

        self.backbone = backbone
        self.proj = proj
        self.transform = transform

        xmls = sorted(glob.glob(os.path.join(xml_dir, "*.xml")))
        self.samples: List[Sample] = []
        for xp in xmls:
            base = os.path.splitext(os.path.basename(xp))[0]
            ip_png = os.path.join(img_dir, f"{base}.png")
            ip_jpg = os.path.join(img_dir, f"{base}.jpg")

            if os.path.exists(ip_png):
                ip = ip_png
            elif os.path.exists(ip_jpg):
                ip = ip_jpg
            else:
                continue

            self.samples.append(Sample(ip, xp))

        self._cache = []
        if self.preload:
            print(f"[{self.name}] Preloading {len(self.samples)} graphs...")
            for i in tqdm(range(len(self.samples)), desc=f"Preloading-{self.name}"):
                self._cache.append(self._make_graph(i))

    def __len__(self) -> int:
        return len(self.samples)

    def _make_graph(self, idx: int) -> GeometricData:
        s = self.samples[idx]
        (W, H), objects = parse_voc_xml(s.xml_path)
        img = Image.open(s.image_path).convert("RGB")

        n = len(objects)
        if n == 0:
            x = torch.zeros((1, self.embed_dim + 6), dtype=torch.float)
            y = torch.tensor([-1], dtype=torch.long)
            edge_index = torch.empty((2, 0), dtype=torch.long)
            return GeometricData(x=x, y=y, edge_index=edge_index)

        feats = []
        labels = []
        centers = []

        for name, bbox in objects:
            geom = bbox_to_features(bbox, W, H)
            emb  = crop_and_embed(img, bbox, self.backbone, self.transform, self.proj)
            f = np.concatenate([geom, emb], axis=0)
            feats.append(f)
            labels.append(self.class_to_idx.get(name, -1))

            xmin, ymin, xmax, ymax = bbox
            cx = (xmin + xmax) / 2.0
            cy = (ymin + ymax) / 2.0
            centers.append([cx, cy])

        x = torch.tensor(np.stack(feats, axis=0), dtype=torch.float)
        y = torch.tensor(labels, dtype=torch.long)
        centers = np.asarray(centers, dtype=np.float32)

        x_geom = x[:, :6]
        x_vis  = x[:, 6:]
        if x.shape[0] > 1:
            g_mean = x_geom.mean(0)
            g_std  = x_geom.std(0, unbiased=False).clamp_min(1e-6)
            v_mean = x_vis.mean(0)
            v_std  = x_vis.std(0, unbiased=False).clamp_min(1e-6)

            x_geom = (x_geom - g_mean) / g_std
            x_vis  = (x_vis  - v_mean) / v_std

        x = torch.cat([x_geom, x_vis], dim=1)
        x = torch.nan_to_num(x, nan=0.0, posinf=1e3, neginf=-1e3)

        n_nodes = x.shape[0]
        knn_k = min(self.knn_k + 1, n_nodes)
        nbrs = NearestNeighbors(n_neighbors=knn_k, algorithm="auto").fit(centers)
        _, idxs = nbrs.kneighbors(centers)

        edges = set()
        for i in range(n_nodes):
            for j in idxs[i][1:]:
                j = int(j)
                edges.add((i, j))
                edges.add((j, i))

        if not edges:
            edge_index = torch.empty((2, 0), dtype=torch.long)
        else:
            edge_index = torch.tensor(list(edges), dtype=torch.long).t().contiguous()

        return GeometricData(x=x, y=y, edge_index=edge_index)

    def __getitem__(self, idx: int) -> GeometricData:
        if self.preload and len(self._cache) == len(self.samples):
            return self._cache[idx]
        return self._make_graph(idx)


## 4) Class Mapping


In [5]:
def collect_classes(xml_dirs: List[str]) -> Dict[str, int]:
    names = set()
    for d in xml_dirs:
        for xp in glob.glob(os.path.join(d, "*.xml")):
            try:
                _, objs = parse_voc_xml(xp)
            except Exception:
                continue
            for name, _ in objs:
                names.add(name)
    names = sorted(names)
    return {name: i for i, name in enumerate(names)}


class_to_idx = collect_classes([LBL_TRAIN, LBL_TEST])
idx_to_class = {v: k for k, v in class_to_idx.items()}

print("Number of classes:", len(class_to_idx))
print("class_to_idx:", class_to_idx)


Number of classes: 5
class_to_idx: {'bike': 0, 'motobike': 1, 'traffic_light': 2, 'traffic_sign': 3, 'vehicle': 4}


## 5) Dataset Instances & Split (Train / Val / Test)


In [None]:
full_train_ds = CarlaGraphDataset(
    IMG_TRAIN, LBL_TRAIN,
    class_to_idx=class_to_idx,
    backbone=backbone,
    proj=proj,
    transform=train_tx,
    device=DEVICE,
    knn_k=KNN_K,
    embed_dim=EMBED_DIM,
    preload=False,
    name="train+val",
)

test_ds = CarlaGraphDataset(
    IMG_TEST, LBL_TEST,
    class_to_idx=class_to_idx,
    backbone=backbone,
    proj=proj,
    transform=eval_tx,
    device=DEVICE,
    knn_k=KNN_K,
    embed_dim=EMBED_DIM,
    preload=False,
    name="test",
)

print("Total train+val images:", len(full_train_ds))
print("Total test images     :", len(test_ds))

indices = list(range(len(full_train_ds)))
random.shuffle(indices)

val_size = int(len(indices) * VAL_FRACTION)
val_indices = indices[:val_size]
train_indices = indices[val_size:]

train_ds = Subset(full_train_ds, train_indices)
val_ds   = Subset(full_train_ds, val_indices)

print(f"Train images: {len(train_ds)}, Val images: {len(val_ds)}")

train_loader = GeoDataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = GeoDataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False)
test_loader  = GeoDataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False)


Total train+val images: 779
Total test images     : 249
Train images: 624, Val images: 155


### 5.1) Class Distribution (for train subset)


In [None]:
train_label_counts = Counter()
for g in tqdm(train_ds, desc="Counting train labels"):
    y = g.y.numpy()
    for c in y:
        if c >= 0:
            train_label_counts[int(c)] += 1

print("Train label counts (node-level):")
for idx, cnt in sorted(train_label_counts.items()):
    print(f"  {idx_to_class[idx]:15s}: {cnt}")

num_classes = len(class_to_idx)


Counting train labels: 100%|██████████| 624/624 [00:33<00:00, 18.51it/s]

Train label counts (node-level):
  bike           : 121
  motobike       : 63
  traffic_light  : 802
  traffic_sign   : 77
  vehicle        : 932





## 6) GNN Model

We use a simple stack of GATv2Conv layers with ReLU + dropout,
followed by a linear classification head.


In [8]:
class NodeGNN(nn.Module):
    def __init__(
        self,
        in_dim: int,
        hidden: int,
        out_dim: int,
        layers: int = 2,
        dropout: float = 0.5,
        heads: int = 2,
    ):
        super().__init__()
        self.dropout = dropout

        convs = []
        dims = [in_dim] + [hidden] * layers
        for d_in, d_out in zip(dims[:-1], dims[1:]):
            convs.append(
                GATv2Conv(
                    in_channels=d_in,
                    out_channels=d_out // heads,
                    heads=heads,
                    dropout=dropout,
                )
            )
        self.convs = nn.ModuleList(convs)
        self.head = nn.Linear(hidden, out_dim)

    def forward(self, x, edge_index):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = torch.relu(x)
            x = nn.functional.dropout(x, p=self.dropout, training=self.training)
        logits = self.head(x)
        return logits


in_dim = EMBED_DIM + 6
model = NodeGNN(
    in_dim=in_dim,
    hidden=HIDDEN_DIM,
    out_dim=num_classes,
    layers=GNN_LAYERS,
    dropout=DROPOUT,
    heads=2,
).to(DEVICE)

model


NodeGNN(
  (convs): ModuleList(
    (0): GATv2Conv(262, 128, heads=2)
    (1): GATv2Conv(256, 128, heads=2)
  )
  (head): Linear(in_features=256, out_features=5, bias=True)
)

## 7) Loss Function, Class Weights, Optimizer & Scheduler


In [None]:
class_counts = torch.zeros(num_classes, dtype=torch.float)

for g in train_ds:
    y = g.y
    valid = y[y >= 0]
    for c in valid.tolist():
        class_counts[c] += 1

class_weights = 1.0 / (class_counts + 1e-6)
class_weights = class_weights / class_weights.mean() 
class_weights = class_weights.to(DEVICE)

print("Class counts:", class_counts.tolist())
print("Class weights:", class_weights.tolist())

criterion = nn.CrossEntropyLoss(weight=class_weights)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LR,
    weight_decay=WEIGHT_DECAY,
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True,
)


Class counts: [121.0, 63.0, 802.0, 77.0, 932.0]
Class weights: [1.047610878944397, 2.012078285217285, 0.15805602073669434, 1.6462457180023193, 0.1360095739364624]




## 8) Training & Evaluation Loops


In [10]:
def step_batch(batch, train: bool = True):
    """
    One forward/backward pass on a batch.
    Returns: (loss, accuracy).
    """
    if train:
        model.train()
    else:
        model.eval()

    x = batch.x.to(DEVICE)
    ei = batch.edge_index.to(DEVICE)
    y = batch.y.to(DEVICE)

    valid_mask = y >= 0
    if valid_mask.sum() == 0:
        return 0.0, 0.0

    logits = model(x, ei)

    loss = criterion(logits[valid_mask], y[valid_mask])

    if train:
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

    with torch.no_grad():
        preds = logits.argmax(dim=-1)
        acc = (preds[valid_mask] == y[valid_mask]).float().mean().item()

    return loss.item(), acc


def evaluate(loader):
    total_loss = 0.0
    total_acc = 0.0
    n_batches = 0

    model.eval()
    for batch in loader:
        l, a = step_batch(batch, train=False)
        total_loss += l
        total_acc += a
        n_batches += 1

    if n_batches == 0:
        return 0.0, 0.0
    return total_loss / n_batches, total_acc / n_batches


## 9) Main Training Loop (with Validation Early Stopping)


In [None]:
best_val_acc = -1.0
best_state = None
best_epoch = 0
wait = 0  

history = {
    "train_loss": [],
    "train_acc": [],
    "val_loss": [],
    "val_acc": [],
}

for epoch in range(1, EPOCHS + 1):
    train_loss_sum, train_acc_sum, train_batches = 0.0, 0.0, 0
    for batch in train_loader:
        l, a = step_batch(batch, train=True)
        train_loss_sum += l
        train_acc_sum += a
        train_batches += 1

    train_loss = train_loss_sum / max(1, train_batches)
    train_acc = train_acc_sum / max(1, train_batches)

    val_loss, val_acc = evaluate(val_loader)

    scheduler.step(val_loss)

    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)

    print(
        f"Epoch {epoch:02d}: "
        f"train loss {train_loss:.4f}, train acc {train_acc:.3f} | "
        f"val loss {val_loss:.4f}, val acc {val_acc:.3f}"
    )

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch
        wait = 0
        best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
    else:
        wait += 1
        if wait >= PATIENCE:
            print(
                f"Early stopping at epoch {epoch}. "
                f"Best val acc: {best_val_acc:.3f} at epoch {best_epoch}"
            )
            break

if best_state is not None:
    model.load_state_dict(best_state, strict=False)

print(f"Best val acc: {best_val_acc:.3f} at epoch {best_epoch}")


Epoch 01: train loss 1.6958, train acc 0.349 | val loss 1.5497, val acc 0.500
Epoch 02: train loss 1.6103, train acc 0.396 | val loss 1.5103, val acc 0.503
Epoch 03: train loss 1.4971, train acc 0.407 | val loss 1.4384, val acc 0.473
Epoch 04: train loss 1.4715, train acc 0.445 | val loss 1.3866, val acc 0.530
Epoch 05: train loss 1.4325, train acc 0.408 | val loss 1.3504, val acc 0.467
Epoch 06: train loss 1.4002, train acc 0.420 | val loss 1.3049, val acc 0.558
Epoch 07: train loss 1.3505, train acc 0.434 | val loss 1.2434, val acc 0.471
Epoch 08: train loss 1.2822, train acc 0.418 | val loss 1.2115, val acc 0.527
Epoch 09: train loss 1.2373, train acc 0.429 | val loss 1.1635, val acc 0.519
Epoch 10: train loss 1.1891, train acc 0.460 | val loss 1.1578, val acc 0.545
Epoch 11: train loss 1.1955, train acc 0.462 | val loss 1.1246, val acc 0.497
Epoch 12: train loss 1.2085, train acc 0.456 | val loss 1.1097, val acc 0.511
Epoch 13: train loss 1.1006, train acc 0.473 | val loss 1.1314, 

## 10) Final Test Evaluation


In [12]:
test_loss, test_acc = evaluate(test_loader)
print(f"Final test loss: {test_loss:.4f}, test acc: {test_acc:.3f}")


Final test loss: 0.9726, test acc: 0.721


### 10.1) Per-Class Test Report


In [13]:
all_preds = []
all_targets = []

model.eval()
for batch in test_loader:
    x = batch.x.to(DEVICE)
    ei = batch.edge_index.to(DEVICE)
    y = batch.y.to(DEVICE)

    valid_mask = y >= 0
    if valid_mask.sum() == 0:
        continue

    with torch.no_grad():
        logits = model(x, ei)
        preds = logits.argmax(dim=-1)

    all_targets.extend(y[valid_mask].cpu().numpy().tolist())
    all_preds.extend(preds[valid_mask].cpu().numpy().tolist())

print(
    classification_report(
        all_targets,
        all_preds,
        target_names=[idx_to_class[i] for i in range(num_classes)],
        digits=3,
        zero_division=0,
    )
)


               precision    recall  f1-score   support

         bike      0.100     0.233     0.140        30
     motobike      0.102     0.600     0.174        10
traffic_light      0.891     0.843     0.866      1586
 traffic_sign      0.179     0.700     0.286        10
      vehicle      0.222     0.187     0.203       203

     accuracy                          0.759      1839
    macro avg      0.299     0.513     0.334      1839
 weighted avg      0.796     0.759     0.775      1839



## 11) Visualizing Predictions on a Sample Image


In [None]:
def visualize_sample(sample_idx=0, split='test', score: bool = False):
    """
    Visualize predicted labels overlaid on an image.
    """
    if split == 'test':
        ds = test_ds
    elif split == 'train':
        ds = full_train_ds
    else:
        raise ValueError("split must be 'train' or 'test'")

    g = ds[sample_idx]
    s = ds.samples[sample_idx]
    (W, H), objects = parse_voc_xml(s.xml_path)
    img = Image.open(s.image_path).convert('RGB')

    x = g.x.to(DEVICE)
    ei = g.edge_index.to(DEVICE)

    model.eval()
    with torch.no_grad():
        logits = model(x, ei)
        probs = torch.softmax(logits, dim=-1)
        preds = logits.argmax(dim=-1).cpu().numpy().tolist()

    draw = ImageDraw.Draw(img)
    for i, ((name, bbox), p) in enumerate(zip(objects, preds)):
        xmin, ymin, xmax, ymax = bbox
        draw.rectangle(bbox, outline=(255, 0, 0), width=2)

        label = idx_to_class[p]
        if score:
            conf = probs[i, p].item()
            label += f" ({conf:.2f})"

        text_pos = (xmin + 2, max(0, ymin - 12))
        draw.text(text_pos, label, fill=(255, 0, 0))

    display(img)



## 12) Save Trained Model & Metadata


In [15]:
os.makedirs("artifacts", exist_ok=True)
torch.save(
    {
        "model_state": model.state_dict(),
        "class_to_idx": class_to_idx,
        "idx_to_class": idx_to_class,
        "in_dim": in_dim,
        "hidden": HIDDEN_DIM,
        "num_classes": num_classes,
        "config": {
            "KNN_K": KNN_K,
            "EMBED_DIM": EMBED_DIM,
            "HIDDEN_DIM": HIDDEN_DIM,
            "GNN_LAYERS": GNN_LAYERS,
            "DROPOUT": DROPOUT,
            "LR": LR,
            "WEIGHT_DECAY": WEIGHT_DECAY,
            "PATIENCE": PATIENCE,
            "VAL_FRACTION": VAL_FRACTION,
        },
        "device": str(DEVICE),
    },
    "artifacts/carla_gnn_node_cls.pt",
)

print("Saved → artifacts/carla_gnn_node_cls.pt")


Saved → artifacts/carla_gnn_node_cls.pt
