In [2]:
import sys
sys.path.append('./detr') 

from PIL import Image
import os
import numpy as np
import torch
import torch.utils.data
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision import transforms as T
import json
import shutil

In [3]:
epochs = 5
num_classes = 2
data_dir1 = "./data/just_car/"
data_dir2 = "./data/car_trees/"
output_file = "model/model.pth"
device = torch.device('cuda')

In [4]:
class CarDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        """
        Initialize the CarDataset.
        Args:
            root (str): Root directory containing the dataset files
            transforms: Transforms to be applied to the images
        """
        self.root = root
        self.transforms = transforms

        # Get all PNG files in the folder
        self.imgs = sorted([f for f in os.listdir(root) if f.endswith('.png')])

        # For each image, get corresponding .npy and .json files
        self.npy_files = []
        self.json_files = []

        for png_file in self.imgs:
            # Extract the index from the filename (e.g. rgb_0000.png -> 0000)
            idx = png_file.replace('rgb_', '').replace('.png', '')

            # Construct corresponding .npy and .json filenames
            npy_name = f"bounding_box_2d_tight_{idx}.npy"
            json_name = f"bounding_box_2d_tight_labels_{idx}.json"

            self.npy_files.append(npy_name)
            self.json_files.append(json_name)

    def __getitem__(self, idx):
        """
        Get a single item from the dataset.
        Args:
            idx (int): Index of the item to get
        Returns:
            tuple: (image, target) where target is a dictionary containing:
                  - boxes (Tensor): Bounding boxes in [x_min, y_min, x_max, y_max] format
                  - labels (Tensor): Class labels for each box
                  - image_id (Tensor): Image index
                  - area (Tensor): Area of each box
        """
        # Load image
        img_path = os.path.join(self.root, self.imgs[idx])
        img = Image.open(img_path).convert("RGB")

        # Load bounding boxes
        npy_path = os.path.join(self.root, self.npy_files[idx])
        bboxes = np.load(npy_path)  # shape: (N, 5) => [object_id, x_min, y_min, x_max, y_max]

        # Load labels
        json_path = os.path.join(self.root, self.json_files[idx])
        with open(json_path, 'r') as f:
            label_dict = json.load(f)

        # Parse bounding boxes & labels
        boxes = []
        labels = []

        for box in bboxes:
            obj_id = int(box[0])
            x_min = float(box[1])
            y_min = float(box[2])
            x_max = float(box[3])
            y_max = float(box[4])

            # Skip invalid boxes
            if x_max <= x_min or y_max <= y_min:
                continue

            obj_class_name = label_dict.get(str(obj_id), {}).get("class", "unknown")

            # Only process ground and cars classes
            if obj_class_name == "ground":
                class_label = 0
                boxes.append([x_min, y_min, x_max, y_max])
                labels.append(class_label)
            elif obj_class_name == "cars":
                class_label = 1
                boxes.append([x_min, y_min, x_max, y_max])
                labels.append(class_label)

        # Convert to torch tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        # Calculate areas
        areas = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        # Handle empty case
        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
            areas = torch.zeros((0,), dtype=torch.float32)

        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([idx]),
            "area": areas
        }

        if self.transforms is not None:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        """
        Get the total number of items in the dataset.
        Returns:
            int: Number of items in dataset
        """
        return len(self.imgs)

In [5]:
def get_transform(train):
    """
    Creates a composition of image transformations to be applied to the dataset.
    
    Args:
        train (bool): Whether the transforms are for training or validation
                     (currently not used but kept for future augmentations)
    
    Returns:
        torchvision.transforms.Compose: A composition of transforms that:
            1. Converts PIL image to tensor
            2. Converts image tensor to float dtype
    """
    transforms = []
    transforms.append(T.PILToTensor())
    transforms.append(T.ConvertImageDtype(torch.float))
    return T.Compose(transforms)

In [6]:
def collate_fn(batch):
    """
    Custom collate function for DataLoader that unpacks batches.
    
    Args:
        batch: List of tuples (image, target) from dataset
        
    Returns:
        tuple: Contains two lists - one with images and one with targets
    """
    return tuple(zip(*batch))

In [7]:
def create_model(num_classes):
    """
    Creates and configures a Faster R-CNN model with ResNet-50 backbone for object detection.
    
    Args:
        num_classes (int): Number of classes to detect (including background)
        
    Returns:
        torch.nn.Module: Configured Faster R-CNN model with:
            - ResNet-50 backbone pre-trained on ImageNet
            - Feature Pyramid Network
            - Modified box predictor head for specified number of classes
    """
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

In [8]:
# Create datasets from the two data directories using the transform function
dataset1 = CarDataset(data_dir1, get_transform(train=True))
dataset2 = CarDataset(data_dir2, get_transform(train=True))

# Create data loaders with batch size 4, shuffling enabled, and custom collate function
data_loader1 = torch.utils.data.DataLoader(
    dataset1, batch_size=4, shuffle=True, collate_fn=collate_fn)
data_loader2 = torch.utils.data.DataLoader(
    dataset2, batch_size=4, shuffle=True, collate_fn=collate_fn)

FileNotFoundError: [Errno 2] No such file or directory: './data/just_car/'

In [None]:
# Create model and move it to the specified device (CPU/GPU)
model = create_model(num_classes)
model.to(device)

# Get trainable parameters and create SGD optimizer with learning rate 0.001 and momentum 0.9
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9)

In [None]:
# Set model to training mode
model.train()
for epoch in range(epochs):
    # Use dataset1 for epochs 1 and 2; use dataset2 for epochs 3 and 4 (and beyond if needed)
    if epoch < 2:
        current_loader = data_loader1
        print(f"Epoch [{epoch+1}/{epochs}]: Using dataset1")
    else:
        current_loader = data_loader2
        print(f"Epoch [{epoch+1}/{epochs}]: Using dataset2")

    # Get total number of batches in current dataloader
    len_dataloader = len(current_loader)
    
    # Iterate through batches of data
    for i, (imgs, annotations) in enumerate(current_loader):
        # Move images and annotations to device (CPU/GPU)
        imgs = list(img.to(device) for img in imgs)
        annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass - get loss dictionary from model
        loss_dict = model(imgs, annotations)
        
        # Calculate total loss by summing all losses
        losses = sum(loss for loss in loss_dict.values())
        
        # Backward pass and optimize
        losses.backward()
        optimizer.step()
        
        # Print progress every 5 batches
        if (i+1) % 5 == 0:
            print(f"  Batch [{i+1}/{len_dataloader}], Loss: {losses.item():.4f}")

In [None]:
torch.save(model, output_file)