In [None]:
from datasets import load_dataset

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


import albumentations as A

from transformers import DetrFeatureExtractor, AutoModelForObjectDetection, TrainingArguments,Trainer

## Model - Detr Resnet50 backbone

In [None]:
label2id = {
    "logo": 0,
    "text": 1,
}

id2label = {v: k for k, v in label2id.items()}

In [None]:

feature_extractor_checkpoint = "facebook/detr-resnet-50"
feature_extractor = DetrFeatureExtractor.from_pretrained(feature_extractor_checkpoint)

In [None]:

model = AutoModelForObjectDetection.from_pretrained(
    feature_extractor_checkpoint,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

## Dataset

In [None]:
dataset = load_dataset("bastienp/visible-watermark-pita")

In [None]:
dataset

In [None]:
dataset["train"][0]["image"]

## Preprocessing

In [None]:
preprocess = A.Compose([
    A.Resize(480, 480),
    A.CenterCrop(224, 224),
], bbox_params=A.BboxParams(format="coco", label_fields=[]))

preprocess_viz = A.Compose([
    A.Resize(480, 480),
    A.CenterCrop(224, 224),
], bbox_params=A.BboxParams(format="coco", label_fields=[]))

In [None]:
def clamp_coco_bbox(bbox, img_width, img_height):
    x, y, width, height = bbox
    
    # Ensure x and y are within the image boundaries
    x = max(0, min(x, img_width))
    y = max(0, min(y, img_height))
    
    # Ensure width and height do not extend beyond the image boundaries
    width = min(width, img_width - x)
    height = min(height, img_height - y)
    
    return [x, y, width, height]


In [None]:

def formatted_anns(image_id, category, area, bbox):
    annotations = []
    for i in range(0, len(category)):
        new_ann = {
            "image_id": image_id,
            "category_id": category[i],
            "isCrowd": 0,
            "area": area[i],
            "bbox": list(bbox[i]),
        }
        annotations.append(new_ann)

    return annotations

In [None]:
# transforming a batch
def transform_aug_ann(examples):
    image_ids = examples["image_id"]
    images, bboxes, areas, categories = [], [], [], []
    for image, bbox, category, area in zip(examples["image"], examples["bbox"], examples["category_id"], examples["area"]):
        image = np.array(image.convert("RGB"))
        img_shape = image.shape

        out = preprocess(image=image, bboxes=[clamp_coco_bbox(bb, img_shape[0], img_shape[1]) for bb in [bbox]], category_ids=[category])

        areas.append([area])
        images.append(out["image"])
        bboxes.append(out["bboxes"])
        categories.append(out["category_ids"])

    targets = [
        {"image_id": id_, "annotations": formatted_anns(id_, cat_, ar_, box_)}
        for id_, cat_, ar_, box_ in zip(image_ids, categories, areas, bboxes)
    ]

    return feature_extractor(images=images, annotations=targets, return_tensors="pt")

In [None]:
dataset["train"] = dataset["train"].with_transform(transform_aug_ann)

In [None]:
def collate_fn(batch):
    pixel_values = [item["pixel_values"] for item in batch]
    encoding = feature_extractor.pad(pixel_values, return_tensors="pt")
    labels = [item["labels"] for item in batch]
    batch = {}
    batch["pixel_values"] = encoding["pixel_values"]
    batch["pixel_mask"] = encoding["pixel_mask"]
    batch["labels"] = labels
    return batch

## Train

In [None]:

training_args = TrainingArguments(
    output_dir="detr-resnet-50_finetuned",
    per_device_train_batch_size=8,
    num_train_epochs=10,
    save_steps=200,
    logging_steps=50,
    learning_rate=1e-5,
    weight_decay=1e-4,
    save_total_limit=2,
    remove_unused_columns=False,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=dataset["train"],
    tokenizer=feature_extractor,
)


In [None]:
trainer.train()


In [None]:
trainer.save_model("detr-resnet-50_finetuned")

In [None]:
df_logs = pd.DataFrame(trainer.state.log_history)

df_logs[["loss", "learning_rate"]].plot(title="Training Metrics")