
# CARLA Object-Graph GNN (Node Classification)

**Trains on** `carla-object-detection-dataset/images/train` + `carla-object-detection-dataset/labels/train`  
**Tests on** `carla-object-detection-dataset/images/test` + `carla-object-detection-dataset/labels/test`

Each image becomes a graph:
- **Nodes** = annotated objects (from VOC XML). Node features = bbox geometry + a CNN embedding of the cropped object.
- **Edges** = *k*-NN edges in XY space between object centers.
- **Task** = node classification (predict each object's class).

> This is a clean, self-contained GNN baseline for **object class** prediction using the provided VOC-style labels.


In [1]:

# If you're on CPU-only or a simple setup, try:
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
# !pip install --no-cache-dir torch_geometric
# If that fails, follow the latest install matrix:
# https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html
print('If installs fail, see the PyG docs. Proceeding to imports...')


If installs fail, see the PyG docs. Proceeding to imports...


In [2]:

import os
import glob
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from typing import List, Dict

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch_geometric.data import Data as GeometricData
from torch_geometric.loader import DataLoader as GeoDataLoader
from torch_geometric.nn import GATv2Conv

import torchvision
import torchvision.transforms as T
from PIL import Image, ImageDraw
import numpy as np
import random
from tqdm import tqdm
from collections import Counter

SEED = 7
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DATA_ROOT = 'carla-object-detection-dataset'
IMG_TRAIN = os.path.join(DATA_ROOT, 'images', 'train')
LBL_TRAIN = os.path.join(DATA_ROOT, 'labels', 'train')
IMG_TEST  = os.path.join(DATA_ROOT, 'images', 'test')
LBL_TEST  = os.path.join(DATA_ROOT, 'labels', 'test')

KNN_K = 8
EMBED_DIM = 256
HIDDEN_DIM = 256
GNN_LAYERS = 2
BATCH_SIZE = 8
DROPOUT = 0.5
EPOCHS = 30
LR = 3e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE


  from .autonotebook import tqdm as notebook_tqdm


'cpu'

In [3]:

def parse_voc_xml(xml_path: str):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    size = root.find('size')
    width = int(size.find('width').text)
    height = int(size.find('height').text)
    objects = []
    for obj in root.findall('object'):
        name = obj.find('name').text
        b = obj.find('bndbox')
        bbox = [int(b.find('xmin').text), int(b.find('ymin').text),
                int(b.find('xmax').text), int(b.find('ymax').text)]
        objects.append((name, bbox))
    return (width, height), objects

def bbox_to_features(bbox, w, h):
    xmin, ymin, xmax, ymax = bbox
    bw = max(1, xmax - xmin)
    bh = max(1, ymax - ymin)
    cx = xmin + bw/2
    cy = ymin + bh/2
    cx_n = cx / w
    cy_n = cy / h
    bw_n = bw / w
    bh_n = bh / h
    area_n = (bw * bh) / float(w*h)
    aspect = bw / float(bh)
    return np.array([cx_n, cy_n, bw_n, bh_n, area_n, aspect], dtype=np.float32)

def crop_and_embed(img: Image.Image, bbox, backbone, transform, proj, pad_scale=1.5):
    W, H = img.size
    xmin, ymin, xmax, ymax = bbox
    cx = (xmin + xmax) / 2.0
    cy = (ymin + ymax) / 2.0
    w = (xmax - xmin) * pad_scale
    h = (ymax - ymin) * pad_scale
    x1 = int(max(0, cx - w/2)); y1 = int(max(0, cy - h/2))
    x2 = int(min(W, cx + w/2)); y2 = int(min(H, cy + h/2))
    crop = img.crop((x1, y1, x2, y2))
    x = transform(crop).unsqueeze(0)
    with torch.no_grad():
        feat = backbone(x.to(DEVICE)).cpu().squeeze(0)
        f = proj(feat.unsqueeze(0)).cpu().squeeze(0)
    return f.numpy()


In [4]:

@dataclass
class Sample:
    image_path: str
    xml_path: str

class CarlaGraphDataset(Dataset):
    def __init__(self, img_dir: str, xml_dir: str, class_to_idx: Dict[str, int],
                 device: str = DEVICE, knn_k: int = 4, embed_dim: int = 256,
                 preload: bool = False):
        super().__init__()
        self.img_dir = img_dir
        self.xml_dir = xml_dir
        self.device = device
        self.knn_k = knn_k
        self.class_to_idx = class_to_idx
        self.embed_dim = embed_dim
        self.preload = preload

        resnet = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])  # Bx512x1x1
        self.backbone.eval().to(self.device)
        self.proj = nn.Sequential(nn.Flatten(), nn.Linear(512, embed_dim))
        self.proj.eval().to(self.device)

        self.tx = T.Compose([
            T.Resize((224, 224)),
            T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
            T.RandomApply([T.GaussianBlur(3)], p=0.2),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        xmls = sorted(glob.glob(os.path.join(xml_dir, '*.xml')))
        self.samples: List[Sample] = []
        for xp in xmls:
            fname = os.path.splitext(os.path.basename(xp))[0]
            ip = os.path.join(img_dir, f'{fname}.png')
            if not os.path.exists(ip):
                ip_jpg = os.path.join(img_dir, f'{fname}.jpg')
                if os.path.exists(ip_jpg):
                    ip = ip_jpg
            if os.path.exists(ip):
                self.samples.append(Sample(ip, xp))

        self._cache = []
        if self.preload:
            for i in tqdm(range(len(self.samples)), desc='Preloading graphs'):
                self._cache.append(self._make_graph(i))

    def __len__(self):
        return len(self.samples)

    def _make_graph(self, idx: int) -> GeometricData:
        from sklearn.neighbors import NearestNeighbors

        s = self.samples[idx]
        (W, H), objects = parse_voc_xml(s.xml_path)
        img = Image.open(s.image_path).convert('RGB')
        n = len(objects)
        if n == 0:
            x = torch.zeros((1, self.embed_dim + 6), dtype=torch.float)
            y = torch.tensor([-1], dtype=torch.long)
            edge_index = torch.empty((2, 0), dtype=torch.long)
            return GeometricData(x=x, y=y, edge_index=edge_index)

        feats = []
        labels = []
        centers = []
        for name, bbox in objects:
            geom = bbox_to_features(bbox, W, H)
            emb = crop_and_embed(img, bbox, self.backbone, self.tx, self.proj, pad_scale=1.5)
            f = np.concatenate([geom, emb], axis=0)
            feats.append(f)
            labels.append(self.class_to_idx.get(name, -1))
            xmin, ymin, xmax, ymax = bbox
            cx = (xmin + xmax)/2.0
            cy = (ymin + ymax)/2.0
            centers.append([cx, cy])

        x = torch.tensor(np.stack(feats, 0), dtype=torch.float)
        y = torch.tensor(labels, dtype=torch.long)
        centers = np.asarray(centers, dtype=np.float32)

        # Safe per-graph standardization
        n_nodes = x.shape[0]
        x_geom = x[:, :6]
        x_vis  = x[:, 6:]
        if n_nodes > 1:
            g_mean = x_geom.mean(0)
            g_std  = x_geom.std(0, unbiased=False).clamp_min(1e-6)
            v_mean = x_vis.mean(0)
            v_std  = x_vis.std(0, unbiased=False).clamp_min(1e-6)
            x_geom = (x_geom - g_mean) / g_std
            x_vis  = (x_vis  - v_mean) / v_std
        x = torch.cat([x_geom, x_vis], dim=1)
        x = torch.nan_to_num(x, nan=0.0, posinf=1e3, neginf=-1e3)

        nbrs = NearestNeighbors(n_neighbors=min(self.knn_k+1, n), algorithm='auto').fit(centers)
        _, idxs = nbrs.kneighbors(centers)
        edges = set()
        for i in range(n):
            for j in idxs[i][1:]:
                edges.add((i, int(j)))
                edges.add((int(j), i))
        if len(edges) == 0:
            edge_index = torch.empty((2,0), dtype=torch.long)
        else:
            edge_index = torch.tensor(list(edges), dtype=torch.long).t().contiguous()

        return GeometricData(x=x, y=y, edge_index=edge_index)

    def __getitem__(self, idx: int) -> GeometricData:
        if self.preload and len(self._cache) == len(self):
            return self._cache[idx]
        return self._make_graph(idx)


In [5]:

def collect_classes(xml_dirs: List[str]) -> Dict[str, int]:
    names = set()
    for d in xml_dirs:
        for xp in glob.glob(os.path.join(d, '*.xml')):
            try:
                _, objs = parse_voc_xml(xp)
            except Exception:
                continue
            for name, _ in objs:
                names.add(name)
    names = sorted(list(names))
    return {name:i for i, name in enumerate(names)}

class_to_idx = collect_classes([LBL_TRAIN, LBL_TEST])
idx_to_class = {v:k for k,v in class_to_idx.items()}
class_to_idx


{'bike': 0, 'motobike': 1, 'traffic_light': 2, 'traffic_sign': 3, 'vehicle': 4}

In [6]:

train_ds = CarlaGraphDataset(IMG_TRAIN, LBL_TRAIN, class_to_idx, device=DEVICE, knn_k=KNN_K, embed_dim=EMBED_DIM, preload=False)
test_ds  = CarlaGraphDataset(IMG_TEST,  LBL_TEST,  class_to_idx, device=DEVICE, knn_k=KNN_K, embed_dim=EMBED_DIM, preload=False)
print('Train samples:', len(train_ds))
print('Test samples :', len(test_ds))

train_loader = GeoDataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_loader  = GeoDataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False)


Train samples: 779
Test samples : 249


In [7]:

class NodeGNN(nn.Module):
    def __init__(self, in_dim, hidden, out_dim, layers=2, dropout=0.5, heads=2):
        super().__init__()
        self.dropout = dropout
        convs = []
        dims = [in_dim] + [hidden] * layers
        for d_in, d_out in zip(dims[:-1], dims[1:]):
            convs.append(GATv2Conv(d_in, d_out // heads, heads=heads, dropout=dropout))
        self.convs = nn.ModuleList(convs)
        self.head = nn.Linear(hidden, out_dim)

    def forward(self, x, edge_index):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = torch.relu(x)
            x = nn.functional.dropout(x, p=self.dropout, training=self.training)
        logits = self.head(x)
        return logits

num_classes = len(class_to_idx)
in_dim = EMBED_DIM + 6
model = NodeGNN(in_dim, HIDDEN_DIM, num_classes, layers=GNN_LAYERS, dropout=DROPOUT).to(DEVICE)

cnt = Counter()
for g in train_ds:
    for c in g.y.tolist():
        if c >= 0:
            cnt[c] += 1

weights = torch.ones(num_classes, dtype=torch.float)
for c, n in cnt.items():
    weights[c] = 1.0 / max(1, n)
weights = weights * (sum(cnt.values()) / (weights.sum() + 1e-6))
weights = weights.to(DEVICE)

criterion = nn.CrossEntropyLoss(weight=weights)
optim = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=5e-4)


In [8]:

def step_batch(batch, train=True):
    model.train() if train else model.eval()
    x = batch.x.to(DEVICE)
    ei = batch.edge_index.to(DEVICE)
    y = batch.y.to(DEVICE)
    valid = y >= 0
    if valid.sum() == 0:
        return 0.0, 0.0
    logits = model(x, ei)
    loss = criterion(logits[valid], y[valid])
    if train:
        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optim.step()
    with torch.no_grad():
        pred = logits.argmax(-1)
        acc = (pred[valid] == y[valid]).float().mean().item()
    return loss.item(), acc

def evaluate(loader):
    total_loss, total_acc, n = 0.0, 0.0, 0
    for batch in loader:
        l, a = step_batch(batch, train=False)
        total_loss += l
        total_acc  += a
        n += 1
    return (total_loss/n if n else 0.0), (total_acc/n if n else 0.0)

patience, best_va, wait = 6, -1.0, 0
best_state = None

for epoch in range(1, EPOCHS + 1):
    tl, ta, tn = 0.0, 0.0, 0
    for batch in train_loader:
        l, a = step_batch(batch, train=True)
        tl += l; ta += a; tn += 1
    tl /= max(1, tn); ta /= max(1, tn)

    vl, va = evaluate(test_loader)
    print(f"Epoch {epoch:02d}: train loss {tl:.4f}, train acc {ta:.3f} | test loss {vl:.4f}, test acc {va:.3f}")

    if va > best_va:
        best_va, wait = va, 0
        best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
    else:
        wait += 1
        if wait >= patience:
            print(f"Early stopping at epoch {epoch}. Best test acc: {best_va:.3f}")
            break

if best_state is not None:
    model.load_state_dict(best_state, strict=False)

final_loss, final_acc = evaluate(test_loader)
print(f"Best test acc after early stopping: {final_acc:.3f} (loss {final_loss:.4f})")


Epoch 01: train loss 1.6389, train acc 0.349 | test loss 1.5512, test acc 0.442
Epoch 02: train loss 1.5696, train acc 0.415 | test loss 1.5909, test acc 0.308
Epoch 03: train loss 1.4929, train acc 0.405 | test loss 1.6170, test acc 0.240
Epoch 04: train loss 1.4164, train acc 0.418 | test loss 1.6245, test acc 0.256
Epoch 05: train loss 1.4126, train acc 0.428 | test loss 1.6464, test acc 0.235
Epoch 06: train loss 1.3453, train acc 0.433 | test loss 1.6128, test acc 0.280
Epoch 07: train loss 1.2824, train acc 0.437 | test loss 1.6155, test acc 0.233
Early stopping at epoch 7. Best test acc: 0.442
Best test acc after early stopping: 0.224 (loss 1.6229)


In [9]:

def visualize_sample(sample_idx=0, split='test', score=False):
    ds = test_ds if split=='test' else train_ds
    g = ds[sample_idx]
    s = ds.samples[sample_idx]
    (W, H), objects = parse_voc_xml(s.xml_path)
    img = Image.open(s.image_path).convert('RGB')

    x = g.x.to(DEVICE)
    ei = g.edge_index.to(DEVICE)
    model.eval()
    with torch.no_grad():
        logits = model(x, ei)
        probs = torch.softmax(logits, -1)
        preds = logits.argmax(-1).cpu().numpy().tolist()

    draw = ImageDraw.Draw(img)
    for i, ((name, bbox), p) in enumerate(zip(objects, preds)):
        xmin, ymin, xmax, ymax = bbox
        draw.rectangle(bbox, outline=(255,0,0), width=2)
        label = idx_to_class[p]
        if score:
            conf = probs[i, p].item() if probs.ndim==2 else 1.0
            label += f" ({conf:.2f})"
        draw.text((xmin+2, max(0,ymin-12)), label, fill=(255,0,0))
    display(img)

# visualize_sample(0, split='test', score=True)


In [10]:

import os
os.makedirs('artifacts', exist_ok=True)
torch.save({'model_state': model.state_dict(),
            'class_to_idx': class_to_idx,
            'in_dim': EMBED_DIM + 6,
            'hidden': HIDDEN_DIM,
            'num_classes': len(class_to_idx),
            'config': {
                'KNN_K': KNN_K,
                'EMBED_DIM': EMBED_DIM,
                'HIDDEN_DIM': HIDDEN_DIM,
                'GNN_LAYERS': GNN_LAYERS,
                'DROPOUT': DROPOUT
            }}, 'artifacts/carla_gnn.pt')
print('Saved → artifacts/carla_gnn.pt')


Saved → artifacts/carla_gnn.pt
