In [None]:
!pip install -qqq -U datasets
!pip install -qqq torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -qqq -U lightning
!pip install -qqq pycocotools faster-coco-eval

In [None]:
import os

import pytorch_lightning as pl
import torch
from datasets import load_dataset
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities.types import EVAL_DATALOADERS
from torch.utils.data import DataLoader, Dataset
from torchmetrics.detection import MeanAveragePrecision
from torchvision.models.detection import (
    FasterRCNN_ResNet50_FPN_Weights,
    faster_rcnn,
    fasterrcnn_resnet50_fpn,
)
from torchvision.ops import box_convert
from torchvision.transforms import functional as F

In [None]:
class CustomDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.hf_dataset = hf_dataset
        self.transform = transform

    def __getitem__(self, idx):
        item = self.hf_dataset[idx]
        image = item["image"]
        image = F.to_tensor(image)

        target = {}
        target["boxes"] = torch.tensor([item["bbox"]])
        target["boxes"] = box_convert(target["boxes"], in_fmt="xywh", out_fmt="xyxy")
        target["labels"] = torch.tensor([item["category_id"] - 1])

        return image, target

    def __len__(self):
        return len(self.hf_dataset)

In [None]:
class FasterRCNNModule(pl.LightningModule):
    def __init__(self, num_classes):
        super().__init__()

        self.model = fasterrcnn_resnet50_fpn(
            weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT,
            pretrained=True,
        )

        in_features = self.model.roi_heads.box_predictor.cls_score.in_features

        self.model.roi_heads.box_predictor = faster_rcnn.FastRCNNPredictor(
            in_features, num_classes
        )

    def forward(self, images, targets=None):
        return self.model(images, targets)

    def training_step(self, batch, batch_idx):
        images, targets = batch
        loss_dict = self(images, targets)

        # Combining the localization and classification losses
        loss = sum(loss for loss in loss_dict.values())
        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        model.train()

        # BatchNorm are already frozen and we want to get the loss on the validation set
        with torch.no_grad():
            images, targets = batch
            loss_dict = self(images, targets)

            loss = sum(loss for loss in loss_dict.values())
            self.log("val_loss", loss)

        return loss

    def test_step(self, batch, batch_idx):
        images, targets = batch
        preds = self(images)
        metric = MeanAveragePrecision(
            iou_type="bbox", box_format="xyxy", class_metrics=True
        )
        metric.update(preds, targets)

        result = metric.compute()

        self.log("map", result["map"].item())
        self.log("map_50", result["map_50"].item())
        self.log("map_75", result["map_75"].item())

        return metric

    def test_dataloader(self) -> EVAL_DATALOADERS:
        return super().test_dataloader()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer

In [None]:
BATCH_SIZE = 16

CPU_COUNT: int = os.cpu_count() or 0

In [None]:
def load_split(
    split: str, shuffle: bool
) -> DataLoader:
    hf_dataset = load_dataset("bastienp/visible-watermark-pita", split=split)
    dataset = CustomDataset(hf_dataset)

    data_loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=shuffle,
        collate_fn=lambda x: tuple(zip(*x)),
        num_workers=CPU_COUNT,
    )

    return data_loader

In [None]:
train_dataloader = load_split("train", shuffle=True)
val_dataloader = load_split("val", shuffle=False)
test_dataloader = load_split("test", shuffle=False)

In [None]:
NUM_CLASSES = 2
model = FasterRCNNModule(num_classes=NUM_CLASSES)

wandb_logger = WandbLogger(project="faster-rcnn-pita-dataset", log_model="all")

trainer = pl.Trainer(
    devices="auto",
    max_epochs=1,
    accumulate_grad_batches=BATCH_SIZE,
    log_every_n_steps=50,
    logger=wandb_logger,
    precision=16,
)

trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

trainer.save_checkpoint("faster-rcnn-pita.ckpt")

trainer.test(dataloaders=test_dataloader, ckpt_path="faster-rcnn-pita.ckpt")