# OwlV2 finetuning

## Install dependencies

In [1]:
! pip install transformers datasets evaluate accelerate orjson albumentations

Collecting albumentations
  Downloading albumentations-1.4.7-py3-none-any.whl (155 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m155.7/155.7 KB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting opencv-python-headless>=4.9.0
  Downloading opencv_python_headless-4.9.0.80-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (49.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.6/49.6 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting pydantic>=2.7.0
  Downloading pydantic-2.7.1-py3-none-any.whl (409 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m409.3/409.3 KB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting annotated-types>=0.4.0
  Downloading annotated_types-0.6.0-py3-none-any.whl (12 kB)
Collecting pydantic-core==2.18.2
  Downloading pydantic_core-2.18.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━

## Prepare dataset

In [None]:
! cp /host/train_old.zip data/train.zip
! cp /host/val_old.zip data/val.zip
! mkdir data/train data/val
! 7z e data/train.zip -odata/train
! 7z e data/val.zip -odata/val

In [10]:
from glob import glob
import orjson

ann_dict = {}

with open("data/vlm.jsonl") as ann_file:
  for line in ann_file:
    # {"image": "image_0.jpg", "annotations": [{"caption": "grey missile", "bbox": [912, 164, 48, 152]}, {"caption": "red, white, and blue light aircraft", "bbox": [1032, 80, 24, 28]}, {"caption": "green and black missile", "bbox": [704, 508, 76, 64]}, {"caption": "white and red helicopter", "bbox": [524, 116, 112, 48]}]}
    x = orjson.loads(line)
    ann_dict[x["image"]] = x

with open("data/train.jsonl", "wb") as train_ann_file:
  for train_path in glob("data/train/*"):
    train_ann_file.write(orjson.dumps(ann_dict[train_path.split("/")[-1]]) + b"\n")

with open("data/val.jsonl", "wb") as val_ann_file:
  for val_path in glob("data/val/*"):
    val_ann_file.write(orjson.dumps(ann_dict[val_path.split("/")[-1]]) + b"\n")

### Yield dataset in CPPE5 format

In [1]:
from datasets import Dataset
import orjson
from PIL import Image
from tqdm import tqdm

def get_split(split):
  ds_list = []
  with open(f"data/{split}.jsonl") as ann_file:
    for i, line in tqdm(enumerate(ann_file)):
      # {"image": "image_0.jpg", "annotations": [{"caption": "grey missile", "bbox": [912, 164, 48, 152]}, {"caption": "red, white, and blue light aircraft", "bbox": [1032, 80, 24, 28]}, {"caption": "green and black missile", "bbox": [704, 508, 76, 64]}, {"caption": "white and red helicopter", "bbox": [524, 116, 112, 48]}]}
      x = orjson.loads(line)
      anns = x["annotations"]
      img = Image.open(f"data/{split}/{x['image']}")
      # img.load() # bypass PIL lazy loading
      ds_list.append({
        "image_id": int(x["image"][6:-4]),
        # "image_id": i,
        "image": img,
        "width": img.width,
        "height": img.height,
        "objects": {
          # "id" key not used
          "area": [ann["bbox"][2] * ann["bbox"][3] for ann in anns],
          "bbox": [ann["bbox"] for ann in anns],
          # TODO: categories aren't fixed. How to supply text?
          "caption": [ann["caption"] for ann in anns],
        },
      })

  return ds_list

train_ds = Dataset.from_list(get_split("train"))
val_ds = Dataset.from_list(get_split("val"))

  from .autonotebook import tqdm as notebook_tqdm
4086it [00:00, 4869.16it/s]
1021it [00:00, 6404.15it/s]


In [18]:
import torch

o = [{'bbox': (764.0, 124.0, 60.0, 36.0), 'caption': 'black and white commercial aircraft'}, {'bbox': (888.0000000000001, 516.0, 47.999999999999886, 64.0), 'caption': 'green and white fighter plane'}, {'bbox': (804.0, 272.0, 48.000000000000114, 48.0), 'caption': 'white and blue helicopter'}, {'bbox': (712.0, 420.0, 60.0, 63.99999999999994), 'caption': 'grey and black fighter plane'}, {'bbox': (367.99999999999994, 228.00000000000003, 60.00000000000006, 35.99999999999997), 'caption': 'white and black helicopter'}]
torch.tensor([[obj["bbox"] for obj in o] for x in batch_list])

tensor([[0, 0, 0, 0],
        [0, 0, 0, 0]])

In [2]:
train_ds[0:2]

{'image_id': [3778, 1752],
 'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1520x870>,
  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1520x870>],
 'width': [1520, 1520],
 'height': [870, 870],
 'objects': [{'area': [3168, 2464, 1584],
   'bbox': [[876, 140, 88, 36], [608, 456, 44, 56], [1236, 104, 44, 36]],
   'caption': ['yellow, black, and red helicopter',
    'green and white fighter plane',
    'yellow fighter jet']},
  {'area': [4576, 2816, 1792, 3584, 2688, 4800, 1008, 4560],
   'bbox': [[1028, 300, 88, 52],
    [1264, 64, 64, 44],
    [604, 196, 56, 32],
    [1208, 312, 128, 28],
    [716, 392, 56, 48],
    [1064, 48, 80, 60],
    [848, 312, 36, 28],
    [928, 444, 76, 60]],
   'caption': ['white, blue, and red commercial aircraft',
    'black and orange drone',
    'white, black, and red drone',
    'green and grey helicopter',
    'black cargo aircraft',
    'green missile',
    'yellow fighter plane',
    'white and red commercial aircraft']}]}

Below code yields the dataset as a generator, which is slower for iteration.

In [45]:
from datasets import Dataset
from PIL import Image
import orjson

def get_ds_generator(split):
  def ds_generator():
    with open(f"data/{split}.jsonl") as ann_file:
      for i, line in enumerate(ann_file):
        # {"image": "image_0.jpg", "annotations": [{"caption": "grey missile", "bbox": [912, 164, 48, 152]}, {"caption": "red, white, and blue light aircraft", "bbox": [1032, 80, 24, 28]}, {"caption": "green and black missile", "bbox": [704, 508, 76, 64]}, {"caption": "white and red helicopter", "bbox": [524, 116, 112, 48]}]}
        x = orjson.loads(line)
        anns = x["annotations"]
        img = Image.open(f"data/{split}/{x['image']}")
        yield {
          "image_id": int(x["image"][6:-4]),
          # "image_id": i,
          "image": img,
          "width": img.width,
          "height": img.height,
          "objects": {
            # "id" key not used
            "area": [ann["bbox"][2] * ann["bbox"][3] for ann in anns],
            "bbox": [ann["bbox"] for ann in anns],
            # TODO: categories aren't fixed. How to supply text?
            "caption": [ann["caption"] for ann in anns],
          },
        }
  
  return ds_generator

train_ds = Dataset.from_generator(get_ds_generator("train"))
val_ds = Dataset.from_generator(get_ds_generator("val"))

In [11]:
train_ds[0]

{'image_id': 3778,
 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1520x870>,
 'width': 1520,
 'height': 870,
 'objects': {'area': [3168, 2464, 1584],
  'bbox': [[876, 140, 88, 36], [608, 456, 44, 56], [1236, 104, 44, 36]],
  'caption': ['yellow, black, and red helicopter',
   'green and white fighter plane',
   'yellow fighter jet']}}

### Transforms

In [2]:
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

checkpoint = "google/owlv2-large-patch14-ensemble"

model = AutoModelForZeroShotObjectDetection.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)

In [3]:
import albumentations

transforms = albumentations.Compose(
  transforms=[
    # How noisy will their test images be? var = 2500 is extreme!
    albumentations.GaussNoise(var_limit=2500, p=0.5),
    albumentations.HorizontalFlip(p=0.5),
    albumentations.VerticalFlip(p=0.5),
    albumentations.RandomBrightnessContrast(
      p=0.5,
      brightness_limit=0.3,
      contrast_limit=0.3,
    ),
    # MixUp / Mosaic?
  ],
  bbox_params=albumentations.BboxParams(format="coco", label_fields=[]),
)

In [10]:
import torch

def transform_sample(batch):
  # batch is a Dict[str, list], NOT a list!
  # TODO: add augmentation

  # batch_transformed = {
  #   "input_ids": [],
  #   "attention_mask": [],
  #   "pixel_values": [],
  #   "labels": [],
  # }

  # for obj, img in zip(batch["objects"], batch["image"]):
  #   sample_transformed = processor(
  #     text=obj["caption"],
  #     images=img,
  #     return_tensors="pt",
  #   )
  #   batch_transformed["input_ids"].append(sample_transformed["input_ids"])
  #   batch_transformed["attention_mask"].append(sample_transformed["attention_mask"])
  #   batch_transformed["pixel_values"].append(sample_transformed["pixel_values"])
  #   batch_transformed["labels"].append(torch.tensor(obj["bbox"]))


  batch_transformed = processor(
    text=[obj["caption"] for obj in batch["objects"]],
    images=batch["image"],
    return_tensors="pt",
  )
  batch_transformed["labels"] = [torch.tensor(obj["bbox"]) for obj in batch["objects"]]
  
  return batch_transformed

def collate_fn(batch_list):
  batch = {}

  batch["input_ids"] = torch.cat([x["input_ids"] for x in batch_list])
  batch["attention_mask"] = torch.cat([x["attention_mask"] for x in batch_list])
  batch["pixel_values"] = torch.cat([x["pixel_values"] for x in batch_list])
  # batch["labels"] = torch.nn.utils.rnn.pad_sequence(
  #   sequences=[x["labels"] for x in batch_list],
  #   batch_first=True,
  #   padding_value=-1.,
  # )
  batch["labels"] = torch.cat([x["labels"] for x in batch_list])
  
  # print(batch)
  return batch

In [12]:
train_ds = train_ds.with_transform(transform_sample)
val_ds = val_ds.with_transform(transform_sample)

In [15]:
train_ds[0:2]

{'input_ids': tensor([[49406,  1579,  9397, 49407,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [49406,  1579,   537,  4287,  1395,  7706, 49407,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [49406,  5046, 11956, 49407,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [49406,   736,   537,  1579,  6438,  6565, 49407,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [49406,  1579,   537,  1746,  6438,  5363, 49407,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [49406,  5046,   537,   736,  6287,  7706, 49407,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [49406,  1579,  1395,  7706, 49407,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [49406,  1579,   537,   736,  6287,  7706, 49407,     0,     

In [17]:
from torch.utils.data import DataLoader

train_dl = DataLoader(train_ds, batch_size=2)
next(iter(train_dl))

IndexError: index 2 is out of bounds for dimension 0 with size 2

## Utility functions for calculating loss

In [140]:
import logging

logging.basicConfig(
    level=logging.WARNING,  # Adjust as needed (DEBUG, INFO, WARNING, ERROR, CRITICAL)
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

# Using Detr-Loss calculation https://github.com/facebookresearch/detr/blob/main/models/matcher.py
# https://www.kaggle.com/code/bibhasmondal96/detr-from-scratch
class BoxUtils(object):
    @staticmethod
    def box_cxcywh_to_xyxy(x):
        x_c, y_c, w, h = x.unbind(-1)
        b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
             (x_c + 0.5 * w), (y_c + 0.5 * h)]
        return torch.stack(b, dim=-1)

    @staticmethod
    def box_xyxy_to_cxcywh(x):
        x0, y0, x1, y1 = x.unbind(-1)
        b = [(x0 + x1) / 2, (y0 + y1) / 2,
             (x1 - x0), (y1 - y0)]
        return torch.stack(b, dim=-1)

    @staticmethod
    def rescale_bboxes(out_bbox, size):
        img_h, img_w = size
        b = BoxUtils.box_cxcywh_to_xyxy(out_bbox)
        b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
        return b

    @staticmethod
    def box_area(boxes):
        """
        Computes the area of a set of bounding boxes, which are specified by its
        (x1, y1, x2, y2) coordinates.
        Arguments:
            boxes (Tensor[N, 4]): boxes for which the area will be computed. They
                are expected to be in (x1, y1, x2, y2) format
        Returns:
            area (Tensor[N]): area for each box
        """
        return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
        
    @staticmethod
    # modified from torchvision to also return the union
    def box_iou(boxes1, boxes2):
        area1 = BoxUtils.box_area(boxes1)
        area2 = BoxUtils.box_area(boxes2)

        lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
        rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

        wh = (rb - lt).clamp(min=0)  # [N,M,2]
        inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

        union = area1[:, None] + area2 - inter

        iou = inter / union
        return iou, union

    @staticmethod
    def generalized_box_iou(boxes1, boxes2):
        """
        Generalized IoU from https://giou.stanford.edu/
        The boxes should be in [x0, y0, x1, y1] format
        Returns a [N, M] pairwise matrix, where N = len(boxes1)
        and M = len(boxes2)
        """
        # degenerate boxes gives inf / nan results
        # so do an early check
        assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
        assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
        iou, union = BoxUtils.box_iou(boxes1, boxes2)

        lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
        rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

        wh = (rb - lt).clamp(min=0)  # [N,M,2]
        area = wh[:, :, 0] * wh[:, :, 1]

        return iou - (area - union) / area

In [141]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from scipy.optimize import linear_sum_assignment

class HungarianMatcher(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network
    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
        """Creates the matcher
        Params:
            cost_class: This is the relative weight of the classification error in the matching cost
            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou
        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"

    @torch.no_grad()
    def forward(self, outputs, targets):
        """ Performs the matching
        Params:
            outputs: This is a dict that contains at least these entries:
                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
                           objects in the target) containing the class labels
                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds:
                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
        print("HungarianMatcher.forward() called!")
        print("outputs=", outputs)
        print("targets=", targets)

        logging.info(f"{outputs.keys()=}")
        bs, num_queries = outputs["logits"].shape[:2]

        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

        # Also concat the target labels and boxes
        tgt_ids = torch.cat([v["class_labels"] for v in targets])
        logging.info(f"forward - {tgt_ids}")
        tgt_ids = tgt_ids.int()
        logging.info(f"forward - {tgt_ids}")


        tgt_bbox = torch.cat([v["boxes"] for v in targets])

        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
        # but approximate it in 1 - proba[target class].
        # The 1 is a constant that doesn't change the matching, it can be ommitted.
        cost_class = -out_prob[:, tgt_ids]

        # Compute the L1 cost between boxes
        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)

        # Compute the giou cost betwen boxes
        cost_giou = -BoxUtils.generalized_box_iou(
            BoxUtils.box_cxcywh_to_xyxy(out_bbox),
            BoxUtils.box_cxcywh_to_xyxy(tgt_bbox)
        )

        # Final cost matrix
        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
        C = C.view(bs, num_queries, -1).cpu()

        sizes = [len(v["boxes"]) for v in targets]
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """
    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer('empty_weight', empty_weight)

    def loss_labels(self, outputs, targets, indices, num_boxes):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        logging.info(f"loss_labels - {outputs.keys()}")
        assert 'logits' in outputs
        src_logits = outputs['logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)]).to(torch.int64)
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device).to(torch.int64)
        target_classes[idx] = target_classes_o

        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {'loss_ce': loss_ce}
        return losses

    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_boxes):
        """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
        """
        pred_logits = outputs['logits']
        device = pred_logits.device
        tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {'cardinality_error': card_err}
        return losses

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
        """
        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)

        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')

        losses = {}
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes

        loss_giou = 1 - torch.diag(BoxUtils.generalized_box_iou(
            BoxUtils.box_cxcywh_to_xyxy(src_boxes),
            BoxUtils.box_cxcywh_to_xyxy(target_boxes))
        )
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        return losses

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'boxes': self.loss_boxes,
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)

    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        logging.info(f"{type(outputs)=}")
        logging.info(f"{type(targets)=}")
        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets)

        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(t["class_labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
        return losses

## Training config

In [138]:
import torch
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup

# TODO: scale params based on batch size
optimizer = torch.optim.AdamW(
  params=model.parameters(),
  lr=1e-5,
  betas=(0.9, 0.999),
  eps=1e-8,
  weight_decay=1e-2,
)

scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
  optimizer=optimizer,
  num_warmup_steps=100,
  num_training_steps=10000,
  num_cycles=10,
  last_epoch=-1,
)

In [161]:
from accelerate import Accelerator

accelerator = Accelerator()

model, optimizer, scheduler = accelerator.prepare(
  model, optimizer, scheduler
)

In [7]:
from transformers import Trainer

class CustomTrainer(Trainer):
  def compute_loss(self, model, inputs, return_outputs=False):
    labels = inputs.pop("labels")
    outputs = model(**inputs)
    
    

In [9]:
from transformers import TrainingArguments, Trainer
import numpy as np

training_args = TrainingArguments(
  output_dir="owlv2-large-patch14-ensemble",
  num_train_epochs=30,
  learning_rate=5e-6,
  # lr_scheduler_type=schedule,
  eval_strategy="epoch",
  # auto_find_batch_size=True,
  # TODO: adjust batch size
  per_device_train_batch_size=1,
  per_device_eval_batch_size=1,
  save_strategy="epoch",
  bf16=True,
  dataloader_num_workers=8,
  remove_unused_columns=False,

  # stupid workaround
  label_smoothing_factor=np.nextafter(0, 1),
)

trainer = CustomTrainer(
  model=model,
  args=training_args,
  train_dataset=train_ds,
  eval_dataset=val_ds,
  data_collator=collate_fn,
)

trainer.train()

OutOfMemoryError: CUDA out of memory. Tried to allocate 22.00 MiB. GPU  has a total capacity of 23.48 GiB of which 22.06 MiB is free. Including non-PyTorch memory, this process has 0 bytes memory in use. Of the allocated memory 1.50 GiB is allocated by PyTorch, and 133.85 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [139]:
from transformers import Trainer

def custom_loss(logits, labels):
  num_classes = 4 # what the fuck to do with this???
  matcher = HungarianMatcher(cost_class = 1, cost_bbox = 5, cost_giou = 2)
  weight_dict = {'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2}
  losses = ['labels', 'boxes', 'cardinality']
  criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=0.1, losses=losses)
  criterion.to(accelerator.device)
  loss = criterion(logits, labels)
  return loss

class CustomTrainer(Trainer):
  def compute_loss(self, model, inputs, return_outputs=False):
    labels = inputs.pop("labels")

    inputs["input_ids"] = inputs["input_ids"][0]
    inputs["attention_mask"] = inputs["attention_mask"][0]
    outputs = model(**inputs, return_dict=True)
    loss = custom_loss(outputs, labels)
    loss_ce = loss['loss_ce'].cpu().item()
    loss_bbox = loss['loss_bbox'].cpu().item()
    loss_giou = loss['loss_giou'].cpu().item()
    cardinality_error = loss['cardinality_error'].cpu().item()
    # print(
    #     f"loss_ce={loss_ce:.2f}",
    #     f"loss_bbox={loss_bbox:.2f}",
    #     f"loss_giou={loss_giou:.2f}",
    #     f"cardinality_error={cardinality_error:.2f}",
    #     sep="\t")
    loss = sum(loss.values())[0] #add
    return (loss, outputs) if return_outputs else loss

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="owlv2-large-patch14-ensemble",
  num_train_epochs=30,
  learning_rate=5e-6,
  # lr_scheduler_type=schedule,
  eval_strategy="epoch",
  # auto_find_batch_size=True,
  # TODO: adjust batch size
  per_device_train_batch_size=1,
  per_device_eval_batch_size=1,
  save_strategy="epoch",
  bf16=True,
  dataloader_num_workers=32,
  remove_unused_columns=False,
)

trainer = CustomTrainer(
  model=model,
  args=training_args,
  data_collator=collate_fn,
  train_dataset=train_ds,
  eval_dataset=train_ds,
  tokenizer=processor,
)

trainer.train()

## Train

In [None]:
trainer.train()