<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/vits/detection/DETR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW, lr_scheduler
import torchvision.transforms as T
from torchvision import datasets, ops
from torchvision.models.feature_extraction import create_feature_extractor
from einops import rearrange

import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import linear_sum_assignment

In [None]:
####################
# DATA PREPARATION #
####################

In [None]:
!git clone https://github.com/lizhogn/tiny_coco_dataset.git

In [None]:
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush', 'empty'
]

# Colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098],
          [0.929, 0.694, 0.125], [0.494, 0.184, 0.556],
          [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
COLORS *= 100


revert_normalization = T.Normalize(
    mean=[-.485/.229, -.456/.224, -.406/.225],
    std=[1/.229, 1/.224, 1/.225]
)


def plot_im_with_boxes(im, boxes, probs=None, ax=None):
    """
    Plot an image and render bounding boxes with optional labels.

    Parameters
    ----------
    im : array-like or Tensor
        The image to display in HWC format.
    boxes : Tensor
        Bounding boxes in xyxy format with shape (N, 4).
    probs : Tensor, optional
        If 1-D: contains class IDs for each box.
        If 2-D: contains class probabilities per box (N, C).
        If None: no labels are drawn.
    ax : matplotlib.axes.Axes, optional
        Existing axes on which to draw. If omitted, a new figure is created.

    Notes
    -----
    Uses COLORS to differentiate boxes and CLASSES to map class IDs
    to human-readable labels.
    """

    if ax is None:
        plt.imshow(im)
        ax = plt.gca()

    for i, b in enumerate(boxes.tolist()):
        xmin, ymin, xmax, ymax = b

        patch = plt.Rectangle(
            (xmin, ymin), xmax - xmin, ymax - ymin,
            fill=False, color=COLORS[i], linewidth=2)

        ax.add_patch(patch)
        if probs is not None:
            if probs.ndim == 1:
                cl = probs[i].item()
                text = f'{CLASSES[cl]}'
            else:
                cl = probs[i].argmax().item()
                text = f'{CLASSES[cl]}: {probs[i,cl]:0.2f}'
        else:
            text = ''

        ax.text(xmin, ymin, text, fontsize=7,
                bbox=dict(facecolor='yellow', alpha=0.5))


def preprocess_target(anno, im_w, im_h):
    """
    Convert COCO annotation dictionaries into normalized training targets.

    Parameters
    ----------
    anno : list of dict
        Raw COCO annotations for a single image. Must contain "bbox"
        in xywh format and "category_id".
    im_w : int
        Original image width in pixels.
    im_h : int
        Original image height in pixels.

    Returns
    -------
    classes : Tensor
        Class IDs for valid bounding boxes.
    boxes : Tensor
        Bounding boxes in normalized cxcywh format.

    Notes
    -----
    - Filters out annotations with "iscrowd" == 1.
    - Converts xywh → xyxy.
    - Removes invalid or degenerate boxes.
    - Normalizes coordinates to [0, 1].
    - Converts xyxy → cxcywh for downstream models.
    """

    anno = [obj for obj in anno
            if 'iscrowd' not in obj or obj['iscrowd'] == 0]

    boxes = [obj["bbox"] for obj in anno]
    boxes = torch.as_tensor(
        boxes, dtype=torch.float32).reshape(-1, 4)

    # xywh -> xyxy
    boxes[:, 2:] += boxes[:, :2]
    boxes[:, 0::2].clamp_(min=0, max=im_w)
    boxes[:, 1::2].clamp_(min=0, max=im_h)
    keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
    boxes = boxes[keep]

    classes = [obj["category_id"] for obj in anno]
    classes = torch.tensor(classes, dtype=torch.int64)
    classes = classes[keep]

    # scales boxes to [0,1]
    boxes[:, 0::2] /= im_w
    boxes[:, 1::2] /= im_h
    boxes.clamp_(min=0, max=1)

    boxes = ops.box_convert(boxes, in_fmt='xyxy', out_fmt='cxcywh')
    return classes, boxes


class MyCocoDetection(datasets.CocoDetection):
    """
    A thin wrapper around torchvision.datasets.CocoDetection that applies
    preprocessing transforms to images and annotations.

    Adds:
    - Image transforms: tensor conversion, normalization, fixed-size resize.
    - Annotation preprocessing: COCO xywh → normalized cxcywh.
    """

    def __init__(self, *args, **kwargs):
        """
        Initialize the dataset wrapper.

        Parameters
        ----------
        *args, **kwargs : passed to CocoDetection
            Standard initialization parameters such as the image root
            directory and annotation file path.
        """
        super().__init__(*args, **kwargs)
        self.edge = 480

        self.T = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[.485, .456, .406],
                        std=[.229, .224, .225]),
            T.Resize((self.edge, self.edge), antialias=True)
        ])

        self.T_target = preprocess_target

    def __getitem__(self, idx):
        """
        Fetch an item and apply both image and annotation preprocessing.

        Parameters
        ----------
        idx : int
            Dataset index.

        Returns
        -------
        image_tensor : Tensor
            Preprocessed image of shape (3, edge, edge).
        (classes, boxes) : tuple
            Processed annotations where:
            - classes: Tensor of class IDs
            - boxes: Tensor of normalized cxcywh bounding boxes
        """
        img, target = super().__getitem__(idx)
        # PIL image
        w, h = img.size

        input_ = self.T(img)
        classes, boxes = self.T_target(target, w, h)

        return input_, (classes, boxes)


def collate_fn(inputs):
    """
    Custom collate function for DataLoader to batch fixed-size images
    while keeping variable-sized bounding-box annotations.

    Parameters
    ----------
    inputs : list
        Each element is (image_tensor, (classes, boxes)).

    Returns
    -------
    batched_images : Tensor
        A stacked tensor of images of shape (B, 3, H, W).
    (classes, boxes) : tuple
        Tuples of length B containing class tensors and box tensors
        for each image.
    """
    input_ = torch.stack([i[0] for i in inputs])
    classes = tuple([i[1][0] for i in inputs])
    boxes = tuple([i[1][1] for i in inputs])
    return input_, (classes, boxes)


In [None]:
train_ds = MyCocoDetection(
    'tiny_coco_dataset/tiny_coco/train2017/',
    'tiny_coco_dataset/tiny_coco/annotations/instances_train2017.json',
)

train_loader = DataLoader(
    train_ds, batch_size=4, shuffle=False, collate_fn=collate_fn)

print(f'\nNumber of training samples: {len(train_ds)}')
# Number of training samples: 50

In [None]:
input_, (target) = next(iter(train_loader))
fig = plt.figure(figsize=(10, 10), constrained_layout=True)

for ix in range(4):
    t_cl = target[0][ix]
    t_bbox = target[1][ix]

    t_bbox = ops.box_convert(
        t_bbox*480, in_fmt='cxcywh', out_fmt='xyxy')

    im = revert_normalization(input_)[ix].\
        permute(1,2,0).cpu().clip(0,1)

    ax = fig.add_subplot(2, 2, ix+1)
    ax.imshow(im)
    plot_im_with_boxes(im, t_bbox, t_cl, ax=ax)
    ax.set_axis_off()