In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleDetector(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # Backbone (feature extractor)
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),  # 3 input channels (RGB)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # Halve resolution
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        # Detection head (predict bounding boxes and classes)
        self.fc = nn.Sequential(
            nn.Linear(64 * 16 * 16, 256),  # Adjust input size based on backbone output
            nn.ReLU(),
            nn.Linear(256, 5 + num_classes)  # 5 = (x, y, w, h, confidence)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [None]:
import os
import cv2
import numpy as np
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset, DataLoader

class DetectionDataset(Dataset):
    def __init__(self, image_dir, annotation_dir, classes, transform=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.classes = classes
        self.transform = transform
        self.image_files = os.listdir(image_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Load annotations
        annotation_path = os.path.join(self.annotation_dir, self.image_files[idx].replace('.jpg', '.xml'))
        tree = ET.parse(annotation_path)
        root = tree.getroot()

        # Parse annotations
        boxes = []
        labels = []
        for obj in root.findall('object'):
            class_name = obj.find('name').text
            labels.append(self.classes.index(class_name))
            bbox = obj.find('bndbox')
            xmin = float(bbox.find('xmin').text)
            ymin = float(bbox.find('ymin').text)
            xmax = float(bbox.find('xmax').text)
            ymax = float(bbox.find('ymax').text)
            boxes.append([xmin, ymin, xmax, ymax])

        # Convert to tensors
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)

        # Apply transforms (resize, normalize, etc.)
        if self.transform:
            image = self.transform(image)

        return image, {'boxes': boxes, 'labels': labels}

In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256)),  # Match model input size
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
def detection_loss(pred, target):
    # pred shape: (batch_size, 5 + num_classes)
    # target: dict with 'boxes' and 'labels'

    # Localization loss (MSE for box coordinates)
    pred_boxes = pred[:, :4]  # (x, y, w, h)
    target_boxes = target['boxes']
    loc_loss = F.mse_loss(pred_boxes, target_boxes)

    # Classification loss (Cross-Entropy)
    pred_classes = pred[:, 5:]
    target_labels = target['labels']
    cls_loss = F.cross_entropy(pred_classes, target_labels)

    # Confidence loss (optional)
    # pred_conf = pred[:, 4]
    # conf_loss = F.binary_cross_entropy(pred_conf, target_conf)

    total_loss = loc_loss + cls_loss
    return total_loss

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize dataset and model
classes = ['car', 'truck']  # Update with your classes
dataset = DetectionDataset('dataset/images', 'dataset/annotations', classes, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
model = SimpleDetector(num_classes=len(classes)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, targets in dataloader:
        images = images.to(device)
        gt_boxes = targets['boxes'].to(device)
        gt_labels = targets['labels'].to(device)

        # Forward pass
        preds = model(images)

        # Compute loss
        loss = detection_loss(preds, {'boxes': gt_boxes, 'labels': gt_labels})

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}')

In [None]:
def predict(image, model, classes, threshold=0.5):
    model.eval()
    with torch.no_grad():
        image_tensor = transform(image).unsqueeze(0).to(device)
        preds = model(image_tensor)
        boxes = preds[0, :4].cpu().numpy()
        conf = preds[0, 4].item()
        class_id = torch.argmax(preds[0, 5:]).item()
        class_name = classes[class_id]

    if conf > threshold:
        xmin, ymin, xmax, ymax = boxes
        return [(xmin, ymin, xmax, ymax, class_name, conf)]
    else:
        return []