# CARLA Object Classification with ResNet50

This notebook trains a **ResNet50** classifier on the same CARLA object-detection dataset used by the GNN:

- **Trains on** `carla-object-detection-dataset/images/train` + `labels/train`
- **Tests on** `carla-object-detection-dataset/images/test` + `labels/test`
- Each **object bounding box** becomes **one training sample**
- Model: **ResNet50** (ImageNet-pretrained), fine-tuned for the CARLA classes


In [None]:
# 0) Imports & Config
import os
import glob
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from typing import List, Dict, Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as T
from PIL import Image
import numpy as np
import random
from tqdm import tqdm

SEED = 7
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DATA_ROOT = 'carla-object-detection-dataset'
IMG_TRAIN = os.path.join(DATA_ROOT, 'images', 'train')
LBL_TRAIN = os.path.join(DATA_ROOT, 'labels', 'train')
IMG_TEST  = os.path.join(DATA_ROOT, 'images', 'test')
LBL_TEST  = os.path.join(DATA_ROOT, 'labels', 'test')

BATCH_SIZE = 32
EPOCHS     = 15
LR         = 3e-4

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', DEVICE)

Using device: cpu


In [None]:
# 1) VOC XML parsing utilities

def parse_voc_xml(xml_path: str):
    """Return image size and a list of objects with (name, bbox[xmin,ymin,xmax,ymax])."""
    tree = ET.parse(xml_path)
    root = tree.getroot()
    size = root.find('size')
    width = int(size.find('width').text)
    height = int(size.find('height').text)
    objects = []
    for obj in root.findall('object'):
        name = obj.find('name').text
        b = obj.find('bndbox')
        bbox = [int(b.find('xmin').text), int(b.find('ymin').text),
                int(b.find('xmax').text), int(b.find('ymax').text)]
        objects.append((name, bbox))
    return (width, height), objects

def collect_classes(xml_dirs: List[str]) -> Dict[str, int]:
    """Scan xml directories to build a class_name -> index mapping."""
    names = set()
    for d in xml_dirs:
        for xp in glob.glob(os.path.join(d, '*.xml')):
            try:
                _, objs = parse_voc_xml(xp)
            except Exception:
                continue
            for name, _ in objs:
                names.add(name)
    names = sorted(list(names))
    return {name: i for i, name in enumerate(names)}

class_to_idx = collect_classes([LBL_TRAIN, LBL_TEST])
idx_to_class = {v: k for k, v in class_to_idx.items()}
print('Classes:', class_to_idx)

Classes: {'bike': 0, 'motobike': 1, 'traffic_light': 2, 'traffic_sign': 3, 'vehicle': 4}


In [None]:
# 2) Dataset 

@dataclass
class BBoxSample:
    image_path: str
    label_idx: int
    bbox: Tuple[int, int, int, int] 

class CarlaBBoxDataset(Dataset):
    def __init__(self, img_dir: str, xml_dir: str, class_to_idx: Dict[str, int], transform=None):
        super().__init__()
        self.img_dir = img_dir
        self.xml_dir = xml_dir
        self.class_to_idx = class_to_idx
        self.transform = transform

        self.samples: List[BBoxSample] = []
        xmls = sorted(glob.glob(os.path.join(xml_dir, '*.xml')))
        for xp in xmls:
            fname = os.path.splitext(os.path.basename(xp))[0]
            ip = os.path.join(img_dir, f'{fname}.png')
            if not os.path.exists(ip):
                ip_jpg = os.path.join(img_dir, f'{fname}.jpg')
                if os.path.exists(ip_jpg):
                    ip = ip_jpg
            if not os.path.exists(ip):
                continue

            try:
                (_, _), objects = parse_voc_xml(xp)
            except Exception:
                continue

            for name, bbox in objects:
                if name not in self.class_to_idx:
                    continue
                label_idx = self.class_to_idx[name]
                self.samples.append(BBoxSample(ip, label_idx, tuple(bbox)))

        print(f"Loaded {len(self.samples)} object crops from {xml_dir}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx: int):
        s = self.samples[idx]
        img = Image.open(s.image_path).convert('RGB')
        xmin, ymin, xmax, ymax = s.bbox
        crop = img.crop((xmin, ymin, xmax, ymax))

        if self.transform is not None:
            crop = self.transform(crop)

        label = torch.tensor(s.label_idx, dtype=torch.long)
        return crop, label

In [None]:
# 3) Build datasets & data loaders

train_tf = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_tf = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_ds = CarlaBBoxDataset(IMG_TRAIN, LBL_TRAIN, class_to_idx, transform=train_tf)
test_ds  = CarlaBBoxDataset(IMG_TEST,  LBL_TEST,  class_to_idx, transform=test_tf)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  
    pin_memory=False
)

test_loader  = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0, 
    pin_memory=False
)

print('Train samples:', len(train_ds))
print('Test samples :', len(test_ds))

Loaded 2556 object crops from carla-object-detection-dataset/labels/train
Loaded 1839 object crops from carla-object-detection-dataset/labels/test
Train samples: 2556
Test samples : 1839


In [None]:
# 4) Build ResNet50 model

num_classes = len(class_to_idx)

resnet50 = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)

in_features = resnet50.fc.in_features
resnet50.fc = nn.Linear(in_features, num_classes)

resnet50 = resnet50.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(resnet50.parameters(), lr=LR, weight_decay=1e-4)

print(resnet50)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [6]:
# 5) Training & evaluation loops

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0

    for images, labels in tqdm(loader, desc='Train', leave=False):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / max(1, total_samples)
    avg_acc  = total_correct / max(1, total_samples)
    return avg_loss, avg_acc


def evaluate(model, loader, device):
    model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0

    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Eval', leave=False):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            logits = model(images)
            loss = criterion(logits, labels)

            total_loss += loss.item() * labels.size(0)
            preds = logits.argmax(dim=1)
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)

    avg_loss = total_loss / max(1, total_samples)
    avg_acc  = total_correct / max(1, total_samples)
    return avg_loss, avg_acc


best_acc = 0.0
best_state = None

for epoch in range(1, EPOCHS + 1):
    print(f"\nEpoch {epoch:02d}/{EPOCHS}")
    train_loss, train_acc = train_one_epoch(resnet50, train_loader, optimizer, DEVICE)
    val_loss, val_acc = evaluate(resnet50, test_loader, DEVICE)

    print(f"Train: loss={train_loss:.4f}, acc={train_acc:.3f} | Test: loss={val_loss:.4f}, acc={val_acc:.3f}")

    if val_acc > best_acc:
        best_acc = val_acc
        best_state = {k: v.cpu() for k, v in resnet50.state_dict().items()}

print(f"\nBest test accuracy: {best_acc:.3f}")
if best_state is not None:
    resnet50.load_state_dict(best_state)


Epoch 01/15


                                                      

Train: loss=0.2249, acc=0.928 | Test: loss=0.0949, acc=0.971

Epoch 02/15


                                                      

Train: loss=0.1057, acc=0.968 | Test: loss=0.0496, acc=0.984

Epoch 03/15


                                                       

Train: loss=0.0562, acc=0.984 | Test: loss=0.0227, acc=0.986

Epoch 04/15


                                                        

Train: loss=0.0607, acc=0.980 | Test: loss=0.1241, acc=0.967

Epoch 05/15


                                                      

Train: loss=0.0554, acc=0.980 | Test: loss=0.0547, acc=0.988

Epoch 06/15


                                                         

Train: loss=0.0352, acc=0.989 | Test: loss=0.0983, acc=0.982

Epoch 07/15


                                                      

Train: loss=0.0391, acc=0.987 | Test: loss=0.0851, acc=0.973

Epoch 08/15


                                                      

Train: loss=0.0550, acc=0.985 | Test: loss=0.0905, acc=0.983

Epoch 09/15


                                                      

Train: loss=0.0762, acc=0.975 | Test: loss=0.1266, acc=0.974

Epoch 10/15


                                                      

Train: loss=0.0377, acc=0.989 | Test: loss=0.0307, acc=0.986

Epoch 11/15


                                                      

Train: loss=0.0186, acc=0.993 | Test: loss=0.0237, acc=0.991

Epoch 12/15


                                                      

Train: loss=0.0208, acc=0.995 | Test: loss=0.0549, acc=0.987

Epoch 13/15


                                                      

Train: loss=0.0199, acc=0.992 | Test: loss=0.0557, acc=0.984

Epoch 14/15


                                                      

Train: loss=0.0106, acc=0.996 | Test: loss=0.0442, acc=0.985

Epoch 15/15


                                                      

Train: loss=0.0105, acc=0.996 | Test: loss=0.0820, acc=0.977

Best test accuracy: 0.991




In [None]:
# 6) Save trained model

os.makedirs('artifacts', exist_ok=True)
save_path = 'artifacts/resnet50_carla_classifier.pt'

torch.save({
    'model_state': resnet50.state_dict(),
    'class_to_idx': class_to_idx,
    'num_classes': num_classes,
}, save_path)

print('Saved model to:', save_path)

Saved model to: artifacts/resnet50_carla_classifier.pt


In [None]:
# 7) Quick inference demo on a few test samples

def show_some_predictions(model, loader, idx_to_class, device, max_batches=1):
    model.eval()
    printed = 0
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            logits = model(images)
            preds = logits.argmax(dim=1)

            for i in range(images.size(0)):
                true_lbl = idx_to_class[labels[i].item()]
                pred_lbl = idx_to_class[preds[i].item()]
                print(f'Sample {printed:03d}: true={true_lbl}, pred={pred_lbl}')
                printed += 1

                if printed >= 10:
                    return
            max_batches -= 1
            if max_batches <= 0:
                break

show_some_predictions(resnet50, test_loader, idx_to_class, DEVICE)

Sample 000: true=vehicle, pred=vehicle
Sample 001: true=vehicle, pred=vehicle
Sample 002: true=vehicle, pred=vehicle
Sample 003: true=traffic_light, pred=traffic_light
Sample 004: true=traffic_light, pred=traffic_light
Sample 005: true=traffic_light, pred=traffic_light
Sample 006: true=traffic_light, pred=traffic_light
Sample 007: true=traffic_light, pred=traffic_light
Sample 008: true=traffic_light, pred=traffic_light
Sample 009: true=traffic_light, pred=traffic_light
