In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

device = "cuda"

In [None]:
import torchvision
import os
from PIL import Image, ImageDraw
import PIL
import numpy as np
from transformers import DetrImageProcessor
from pathlib import Path

class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, processor, transforms=None, train=True, debug = False):
        ann_file = os.path.join(img_folder, "_annotations.coco.json" if train else "_annotations.coco.json")
        super(CocoDetection, self).__init__(img_folder, ann_file)
        self.processor = processor
        self.debug = debug
        self.our_transforms = transforms

    def __getitem__(self, idx):
        # read in PIL image and target in COCO format
        img, target = super(CocoDetection, self).__getitem__(idx)

        if self.our_transforms and len(target)>0:
            id_list = []
            image_id_list = []
            category_id_list = []
            bbox_list = []
            area_list = []
            segmentation_list = []
            iscrowd_list = []

            for t in target:
                id_list.append(t['id'])
                image_id_list.append(t['image_id'])
                category_id_list.append(t['category_id'])
                bbox_list.append(t['bbox'])
                area_list.append(t['area'])
                segmentation_list.append(t['segmentation'])
                iscrowd_list.append(t['iscrowd'])

            augmented = self.our_transforms(
                image=np.array(img), bboxes=bbox_list, category_ids=category_id_list
            )
            img = PIL.Image.fromarray(augmented["image"])
            boxes = augmented["bboxes"]
            labels = augmented["category_ids"]

            if self.debug:
                show_augmented_sample(
                    augmented["image"],
                    augmented["bboxes"],
                    augmented["category_ids"],
                    title=f"Sample {idx}",
                )

            new_targets = []

            for i in range(len(boxes)):
                new_targets.append(
                    {
                        "id": target[i]["id"],
                        "image_id": target[i]["image_id"],
                        "category_id": int(labels[i]),
                        "bbox": boxes[i],
                        "area": target[i]['area'],
                        "segmentation": target[i]['segmentation'],
                        "iscrowd": target[i]['iscrowd'],
                    }
                )

            target = new_targets
        # preprocess image and target (converting target to DETR format, resizing + normalization of both image and target)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}
        encoding = self.processor(images=img, annotations=target, return_tensors="pt")
        pixel_values = encoding["pixel_values"].squeeze() # remove batch dimension
        target = encoding["labels"][0] # remove batch dimension

        return pixel_values, target

In [None]:
from transformers import DetrImageProcessor

processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
processor.size = {"height": 640, "width": 640}
nora_path = "./Nora_dataset/coco_sliced"
open100_path = "./OPEN100_dataset/coco_sliced"
synth_path = "./Eramia_dataset/Eramia_COCO/test"

nora_dataset = CocoDetection(
    img_folder=nora_path,
    processor=processor,
    transforms=None,
)
nora_dataset_debug = CocoDetection(
    img_folder=nora_path,
    processor=processor,
    transforms=None,
    debug=True,
)
open100_dataset = CocoDetection(
    img_folder=open100_path,
    processor=processor,
    transforms=None,
)
open100_dataset_debug = CocoDetection(
    img_folder=open100_path,
    processor=processor,
    transforms=None,
    debug=True,
)
synth_dataset = CocoDetection(
    img_folder=synth_path,
    processor=processor,
    transforms=None,
)
synth_dataset_debug = CocoDetection(
    img_folder=synth_path,
    processor=processor,
    transforms=None,
    debug=True,
)

As you can see, this dataset is tiny:

In [None]:
print("Number of Nora examples:", len(nora_dataset))
print("Number of Open100 examples:", len(open100_dataset))
print("Number of Synth examples:", len(synth_dataset))

In [None]:
cats = nora_dataset.coco.cats
id2label = {int(k):v['name'] for k,v in cats.items()}

In [None]:
def show_augmented_sample(
    image, boxes, labels=None, pred_boxes=None, pred_labels=None, confs=None, label_map=None, figsize=(8, 8), title=None
):
    """
    Show image with bounding boxes overlaid.
    - image: NumPy array (H, W, 3)
    - boxes: list of [x, y, w, h]
    - labels: optional list of category_ids
    - label_map: optional dict mapping category_id to class name
    """
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    ax.imshow(image)

    for i, (box) in enumerate(boxes):
        x, y, w, h = box
        rect = patches.Rectangle(
            (x, y), w, h, linewidth=2, edgecolor="lime", facecolor="none"
        )
        ax.add_patch(rect)

        if labels is not None:
            label = labels[i]
            label_str = label_map[label] if label_map else str(label)
            ax.text(
                x,
                y,
                label_str,
                color="black",
                fontsize=12,
                bbox=dict(facecolor="lime", alpha=0.5),
            )
    if pred_boxes is not None:
        for i, (box) in enumerate(pred_boxes):
            x, y, w, h = box
            rect = patches.Rectangle(
                (x, y), w, h, linewidth=2, edgecolor="red", facecolor="none"
            )
            ax.add_patch(rect)
    
            if pred_labels is not None and confs is not None:
                label = pred_labels[i]
                conf = confs[i]
                label_str = label_map[label] if label_map else str(label)
                ax.text(
                    x,
                    min(y+h,image.size[0]-3),
                    f"{label_str} {conf:.2f}",
                    color="black",
                    fontsize=12,
                    bbox=dict(facecolor="red", alpha=0.5),
                )

    ax.set_title(title or "Augmented image with boxes")
    plt.axis("off")
    plt.show()

In [None]:
import os

image_ids = nora_dataset.coco.getImgIds()
image_id = image_ids[np.random.randint(0, len(image_ids))]

def print_img(img_id,dataset,annotations=None,predictions=None):
    image = dataset.coco.loadImgs(img_id)[0]
    image = Image.open(os.path.join(dataset.root, image['file_name']))

    if annotations is None:
        annotations = dataset.coco.imgToAnns[img_id]
    # draw = ImageDraw.Draw(image, "RGBA")

    bboxes = []
    class_names = []
    for annotation in annotations:
        print(annotation["bbox"])
        box = annotation["bbox"]
        bboxes.append(box)
        class_idx = annotation["category_id"]
        class_names.append(id2label[class_idx])
        x, y, w, h = tuple(box)

    pred_bboxes = []
    pred_class_names = []
    confs = []
    if predictions is None:
        predictions = []

    for prediction in predictions:
        print(prediction)
        box = prediction["bbox"]
        pred_bboxes.append(box)
        class_idx = prediction["category_id"]
        pred_class_names.append(id2label[class_idx])
        confs.append(prediction["score"])
        x, y, w, h = tuple(box)
    #   draw.rectangle((x,y,x+w,y+h), outline='red', width=1)
    #   draw.text((x, y), id2label[class_idx], fill='black')

    show_augmented_sample(
        image,
        bboxes,
        class_names,
        pred_bboxes,
        pred_class_names,
        confs,
        title=f"Sample {img_id}",
    )

print_img(image_id, open100_dataset)

In [None]:
from torch.utils.data import DataLoader

batch_size = 60

def collate_fn(batch):
    pixel_values = [item[0] for item in batch]
    encoding = processor.pad(pixel_values, return_tensors="pt")
    labels = [item[1] for item in batch]
    batch = {}
    batch['pixel_values'] = encoding['pixel_values']
    batch['pixel_mask'] = encoding['pixel_mask']
    batch['labels'] = labels
    return batch


nora_dataloader = DataLoader(nora_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True, num_workers=8)
nora_dataloader_debug = DataLoader(
    nora_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False, num_workers=8
)
open100_dataloader = DataLoader(open100_dataset, collate_fn=collate_fn, batch_size=batch_size, num_workers=8)
open100_dataloader_debug = DataLoader(
    open100_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False, num_workers=8
)
synth_dataloader = DataLoader(synth_dataset, collate_fn=collate_fn, batch_size=batch_size, num_workers=8)
synth_dataloader_debug = DataLoader(
    synth_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False, num_workers=8
)

In [None]:
# next(iter(nora_dataloader))

In [None]:
# next(iter(open100_dataloader))

In [None]:
# next(iter(synth_dataloader))

## Evaluate the model

In [None]:
import torch
from transformers import AutoModelForObjectDetection, AutoImageProcessor

# processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")

def load_model(ckpt_path):
    # Load the base model architecture (replace with your backbone/model type)
    model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50",
                                                        revision="no_timm",
                                                        num_labels=len(id2label),
                                                        ignore_mismatched_sizes=True)
    # Load checkpoint
    model_checkpoint = torch.load(ckpt_path, map_location="cpu")
    
    # If checkpoint wrapped inside "state_dict", unwrap and clean keys if needed
    if "state_dict" in model_checkpoint:
        model_checkpoint = {
            k.replace("model.", "", 1): v for k, v in model_checkpoint["state_dict"].items()
        }
    
    # Load weights into the model
    model.load_state_dict(model_checkpoint, strict=False)
    model.to(device)
    model.eval()

    return model

In [None]:
def convert_to_xywh(boxes):
    xmin, ymin, xmax, ymax = boxes.unbind(1)
    return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)

def convert_to_xywh_cpu(box):
    xmin, ymin, xmax, ymax = box
    return xmin, ymin, xmax - xmin, ymax - ymin

def prepare_for_coco_detection(predictions):
    coco_results = []
    for original_id, prediction in predictions.items():
        if len(prediction) == 0:
            continue

        boxes = prediction["boxes"]
        boxes = convert_to_xywh(boxes).tolist()
        scores = prediction["scores"].tolist()
        labels = prediction["labels"].tolist()

        coco_results.extend(
            [
                {
                    "image_id": original_id,
                    "category_id": labels[k],
                    "bbox": box,
                    "score": scores[k],
                }
                for k, box in enumerate(boxes)
            ]
        )
    return coco_results

In [None]:
from coco_eval import CocoEvaluator
from tqdm.notebook import tqdm

def evaluate(dataloader, model):
    evaluator = CocoEvaluator(coco_gt=dataloader.dataset.coco, iou_types=["bbox"])
    for idx, batch in enumerate(tqdm(dataloader)):
        # get the inputs
        pixel_values = batch["pixel_values"].to(device)
        pixel_mask = batch["pixel_mask"].to(device)
        labels = [{k: v.to(device) for k, v in t.items()} for t in batch["labels"]] # these are in DETR format, resized + normalized
    
        # forward pass
        with torch.no_grad():
          outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
    
        # turn into a list of dictionaries (one item for each example in the batch)
        orig_target_sizes = torch.stack([target["orig_size"] for target in labels], dim=0)
        results = processor.post_process_object_detection(outputs, target_sizes=orig_target_sizes, threshold=0.5)
    
        # provide to metric
        # metric expects a list of dictionaries, each item
        # containing image_id, category_id, bbox and score keys
        predictions = {target['image_id'].item(): output for target, output in zip(labels, results)}
        predictions = prepare_for_coco_detection(predictions)
        evaluator.update(predictions)

    evaluator.synchronize_between_processes()
    evaluator.accumulate()
    evaluator.summarize()
    return evaluator

In [None]:
model_coco = load_model("runs/COCO_pretrain/best_coco.ckpt")
model_imagenet = load_model("runs/ImageNet/best_imagenet.ckpt")
model_ipd = load_model("runs/classify_history/best_classify.ckpt")

In [None]:
# for dataloader in [synth_dataloader,nora_dataloader,open100_dataloader]:
#     print()
#     print(dataloader.dataset.root)
#     for model in [model_coco, model_imagenet, model_ipd]:
#         evaluate(dataloader,model)

## Inference (+ visualization)

Let's visualize the predictions of DETR on the first image of the validation set.

In [None]:
import random

def plot_inference(model, dataloader, n_images=1, jumps=0):
    iter_dl = iter(dataloader)
    for _ in range(jumps):
        _ = next(iter_dl)
    
    for idx in range(n_images):
        # We can use the image_id in target to know which image it is
        batch = next(iter_dl)
        pixel_values = batch["pixel_values"]
        targets = batch["labels"]
    
        pixel_values = pixel_values.to(device)
        orig_target_sizes = torch.stack([target["orig_size"] for target in targets], dim=0)

        with torch.no_grad():
            # forward pass to get class logits and bounding boxes
            outputs = model(pixel_values=pixel_values, pixel_mask=None)
            
        postprocessed_outputs = processor.post_process_object_detection(
            outputs, target_sizes=pixel_values.shape[0]*[(pixel_values.shape[2], pixel_values.shape[3])], threshold=0.3
        )[0]
        print("Processed:", postprocessed_outputs)
    
        # load image based on ID
        image_id = targets[0]["image_id"].item()
    
        # print GT
        print("GT")
        annotations = dataloader.dataset.coco.imgToAnns[image_id]
    
        # print Pred
        predictions = []
        for label,bbox,score in zip(postprocessed_outputs["labels"].cpu().tolist(),postprocessed_outputs["boxes"].cpu().tolist(),postprocessed_outputs["scores"].cpu().tolist()):
            prediction = {}
            prediction["bbox"] = convert_to_xywh_cpu(bbox)
            prediction["category_id"] = label
            prediction["score"] = score
            predictions.append(prediction)
    
        print("Pred")
        print_img(image_id, dataloader.dataset, predictions=predictions, annotations=annotations)

In [None]:
plot_inference(model_ipd, synth_dataloader, 10, 10)

In [None]:
plot_inference(model_ipd, open100_dataloader, 10)

In [None]:
plot_inference(model_ipd, open100_dataloader, 10,20)