In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def show_augmented_sample(
    image, boxes, labels=None, label_map=None, figsize=(8, 8), title=None
):
    """
    Show image with bounding boxes overlaid.
    - image: NumPy array (H, W, 3)
    - boxes: list of [x, y, w, h]
    - labels: optional list of category_ids
    - label_map: optional dict mapping category_id to class name
    """
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.imshow(image)

    for i, (box) in enumerate(boxes):
        x, y, w, h = box
        rect = patches.Rectangle(
            (x, y), w, h, linewidth=2, edgecolor="lime", facecolor="none"
        )
        ax.add_patch(rect)

        if labels is not None:
            label = labels[i]
            label_str = label_map[label] if label_map else str(label)
            ax.text(
                x,
                y,
                label_str,
                color="black",
                fontsize=12,
                bbox=dict(facecolor="lime", alpha=0.5),
            )

    ax.set_title(title or "Augmented image with boxes")
    plt.axis("off")
    plt.show()

In [None]:
import albumentations as A

train_transforms = A.Compose(
    [
        # A.RandomResizedCrop(size=(800,800), scale=(0.8, 1.0), p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ElasticTransform(p=0.3),
        A.ColorJitter(p=0.4),
        A.GaussianBlur(p=0.1),
        A.RandomBrightnessContrast(p=0.3),
        A.RandomScale(scale_limit=0.2, p=0.2),
        A.Rotate(limit=15, p=0.2),
    ],
    bbox_params=A.BboxParams(
        format="coco", label_fields=["category_ids"], clip=True
    ),
)

In [None]:
import torchvision
import os
from PIL import Image
import PIL
import numpy as np
from transformers import DetrImageProcessor
from pathlib import Path

class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, processor, transforms=None, train=True, debug = False):
        ann_file = os.path.join(img_folder, "_annotations.coco.json" if train else "_annotations.coco.json")
        super(CocoDetection, self).__init__(img_folder, ann_file)
        self.processor = processor
        self.debug = debug
        self.our_transforms = transforms

    def __getitem__(self, idx):
        # read in PIL image and target in COCO format
        img, target = super(CocoDetection, self).__getitem__(idx)

        if self.our_transforms and len(target)>0:
            id_list = []
            image_id_list = []
            category_id_list = []
            bbox_list = []
            area_list = []
            segmentation_list = []
            iscrowd_list = []

            for t in target:
                id_list.append(t['id'])
                image_id_list.append(t['image_id'])
                category_id_list.append(t['category_id'])
                bbox_list.append(t['bbox'])
                area_list.append(t['area'])
                segmentation_list.append(t['segmentation'])
                iscrowd_list.append(t['iscrowd'])

            augmented = self.our_transforms(
                image=np.array(img), bboxes=bbox_list, category_ids=category_id_list
            )
            img = PIL.Image.fromarray(augmented["image"])
            boxes = augmented["bboxes"]
            labels = augmented["category_ids"]

            if self.debug:
                show_augmented_sample(
                    augmented["image"],
                    augmented["bboxes"],
                    augmented["category_ids"],
                    title=f"Sample {idx}",
                )

            new_targets = []

            for i in range(len(boxes)):
                new_targets.append(
                    {
                        "id": target[i]["id"],
                        "image_id": target[i]["image_id"],
                        "category_id": int(labels[i]),
                        "bbox": boxes[i],
                        "area": target[i]['area'],
                        "segmentation": target[i]['segmentation'],
                        "iscrowd": target[i]['iscrowd'],
                    }
                )

            target = new_targets
        # preprocess image and target (converting target to DETR format, resizing + normalization of both image and target)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}
        encoding = self.processor(images=img, annotations=target, return_tensors="pt")
        pixel_values = encoding["pixel_values"].squeeze() # remove batch dimension
        target = encoding["labels"][0] # remove batch dimension

        return pixel_values, target

In [None]:
from transformers import DetrImageProcessor

processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
processor.size = {"height": 640, "width": 640}
train_path = "./Eramia_dataset/Eramia_COCO/train/"
val_path = "./Eramia_dataset/Eramia_COCO/val"
test_path = "./Eramia_dataset/Eramia_COCO/test"

train_dataset = CocoDetection(
    img_folder=train_path,
    processor=processor,
    transforms=train_transforms,
)
train_dataset_debug = CocoDetection(
    img_folder=train_path,
    processor=processor,
    transforms=train_transforms,
    debug=True,
)
val_dataset = CocoDetection(
    img_folder=val_path,
    processor=processor,
    transforms=None,
)
test_dataset = CocoDetection(
    img_folder=test_path,
    processor=processor,
    transforms=None,
)

In [None]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(val_dataset))

In [None]:
cats = train_dataset.coco.cats
id2label = {int(k):v['name'] for k,v in cats.items()}

In [None]:
import os
from PIL import Image, ImageDraw

# based on https://github.com/woctezuma/finetune-detr/blob/master/finetune_detr.ipynb
image_ids = train_dataset.coco.getImgIds()
image_id = image_ids[np.random.randint(0, len(image_ids))]

def print_img(img_id,dataset,data_path,annotations=None):
    image = dataset.coco.loadImgs(img_id)[0]
    image = Image.open(os.path.join(data_path, image['file_name']))

    if annotations is None:
        annotations = dataset.coco.imgToAnns[img_id]
    # draw = ImageDraw.Draw(image, "RGBA")

    bboxes = []
    class_names = []
    for annotation in annotations:
        box = annotation["bbox"]
        bboxes.append(box)
        class_idx = annotation["category_id"]
        class_names.append(id2label[class_idx])
        x, y, w, h = tuple(box)

    show_augmented_sample(
        image,
        bboxes,
        class_names,
        title=f"Sample {img_id}",
    )

print_img(image_id, train_dataset, data_path=train_path)

In [None]:
from torch.utils.data import DataLoader

batch_size = 32

def collate_fn(batch):
    pixel_values = [item[0] for item in batch]
    encoding = processor.pad(pixel_values, return_tensors="pt")
    labels = [item[1] for item in batch]
    batch = {}
    batch['pixel_values'] = encoding['pixel_values']
    batch['pixel_mask'] = encoding['pixel_mask']
    batch['labels'] = labels
    return batch


train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True, num_workers=8)
train_dataloader_debug = DataLoader(
    train_dataset_debug, collate_fn=collate_fn, batch_size=batch_size, shuffle=True
)
val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size)

In [None]:
next(iter(train_dataloader_debug))

In [None]:
import torch

def convert_to_xywh(boxes):
    xmin, ymin, xmax, ymax = boxes.unbind(1)
    return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)

def prepare_for_coco_detection(predictions):
    coco_results = []
    for original_id, prediction in predictions.items():
        if len(prediction) == 0:
            continue

        boxes = prediction["boxes"]
        boxes = convert_to_xywh(boxes).tolist()
        scores = prediction["scores"].tolist()
        labels = prediction["labels"].tolist()

        coco_results.extend(
            [
                {
                    "image_id": original_id,
                    "category_id": labels[k],
                    "bbox": box,
                    "score": scores[k],
                }
                for k, box in enumerate(boxes)
            ]
        )
    return coco_results

In [None]:
from coco_eval import CocoEvaluator
from tqdm.notebook import tqdm

import pytorch_lightning as pl
from transformers import DetrForObjectDetection, ResNetBackbone
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
import json 
import tempfile

class Detr(pl.LightningModule):
    def __init__(self, lr, lr_backbone, weight_decay, pretrain=None, backbone=None):
        super().__init__()
        self.save_hyperparameters()
        if pretrain is None:
            self.pretrain = "random"
            model_config = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50",
                                                            revision="no_timm",
                                                            num_labels=len(id2label),
                                                            ignore_mismatched_sizes=True).config
            model_config.use_pretrained_backbone=False
            self.model = DetrForObjectDetection(model_config)
        else:
            self.pretrain = pretrain
            self.model = DetrForObjectDetection.from_pretrained(pretrain, ## Ex: "facebook/detr-resnet-50"
                                                                revision="no_timm",
                                                                num_labels=len(id2label),
                                                                ignore_mismatched_sizes=True)

        self.backbone = backbone
        if self.backbone is None:
            pass
        elif backbone == "random":
            self.backbone = backbone
            model_config = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50",
                                                            revision="no_timm",
                                                            num_labels=len(id2label),
                                                            ignore_mismatched_sizes=True).config
            model_config.use_pretrained_backbone=False
            self.model.backbone = DetrForObjectDetection(model_config).backbone
        else:
            self.model.backbone = ResNetBackbone.from_pretrained(self.backbone) ## Ex: "microsoft/resnet-50"

        self.lr = lr
        self.lr_backbone = lr_backbone
        self.weight_decay = weight_decay

        # Evaluator
        self.val_outputs = []

    def forward(self, pixel_values, pixel_mask):
        outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask)

        return outputs

    def common_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]
        pixel_mask = batch["pixel_mask"]
        labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

        outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)

        loss = outputs.loss
        loss_dict = outputs.loss_dict

        return loss, loss_dict, outputs

    def training_step(self, batch, batch_idx):
        loss, loss_dict, outputs = self.common_step(batch, batch_idx)
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        for k,v in loss_dict.items():
            self.log("train_" + k, v.item())

        return loss

    def validation_step(self, batch, batch_idx):
        loss, loss_dict, outputs = self.common_step(batch, batch_idx)
        self.log("validation_loss", loss)
        for k,v in loss_dict.items():
            self.log("validation_" + k, v.item())

        # turn into a list of dictionaries (one item for each example in the batch)
        labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]
        orig_target_sizes = torch.stack([target["orig_size"] for target in labels], dim=0)
        results = processor.post_process_object_detection(outputs, target_sizes=orig_target_sizes, threshold=0)

        # other validation
        for target, pred in zip(batch["labels"], results):
            image_id = int(target["image_id"])
            boxes = pred["boxes"].cpu().numpy()
            scores = pred["scores"].cpu().numpy()
            labels = pred["labels"].cpu().numpy()

            for box, score, label in zip(boxes, scores, labels):
                x1, y1, x2, y2 = box.tolist()
                width, height = x2 - x1, y2 - y1
                self.val_outputs.append(
                    {
                        "image_id": image_id,
                        "category_id": int(label),
                        "bbox": [x1, y1, width, height],
                        "score": float(score),
                    }
                )

        return loss

    def on_validation_epoch_end(self):
        with tempfile.NamedTemporaryFile(mode="w+", suffix=".json", delete=False) as f:
            json.dump(self.val_outputs, f)
            pred_path = f.name

        # Assume you have the GT annotations loaded from a COCO file
        coco_gt = COCO(f"{val_path}/_annotations.coco.json")
        coco_dt = coco_gt.loadRes(pred_path)

        coco_eval = COCOeval(coco_gt, coco_dt, iouType="bbox")
        coco_eval.evaluate()
        coco_eval.accumulate()
        coco_eval.summarize()

        # Log primary metrics
        self.log("val/cocomAP", coco_eval.stats[0], prog_bar=True)
        self.log("val/AP50", coco_eval.stats[1])
        self.log("val/AP75", coco_eval.stats[2])

        # Clear stored predictions
        self.val_outputs = []

    def configure_optimizers(self):
        param_dicts = [
              {"params": [p for n, p in self.named_parameters() if "backbone" not in n and p.requires_grad]},
              {
                  "params": [p for n, p in self.named_parameters() if "backbone" in n and p.requires_grad],
                  "lr": self.lr_backbone,
              },
        ]
        optimizer = torch.optim.AdamW(param_dicts, lr=self.lr,
                                  weight_decay=self.weight_decay)

        return optimizer

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

Here we define the model, and verify the outputs.

In [None]:
model = Detr(lr=1e-4,
             lr_backbone=1e-5,
             weight_decay=1e-4,
             pretrain="facebook/detr-resnet-50",
             backbone=None)

The logits are of shape `(batch_size, num_queries, number of classes + 1)`. We model internally adds an additional "no object class", which explains why we have one additional output for the class dimension.

Next, let's train! We train for a maximum of 300 training steps, and also use gradient clipping. You can refresh Tensorboard above to check the various losses.

In [None]:
from pytorch_lightning import loggers as pl_loggers

tb_logger = pl_loggers.TensorBoardLogger('runs/',name=f"{model.pretrain}_{model.backbone}".replace("/","_"))

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

model_checkpoint_path = Path(tb_logger.log_dir)/"models"

checkpoint_best = ModelCheckpoint(dirpath=model_checkpoint_path,
                                      save_top_k=3,
                                      monitor="val/cocomAP",
                                      filename= 'best_e-{epoch}_s-{step}_ap-{val/cocomAP:.2f}',
                                      save_weights_only = False,
                                      auto_insert_metric_name=False,
                                      mode="max"
                                     )
checkpoint_epoch = ModelCheckpoint(dirpath=model_checkpoint_path,
                                      filename='e-{epoch:03d}',
                                      save_weights_only = False,
                                      auto_insert_metric_name=False,
                                      every_n_epochs=10,
                                      save_last=True,
                                      save_top_k=-1,
                                     )

In [None]:
trainer = Trainer(max_epochs=300,
                  min_epochs=10,
                  check_val_every_n_epoch=1,
                  gradient_clip_val=0.1,
                  default_root_dir=model_checkpoint_path,
                  enable_checkpointing=True,
                  # limit_train_batches=0.1,
                  # limit_val_batches=0.1,
                  callbacks=[checkpoint_best,checkpoint_epoch],
                  logger=tb_logger
                 )
# trainer.fit(model, ckpt_path="runs/facebook_detr-resnet-50/version_14/models/last.ckpt")
trainer.fit(model)