# Finetuning with COCO Format

We can now setup a script to finetune with pytorch lightning on top of standard coco format data.

This is based upon the following example: https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/train-huggingface-detr-on-custom-dataset.ipynb#scrollTo=qk6sRB0lueHY

In [None]:
%pip install -U timm transformers supervision roboflow lightning mlflow pycocotools
%restart_python 

## Setup Configurations

This will initialise all the variables that we need

In [None]:
import lightning as pl
from lightning.pytorch.loggers import MLFlowLogger

import torch
from torch.utils.data import DataLoader
import torchvision

from transformers import DetrForObjectDetection, DetrImageProcessor

import supervision as sv
import os

import mlflow

ds_catalog = 'brian_ml_dev'
ds_schame = 'image_processing'
coco_volume = 'coco_dataset'
save_dir = '/local_disk0/train'

mlflow_experiment = '/Users/brian.law@databricks.com/brian_lightning'

volume_path = f"/Volumes/{ds_catalog}/{ds_schame}/{coco_volume}"
image_path = f'{volume_path}'
annotation_json = f'{volume_path}/annotations.json'

CHECKPOINT = 'facebook/detr-resnet-50'

In [None]:
image_processor = DetrImageProcessor.from_pretrained(CHECKPOINT)

In [None]:
class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(
        self, 
        image_directory_path: str, 
        image_processor, 
        train: bool = True
    ):
        annotation_file_path = annotation_json
        super(CocoDetection, self).__init__(image_directory_path, annotation_file_path)
        self.image_processor = image_processor

    def __getitem__(self, idx):
        images, annotations = super(CocoDetection, self).__getitem__(idx)        
        image_id = self.ids[idx]
        annotations = {'image_id': image_id, 'annotations': annotations}
        encoding = self.image_processor(images=images, annotations=annotations, return_tensors="pt")
        pixel_values = encoding["pixel_values"].squeeze()
        target = encoding["labels"][0]

        return pixel_values, target

def collate_fn(batch):
    # DETR authors employ various image sizes during training, making it not possible 
    # to directly batch together images. Hence they pad the images to the biggest 
    # resolution in a given batch, and create a corresponding binary pixel_mask 
    # which indicates which pixels are real/which are padding
    pixel_values = [item[0] for item in batch]
    encoding = image_processor.pad(pixel_values, return_tensors="pt")
    labels = [item[1] for item in batch]
    return {
        'pixel_values': encoding['pixel_values'],
        'pixel_mask': encoding['pixel_mask'],
        'labels': labels
    }

In [None]:
TRAIN_DATASET = CocoDetection(
    image_directory_path=image_path, 
    image_processor=image_processor, 
    train=True)

print("Number of training examples:", len(TRAIN_DATASET))

TRAIN_DATALOADER = DataLoader(dataset=TRAIN_DATASET, collate_fn=collate_fn, batch_size=4, shuffle=True)
VAL_DATALOADER = DataLoader(dataset=TRAIN_DATASET, collate_fn=collate_fn, batch_size=4)

In [None]:
#### Test out the batch loader

for batch in TRAIN_DATALOADER:
  print(batch)
  break 

In [None]:

categories = TRAIN_DATASET.coco.cats
id2label = {k: v['name'] for k,v in categories.items()}

class Detr(pl.LightningModule):

    def __init__(self, lr, lr_backbone, weight_decay):
        super().__init__()
        self.model = DetrForObjectDetection.from_pretrained(
            pretrained_model_name_or_path=CHECKPOINT, 
            num_labels=len(id2label),
            ignore_mismatched_sizes=True
        )
        
        self.lr = lr
        self.lr_backbone = lr_backbone
        self.weight_decay = weight_decay

    def forward(self, pixel_values, pixel_mask):
        return self.model(pixel_values=pixel_values, pixel_mask=pixel_mask)

    def common_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]
        pixel_mask = batch["pixel_mask"]
        labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]

        outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)

        loss = outputs.loss
        loss_dict = outputs.loss_dict

        return loss, loss_dict

    def training_step(self, batch, batch_idx):
        logger_module = self.logger.experiment
        
        loss, loss_dict = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step, and the average across the epoch
        self.logger.log_metrics({"training_loss": loss}, batch_idx)
        for k,v in loss_dict.items():
            self.logger.log_metrics({"train_" + k: v.item()}, batch_idx)

        return loss

    def validation_step(self, batch, batch_idx):
        logger_module = self.logger.experiment
        loss, loss_dict = self.common_step(batch, batch_idx)     
        self.logger.log_metrics({"validation/loss": loss}, batch_idx)
        for k, v in loss_dict.items():
            self.logger.log_metrics({"validation_" + k: v.item()}, batch_idx)
            
        return loss

    def configure_optimizers(self):
        # DETR authors decided to use different learning rate for backbone
        # you can learn more about it here: 
        # - https://github.com/facebookresearch/detr/blob/3af9fa878e73b6894ce3596450a8d9b89d918ca9/main.py#L22-L23
        # - https://github.com/facebookresearch/detr/blob/3af9fa878e73b6894ce3596450a8d9b89d918ca9/main.py#L131-L139
        param_dicts = [
            {
                "params": [p for n, p in self.named_parameters() if "backbone" not in n and p.requires_grad]},
            {
                "params": [p for n, p in self.named_parameters() if "backbone" in n and p.requires_grad],
                "lr": self.lr_backbone,
            },
        ]
        return torch.optim.AdamW(param_dicts, lr=self.lr, weight_decay=self.weight_decay)

    def train_dataloader(self):
        return TRAIN_DATALOADER

    def val_dataloader(self):
        return VAL_DATALOADER

In [None]:
model = Detr(lr=1e-4, lr_backbone=1e-5, weight_decay=1e-4)

# settings
MAX_EPOCHS = 10


# we need to start the logger here as it creates a new run
mlflow.pytorch.autolog()

mlf_logger = MLFlowLogger(
    experiment_name=mlflow_experiment,
    tracking_uri="databricks",
    checkpoint_path_prefix="brian_testing"
)

# pytorch_lightning < 2.0.0
# trainer = Trainer(gpus=1, max_epochs=MAX_EPOCHS, gradient_clip_val=0.1, accumulate_grad_batches=8, log_every_n_steps=5)

# pytorch_lightning >= 2.0.0
trainer = pl.Trainer(devices=1, accelerator="gpu", max_epochs=MAX_EPOCHS, 
                     gradient_clip_val=0.1, accumulate_grad_batches=8, log_every_n_steps=5,
                     logger=mlf_logger)

trainer.fit(model)