## Finetune DETR model

In [None]:
import gc
import pytorch_lightning as pl
import torch

from datasets import load_dataset
from functools import partial
from PIL import Image as PImage, ImageDraw as PImageDraw
from pytorch_lightning import Trainer, loggers as PLLoggers
from torch.utils.data import DataLoader
from torchvision import tv_tensors
from torchvision.transforms import v2 as T
from transformers import AutoImageProcessor, AutoModelForObjectDetection

from dataset_utils.finetune_0915 import FTUtils

### Load HF Dataset

In [None]:
MODEL_NAME = "microsoft/conditional-detr-resnet-50"
DATASET_NAME = "acervos-digitais/ft-0915"

In [None]:
ft0915_ds = load_dataset(DATASET_NAME)

In [None]:
categories = ft0915_ds["train"].features["objects"].feature["category"].names

id2label = {index: x for index, x in enumerate(categories, start=0)}
label2id = {v: k for k, v in id2label.items()}

### Test HF Dataset

In [None]:
img_id = 11
image = ft0915_ds["train"][img_id]["image"]
annotations = ft0915_ds["train"][img_id]["objects"]
draw = PImageDraw.Draw(image)

In [None]:
for box,class_idx in zip(annotations["bbox"], annotations["category"]):
  x, y, w, h = tuple(box)
  x1, y1 = int(x), int(y)
  x2, y2 = int(x + w), int(y + h)

  draw.rectangle((x, y, x + w, y + h), outline="red", width=1)
  draw.text((x+2, y), id2label[class_idx], fill=(0,0,0))
  draw.text((x+2, y-12), id2label[class_idx], fill=(255,0,255))

display(image)

### Define Image transforms

In [None]:
image_transform = T.Compose([
  T.RandomHorizontalFlip(p=0.5),
  T.RandomAdjustSharpness(sharpness_factor=2, p=0.5),
  T.RandomEqualize(p=0.5),
  T.RandomPerspective(distortion_scale=0.6, p=0.5),
  T.RandomApply(transforms=[T.RandomRotation(degrees=35)], p=0.5),
  T.RandomApply(transforms=[T.RandomAffine(degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.25))], p=0.5),
  T.RandomApply(transforms=[T.ColorJitter(brightness=0.5, hue=0.3)], p=0.5)
])

In [None]:
def transform_batch(examples, transform, image_processor, return_pixel_mask=False):
  images = []
  annotations = []
  for image_id, image, objects in zip(examples["image_id"], examples["image"], examples["objects"]):
    iw, ih = image.size
    objects["bbox"] = tv_tensors.BoundingBoxes(objects["bbox"], format="XYWH", canvas_size=(ih, iw))
    image = tv_tensors.Image(image.convert("RGB"))

    # apply augmentations
    if transform is not None:
      image, bboxes, categories = transform(image, objects["bbox"], objects["category"])
      objects["bbox"] = bboxes
      objects["category"] = categories

    images.append(image)

    # format annotations in COCO format
    formatted_annotations = FTUtils.as_coco(image_id, objects)
    annotations.append(formatted_annotations)

  # Apply the image processor transformations: resizing, rescaling, normalization
  result = image_processor(images=images, annotations=annotations, return_tensors="pt")

  if not return_pixel_mask:
    result.pop("pixel_mask", None)

  return result

### Apply Image transforms

In [None]:
detr_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)

train_transform = partial(transform_batch, transform=image_transform, image_processor=detr_processor, return_pixel_mask=True)
validation_transform = partial(transform_batch, transform=None, image_processor=detr_processor, return_pixel_mask=True)

train_ds = ft0915_ds["train"].with_transform(train_transform)
val_ds = ft0915_ds["test"].with_transform(validation_transform)

### Prepare DataLoaders

In [None]:
def collate_fn(batch):
  data = {}
  data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
  data["labels"] = [x["labels"] for x in batch]
  if "pixel_mask" in batch[0]:
    data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch])
  return data

In [None]:
train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=4, num_workers=4, shuffle=True)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=4, num_workers=4)

In [None]:
def eval_detr(model, processor, dataset, min_threshold=0.2, thresholds=[]):
  num_correct = 0
  num_preds = 0
  num_labels = 0

  with torch.no_grad():
    for row in dataset:
      img = row["image"]
      iw, ih = img.size

      inputs = processor(images=img, return_tensors="pt")
      pixel_values = inputs["pixel_values"].to("cuda")

      outputs = model(pixel_values=pixel_values, pixel_mask=None)

      ppo = processor.post_process_object_detection(outputs,
                                                    target_sizes=[(ih, iw)],
                                                    threshold=min_threshold)[0]

      preds = [l.item() for l in ppo["labels"]]
      scores = [s.item() for s in ppo["scores"]]
      boxes = [b.tolist() for b in ppo["boxes"]]

      if len(thresholds) > 0:
        f_preds = []
        f_scores = []
        f_boxes = []

        for p,s,b in zip(preds, scores, boxes):
          if s > thresholds[p]:
            f_preds.append(p)
            f_scores.append(s)
            f_boxes.append(b)

        preds, scores, boxes = f_preds, f_scores, f_boxes

      labels = row["objects"]["category"]

      cpreds = [1 for p in set(preds) if p in labels]

      num_correct += len(cpreds)
      num_preds += len(preds)
      num_labels += len(labels)
  
  precision = round(num_correct / num_preds, 4) if num_preds != 0 else 0
  recall = round(num_correct / num_labels, 4) if num_labels != 0 else 0
  return precision, recall

In [None]:
class Detr(pl.LightningModule):
  def __init__(self, model_name, image_processor, lr, lr_backbone, weight_decay):
    super().__init__()
    self.processor = image_processor
    self.model = AutoModelForObjectDetection.from_pretrained(
      model_name,
      id2label=id2label,
      label2id=label2id,
      ignore_mismatched_sizes=True
    ).to("cuda")

    # see https://github.com/PyTorchLightning/pytorch-lightning/pull/1896
    self.lr = lr
    self.lr_backbone = lr_backbone
    self.weight_decay = weight_decay
    self.save_hyperparameters()

  def forward(self, pixel_values, pixel_mask):
    return self.model(pixel_values=pixel_values, pixel_mask=pixel_mask)

  def common_step(self, batch, batch_idx):
    pixel_values = batch["pixel_values"]
    pixel_mask = batch["pixel_mask"]
    labels = [{k: v.to(self.device) for k,v in t.items()} for t in batch["labels"]]

    outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)

    loss = outputs.loss
    loss_dict = outputs.loss_dict

    return loss, loss_dict

  def training_step(self, batch, batch_idx):
    loss, loss_dict = self.common_step(batch, batch_idx)
    self.log("training_loss", loss)
    for k,v in loss_dict.items():
      self.log("train_" + k, v.item())

    if batch_idx == 0:
      precision, recall = eval_detr(self.model, self.processor, ft0915_ds["train"])
      self.log("train_precision", precision)
      self.log("train_recall", recall)

    return loss

  def validation_step(self, batch, batch_idx):
    loss, loss_dict = self.common_step(batch, batch_idx)
    self.log("validation_loss", loss)
    for k,v in loss_dict.items():
      self.log("validation_" + k, v.item())

    if batch_idx == 0:
      precision, recall = eval_detr(self.model, self.processor, ft0915_ds["test"])
      self.log("validation_precision", precision)
      self.log("validation_recall", recall)

    return loss

  def configure_optimizers(self):
    param_dicts = [
          {"params": [p for n,p in self.named_parameters() if "backbone" not in n and p.requires_grad]},
          {
              "params": [p for n,p in self.named_parameters() if "backbone" in n and p.requires_grad],
              "lr": self.lr_backbone,
          },
    ]
    optimizer = torch.optim.AdamW(param_dicts, lr=self.lr, weight_decay=self.weight_decay)

    return optimizer

  def train_dataloader(self):
    return train_dataloader

  def val_dataloader(self):
    return val_dataloader

In [None]:
try:
  del model
except:
  pass

gc.collect()
torch.cuda.empty_cache()

In [None]:
model = Detr(model_name=MODEL_NAME, image_processor=detr_processor,
             lr=1e-5, lr_backbone=1e-5, weight_decay=1e-4)

In [None]:
cp = torch.load("lightning_logs/76+43e-1e-5lr-augm3/checkpoints/epoch=43-step=3300.ckpt")
model.load_state_dict(cp["state_dict"])

In [None]:
mLogger = PLLoggers.TensorBoardLogger(save_dir=".", version="e256-augm3")
trainer = Trainer(accelerator="gpu", max_epochs=256, gradient_clip_val=0.1, logger=mLogger)
trainer.fit(model)

### Save to HF Hub

In [None]:
OUTPUT_MODEL_NAME = "acervos-digitais/conditional-detr-resnet-50-ft-0915-e256-augm3"

In [None]:
model.model.push_to_hub(OUTPUT_MODEL_NAME)
detr_processor.push_to_hub(OUTPUT_MODEL_NAME)

### Test Model

In [None]:
import torch

from datasets import load_dataset
from os import path
from PIL import Image as PImage, ImageDraw as PImageDraw, ImageFont as PImageFont
from transformers import AutoImageProcessor, AutoModelForObjectDetection

from dataset_utils.finetune_0915 import FTUtils

MODEL_NAME = OUTPUT_MODEL_NAME

In [None]:
ft0915_ds = load_dataset("acervos-digitais/ft-0915")

detr_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForObjectDetection.from_pretrained(
  MODEL_NAME,
  id2label=FTUtils.ID2LABEL,
  label2id=FTUtils.LABEL2ID,
).to("cuda")

In [None]:
print(eval_detr(model, detr_processor, list(ft0915_ds["train"]), min_threshold=0.3, thresholds=[]))
print(eval_detr(model, detr_processor, list(ft0915_ds["test"]), min_threshold=0.3, thresholds=[]))

In [None]:
thresholds = [0.27, 0.27, 0.55]
print(eval_detr(model, detr_processor, list(ft0915_ds["train"]), min_threshold=0.15, thresholds=thresholds))
print(eval_detr(model, detr_processor, list(ft0915_ds["test"]), min_threshold=0.15, thresholds=thresholds))

In [None]:
thresholds = [0.27, 0.27, 0.55]
print(eval_detr(model, detr_processor, list(ft0915_ds["train"]), min_threshold=0.15, thresholds=thresholds))
print(eval_detr(model, detr_processor, list(ft0915_ds["test"]), min_threshold=0.15, thresholds=thresholds))

In [None]:
thresholds = [0.26, 0.26, 0.74]
print(eval_detr(model, detr_processor, list(ft0915_ds["train"]), min_threshold=0.15, thresholds=thresholds))
print(eval_detr(model, detr_processor, list(ft0915_ds["test"]), min_threshold=0.15, thresholds=thresholds))

In [None]:
for r in list(ft0915_ds["test"])[:48]:
  img = r["image"]
  iw, ih = img.size
  draw = PImageDraw.Draw(img)

  inputs = detr_processor(images=img, return_tensors="pt")
  pixel_values = inputs["pixel_values"].to("cuda")

  with torch.no_grad():
    outputs = model(pixel_values=pixel_values, pixel_mask=None)

  ppo = detr_processor.post_process_object_detection(outputs,
                                                     target_sizes=[(ih, iw)],
                                                     threshold=0.25)[0]

  labels_list = [l.item() for l in ppo["labels"]]
  scores_list = [round(s.item(),4) for s in ppo["scores"]]

  print("pred:", [(FTUtils.ID2LABEL[l],s) for l,s in zip(labels_list, scores_list)])
  print("labels:", [FTUtils.ID2LABEL[c] for c in r["objects"]["category"]])

  for l,b,s in zip(ppo["labels"], ppo["boxes"], ppo["scores"]):
    draw.rectangle(((b[0], b[1]), (b[2], b[3])), outline=(255, 0, 0), width=2)

  display(img)