## Read DataSet and confirm objects

In [None]:
import gc
import numpy as np
import pytorch_lightning as pl
import torch

from datasets import load_dataset
from functools import partial
from PIL import ImageDraw as PImageDraw
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from transformers import AutoImageProcessor, AutoModelForObjectDetection

from dataset_utils.finetune_0915 import FTUtils

### Load HF Dataset

In [None]:
ft0915_ds = load_dataset("acervos-digitais/ft-0915")

In [None]:
if "val" not in ft0915_ds:
  split = ft0915_ds["train"].train_test_split(0.15, seed=1010)
  ft0915_ds["train"] = split["train"]
  ft0915_ds["val"] = split["test"]

In [None]:
categories = ft0915_ds["train"].features["objects"].feature["category"].names

id2label = {index: x for index, x in enumerate(categories, start=0)}
label2id = {v: k for k, v in id2label.items()}

### Load Model and Image Processor

In [None]:
MODEL_NAME = "microsoft/conditional-detr-resnet-50"

In [None]:
image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)

model = AutoModelForObjectDetection.from_pretrained(
  MODEL_NAME,
  id2label=id2label,
  label2id=label2id,
  ignore_mismatched_sizes=True,
)

### Test HF Dataset

In [None]:
img_id = 11
image = ft0915_ds["train"][img_id]["image"]
annotations = ft0915_ds["train"][img_id]["objects"]
draw = PImageDraw.Draw(image)

In [None]:
for box,class_idx in zip(annotations["bbox"], annotations["category"]):
  x, y, w, h = tuple(box)
  x1, y1 = int(x), int(y)
  x2, y2 = int(x + w), int(y + h)

  draw.rectangle((x, y, x + w, y + h), outline="red", width=1)
  draw.text((x, y), id2label[class_idx], fill=(0,255,0))

display(image)

### Define Image transforms

In [None]:
def transform_batch(examples, transform, image_processor, return_pixel_mask=False):
  images = []
  annotations = []
  for image_id, image, objects in zip(examples["image_id"], examples["image"], examples["objects"]):
    image = np.array(image.convert("RGB"))

    # apply augmentations
    # output = transform(image=image, bboxes=objects["bbox"], category=objects["category"])
    images.append(image)

    # format annotations in COCO format
    formatted_annotations = FTUtils.as_coco(image_id, objects)
    annotations.append(formatted_annotations)

  # Apply the image processor transformations: resizing, rescaling, normalization
  result = image_processor(images=images, annotations=annotations, return_tensors="pt")

  if not return_pixel_mask:
    result.pop("pixel_mask", None)

  return result

### Apply Image transforms

In [None]:
train_transform = partial(transform_batch, transform=None, image_processor=image_processor, return_pixel_mask=True)
validation_transform = partial(transform_batch, transform=None, image_processor=image_processor, return_pixel_mask=True)

train_ds = ft0915_ds["train"].with_transform(train_transform)
val_ds = ft0915_ds["val"].with_transform(validation_transform)
test_ds = ft0915_ds["test"].with_transform(validation_transform)

### Prepare DataLoaders

In [None]:
def collate_fn(batch):
  data = {}
  data["pixel_values"] = torch.stack([x["pixel_values"] for x in batch])
  data["labels"] = [x["labels"] for x in batch]
  if "pixel_mask" in batch[0]:
    data["pixel_mask"] = torch.stack([x["pixel_mask"] for x in batch])
  return data

In [None]:
train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=4)

In [None]:
class Detr(pl.LightningModule):
  def __init__(self, lr, lr_backbone, weight_decay):
    super().__init__()
    self.model = AutoModelForObjectDetection.from_pretrained(
      MODEL_NAME,
      id2label=id2label,
      label2id=label2id,
      ignore_mismatched_sizes=True
    )

    # see https://github.com/PyTorchLightning/pytorch-lightning/pull/1896
    self.lr = lr
    self.lr_backbone = lr_backbone
    self.weight_decay = weight_decay

  def forward(self, pixel_values, pixel_mask):
    return self.model(pixel_values=pixel_values, pixel_mask=pixel_mask)

  def common_step(self, batch, batch_idx):
    pixel_values = batch["pixel_values"]
    pixel_mask = batch["pixel_mask"]
    labels = [{k: v.to(self.device) for k,v in t.items()} for t in batch["labels"]]

    outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)

    loss = outputs.loss
    loss_dict = outputs.loss_dict

    return loss, loss_dict

  def training_step(self, batch, batch_idx):
    loss, loss_dict = self.common_step(batch, batch_idx)
    self.log("training_loss", loss)
    for k,v in loss_dict.items():
      self.log("train_" + k, v.item())

    return loss

  def validation_step(self, batch, batch_idx):
    loss, loss_dict = self.common_step(batch, batch_idx)
    self.log("validation_loss", loss)
    for k,v in loss_dict.items():
      self.log("validation_" + k, v.item())

    return loss

  def configure_optimizers(self):
    param_dicts = [
          {"params": [p for n,p in self.named_parameters() if "backbone" not in n and p.requires_grad]},
          {
              "params": [p for n,p in self.named_parameters() if "backbone" in n and p.requires_grad],
              "lr": self.lr_backbone,
          },
    ]
    optimizer = torch.optim.AdamW(param_dicts, lr=self.lr, weight_decay=self.weight_decay)

    return optimizer

  def train_dataloader(self):
    return train_dataloader

  def val_dataloader(self):
    return val_dataloader

In [None]:
if model:
  del model

gc.collect()
torch.cuda.empty_cache()

In [None]:
model = Detr(lr=1e-4, lr_backbone=1e-5, weight_decay=1e-4)

In [None]:
trainer = Trainer(accelerator="gpu", max_epochs=48, gradient_clip_val=0.1)
trainer.fit(model)

In [None]:
OUTPUT_MODEL_NAME = "acervos-digitais/conditional-detr-resnet-50-ft-0915"

In [None]:
model.model.push_to_hub(OUTPUT_MODEL_NAME)
image_processor.push_to_hub(OUTPUT_MODEL_NAME)

### Test Model

In [None]:
import torch

from datasets import load_dataset
from os import path
from PIL import Image as PImage, ImageDraw as PImageDraw, ImageFont as PImageFont
from transformers import AutoImageProcessor, AutoModelForObjectDetection

from dataset_utils.finetune_0915 import FTUtils

MODEL_NAME = "acervos-digitais/conditional-detr-resnet-50-ft-0915"

In [None]:
ft0915_ds = load_dataset("acervos-digitais/ft-0915")

image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForObjectDetection.from_pretrained(
  MODEL_NAME,
  id2label=FTUtils.ID2LABEL,
  label2id=FTUtils.LABEL2ID,
).to("cuda")

In [None]:
for r in list(ft0915_ds["test"])[:10]:
  img = r["image"]
  iw, ih = img.size
  draw = PImageDraw.Draw(img)

  inputs = image_processor(images=img, return_tensors="pt")
  pixel_values = inputs["pixel_values"].to("cuda")

  with torch.no_grad():
    outputs = model(pixel_values=pixel_values, pixel_mask=None)

  ppo = image_processor.post_process_object_detection(outputs,
                                                      target_sizes=[(ih, iw)],
                                                      threshold=0.13)[0]

  print("labels:", [FTUtils.ID2LABEL[c] for c in r["objects"]["category"]])
  print("pred:", [FTUtils.ID2LABEL[l.item()] for l in ppo["labels"]])

  for l,b in zip(ppo["labels"], ppo["boxes"]):
    draw.rectangle(((b[0], b[1]), (b[2], b[3])), outline=(255, 0, 0), width=2)

  display(img)
