In [None]:
in[0]: import sys
sys.path.append('./detr')  # Add this line before other imports

from PIL import Image
import os
import numpy as np
import torch
import torch.utils.data
import torchvision
from torchvision import transforms as T
import json
import shutil
from detr.models.matcher import HungarianMatcher
from detr.models.detr import SetCriterion

In [None]:
epochs = 5
num_classes = 2  # Two object classes: ground and cars
data_dir1 = "./data/just_car/"
data_dir2 = "./data/car_trees/"
output_file = "my_model_3_detr.pth"
device = torch.device('cuda')

In [None]:
class CarDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms

        # Get all PNG files in the folder
        self.imgs = sorted([f for f in os.listdir(root) if f.endswith('.png')])

        # For each image, get corresponding .npy and .json files
        self.npy_files = []
        self.json_files = []
        for png_file in self.imgs:
            idx = png_file.replace('rgb_', '').replace('.png', '')
            npy_name = f"bounding_box_2d_tight_{idx}.npy"
            json_name = f"bounding_box_2d_tight_labels_{idx}.json"
            self.npy_files.append(npy_name)
            self.json_files.append(json_name)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.root, self.imgs[idx])
        img = Image.open(img_path).convert("RGB")
        w, h = img.size  # Get image dimensions

        # Load bounding boxes
        npy_path = os.path.join(self.root, self.npy_files[idx])
        bboxes = np.load(npy_path)  # shape: (N, 5) => [object_id, x_min, y_min, x_max, y_max]

        # Load labels
        json_path = os.path.join(self.root, self.json_files[idx])
        with open(json_path, 'r') as f:
            label_dict = json.load(f)

        # Parse bounding boxes & labels
        boxes = []
        labels = []
        for box in bboxes:
            obj_id = int(box[0])
            x_min = float(box[1])
            y_min = float(box[2])
            x_max = float(box[3])
            y_max = float(box[4])
            if x_max <= x_min or y_max <= y_min:
                continue

            obj_class_name = label_dict.get(str(obj_id), {}).get("class", "unknown")
            if obj_class_name == "ground":
                class_label = 0
                boxes.append([x_min, y_min, x_max, y_max])
                labels.append(class_label)
            elif obj_class_name == "cars":
                class_label = 1
                boxes.append([x_min, y_min, x_max, y_max])
                labels.append(class_label)

        # Convert lists to torch tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        # Normalize boxes to [0, 1] using image dimensions
        if boxes.shape[0] > 0:
            boxes[:, [0, 2]] /= w
            boxes[:, [1, 3]] /= h
            areas = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        else:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
            areas = torch.zeros((0,), dtype=torch.float32)

        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([idx]),
            "area": areas
        }

        if self.transforms is not None:
            img = self.transforms(img)
        return img, target

    def __len__(self):
        return len(self.imgs)

In [None]:
def get_transform(train):
    transforms = []
    transforms.append(T.PILToTensor())
    transforms.append(T.ConvertImageDtype(torch.float))
    # Normalize images using ImageNet statistics (DETR was trained with these)
    transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]))
    return T.Compose(transforms)

In [None]:

def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:

def create_model(num_classes):
    # DETR expects the number of classes to include an extra "no-object" class
    model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
    hidden_dim = model.transformer.d_model
    model.class_embed = torch.nn.Linear(hidden_dim, num_classes + 1)
    return model

In [None]:
matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
weight_dict = {'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2}
losses = ['labels', 'boxes', 'cardinality']

criterion = SetCriterion(
    num_classes=num_classes,
    matcher=matcher,
    weight_dict=weight_dict,
    eos_coef=0.1,
    losses=losses
)
criterion.to(device)

In [None]:
# Create datasets and dataloaders
dataset1 = CarDataset(data_dir1, get_transform(train=True))
dataset2 = CarDataset(data_dir2, get_transform(train=True))

data_loader1 = torch.utils.data.DataLoader(
    dataset1, batch_size=4, shuffle=True, collate_fn=collate_fn)
data_loader2 = torch.utils.data.DataLoader(
    dataset2, batch_size=4, shuffle=True, collate_fn=collate_fn)

In [None]:
# Create and move model to the device
model = create_model(num_classes)
model.to(device)

In [None]:

# Use AdamW optimizer as recommended for DETR
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, lr=1e-4)


In [None]:

# Training loop
model.train()
for epoch in range(epochs):
    if epoch < 2:
        current_loader = data_loader1
        print(f"Epoch [{epoch+1}/{epochs}]: Using dataset1")
    else:
        current_loader = data_loader2
        print(f"Epoch [{epoch+1}/{epochs}]: Using dataset2")

    len_dataloader = len(current_loader)
    for i, (imgs, targets) in enumerate(current_loader):
        imgs = torch.stack(imgs).to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        outputs = model(imgs)

        # Use the separate criterion
        loss_dict = criterion(outputs, targets)
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

        losses.backward()
        optimizer.step()

        if (i+1) % 5 == 0:
            print(f"  Batch [{i+1}/{len_dataloader}], Loss: {losses.item():.4f}")


In [None]:

# Save the trained model's state dictionary
torch.save(model, output_file)
