# 🦉 OWL-ViT inference playground

OWL-ViT is an **open-vocabulary object detector**. Given a free-text query, it will find objects matching that query. It can also do **one-shot object detection**, i.e. detect objects based on a single example image.

This Colab allows you to query the model interactively, to get a feeling for its capabilities. For details on the model, check out the [paper](https://arxiv.org/abs/2205.06230) or the [code](https://github.com/google-research/scenic/tree/main/scenic/projects/owl_vit).

> ❗ Note: The free public Colab runtime has enough memory for the ViT-B/16 model. For optimal results, use a Pro or local runtime and the ViT-L/14 model.

> ❗ Note: This Colab is optimized for fast interactive exploration. It does not apply some of the optimizations and augmentations that would be used in a rigorous evaluation settings, so results from this Colab may not match the paper.

## How to use this Colab
1. Use a GPU or TPU Colab runtime.
2. Run all cells in the Colab from top to bottom.
3. Go to the cells for [Text-conditioned object detection](#scrollTo=aNzcyP1sbJ9w&uniqifier=1) or [Image-conditioned object detection](#scrollTo=TFlZhrDTQbiY&uniqifier=1) and have fun!

**If you run into any problems, please [file an issue](https://github.com/google-research/scenic/issues/new?title=OWL-ViT+inference+playround:+[add+title]) on GitHub.**



# Setup

OWL-ViT is implemented in [Scenic](https://github.com/google-research/scenic). The cell below installs the Scenic codebase from GitHub and imports it.

In [None]:
from google.colab import drive
drive.flush_and_unmount()

In [None]:
# colab workspaces dont have a lot of fonts installed, and this one is pretty nice
!curl -L -O https://github.com/source-foundry/Hack/releases/download/v3.003/Hack-v3.003-ttf.tar.gz
!tar -xzvf Hack-v3.003-ttf.tar.gz
!mkdir /usr/share/fonts/truetype/ttf
!mv ttf/Hack-Regular.ttf /usr/share/fonts/truetype/ttf/Hack-Regular.ttf

In [None]:
import os
if 'COLAB_GPU' in os.environ:
  print("I'm running on Colab")
else:
  print("Running this notebook on your own pc is dangerous due to the setup of scenic (which i didnt mess with)... Use colab or a docker container[if using docker, remove these few lines]")
  exit()

In [None]:
if 'Cloned' not in globals(): # Only clone once (this is time-consuming, and unnecessary if already done)
  !rm -rf *
  !rm -rf .config
  !rm -rf .git
  !git clone https://github.com/google-research/scenic.git .
  !python -m pip install -q .
  !python -m pip install -r ./scenic/projects/owl_vit/requirements.txt

  # Also install big_vision, which is needed for the mask head:
  !mkdir /big_vision
  !git clone https://github.com/google-research/big_vision.git /big_vision
  !python -m pip install -r /big_vision/big_vision/requirements.txt
  import sys
  sys.path.append('/big_vision/')
  !echo "Done."
  Cloned = True

In [None]:
!python -m pip install brambox

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
annotations_folder_path: str = "./drive/MyDrive/AI/Data/Annotations/"
images_folder_path: str = "./drive/MyDrive/AI/Data/JPEGImages/"
fixed_images_path: str = "./drive/MyDrive/AI/Data/Fixed_support/"

output_path_base: str = "./drive/MyDrive/AI/Outputs/"
saves_path_base: str= "./drive/MyDrive/AI/Saves/"

In [None]:
import os

from bokeh import io as bokeh_io
import jax
from google.colab import output as colab_output
import matplotlib as mpl
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image, ImageFont, ImageDraw
from scenic.projects.owl_vit import models
from scenic.projects.owl_vit.configs import clip_l14 as config_module
from scenic.projects.owl_vit.notebooks import inference
from scenic.projects.owl_vit.notebooks import interactive
from scipy.special import expit as sigmoid
import skimage
from skimage import io as skimage_io
from skimage import transform as skimage_transform
import tensorflow as tf

import brambox as bb
import itertools

tf.config.experimental.set_visible_devices([], 'GPU')
bokeh_io.output_notebook(hide_banner=True)

In [None]:
class BBAnnotatedImage:
    def __init__(self, *, img_dir = "", filename= "", annotations:pd.DataFrame, scale=1) -> None:
        if not filename:
            raise ValueError("filename cannot be empty")
        if not os.path.exists(os.path.join(img_dir, filename)):
            raise FileNotFoundError(f"image file {filename} not found in {img_dir}")
        self.filename: str = filename
        self.image: Image.Image = Image.open(os.path.join(img_dir, filename))
        self.annotations: pd.DataFrame = annotations
        # if height > width rotate image and annotations to get all landscape img
        if self.height > self.width:
            width = self.width
            self.image = self.image.rotate(90, expand=True)
            new_annotations = pd.DataFrame(columns = self.annotations.columns)
            for _, annotation in self.annotations.iterrows():
                annotation['x_top_left'], annotation['y_top_left'] = annotation['y_top_left'], width - annotation['x_top_left'] - annotation['width']
                annotation['width'], annotation['height'] = annotation['height'], annotation['width']
                new_annotations = pd.concat([new_annotations, annotation.to_frame().T])
            self.annotations = new_annotations

        width = self.width // scale
        height = self.height // scale
        self.image = self.image.resize((width, height))
        new_annotations = pd.DataFrame(columns = self.annotations.columns)
        for _, annotation in self.annotations.iterrows():
          annotation['x_top_left'], annotation['y_top_left'] = annotation['x_top_left'] // scale, annotation['y_top_left'] // scale
          annotation['width'], annotation['height'] = annotation['width'] // scale, annotation['height'] // scale
          new_annotations = pd.concat([new_annotations, annotation.to_frame().T])
        self.annotations = new_annotations


    @property
    def width(self) -> int:
        return self.image.width

    @property
    def height(self) -> int:
        return self.image.height

    @staticmethod
    def from_xml(xml_file:str, img_dir:str = "", scale=1) -> 'BBAnnotatedImage':
        annotations:pd.DataFrame = bb.io.load('anno_pascalvoc', xml_file)
        filename = annotations.iloc[[0]]['image'].values[0] + '.jpg'
        return BBAnnotatedImage(filename = filename, img_dir = img_dir, annotations = annotations, scale=scale)

    def get_object_image_cutout(self, annotation_row:pd.Series) -> Image.Image:
        return self.image.crop((annotation_row['x_top_left'], annotation_row['y_top_left'], annotation_row['x_top_left'] + annotation_row['width'], annotation_row['y_top_left'] + annotation_row['height']))

    @property
    def annotation_classes(self) -> list[str]:
        return self.annotations['class_label'].tolist()

    def drawn_annotations(self, resize=10) -> Image.Image:
        img = self.image.copy()
        draw = ImageDraw.Draw(img)
        for _, annotation in self.annotations.iterrows():
            draw.rectangle((annotation['x_top_left'], annotation['y_top_left'], annotation['x_top_left'] + annotation['width'], annotation['y_top_left'] + annotation['height']), outline='red', width=20)
            draw.text((annotation['x_top_left'], annotation['y_top_left']), annotation['class_label'], fill="red", font=ImageFont.truetype("arial.ttf", 150))
        draw.text((10, 10), self.filename, fill="red", font=ImageFont.truetype("arial.ttf", 150))
        img = img.resize((int(img.width / resize), int(img.height / resize)))
        return img

    @property
    def object_image_cutouts(self) -> list[Image.Image]:
        return [self.get_object_image_cutout(annotation) for _, annotation in self.annotations.iterrows()]

    @property
    def objects(self) -> list[tuple[str, Image.Image]]:
        return list(zip(self.annotation_classes, self.object_image_cutouts))

    @property
    def npimage(self):
        return np.asarray(self.image)

In [None]:
def create_folder_if_not_exists(folder_name):
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

In [None]:
def bb_load_dataset(annotations_folder:str, images_folder:str, area_column:bool=True, scale = 1) -> list[BBAnnotatedImage]:
    _, _, files = next(os.walk(annotations_folder))
    annotated_images:list[BBAnnotatedImage] = []
    for file in files:
        annotated_images.append(BBAnnotatedImage.from_xml(annotations_folder + file, images_folder, scale=scale))
    if area_column:
        # add area column
        for annotated_image in annotated_images:
            for _, annotation in annotated_image.annotations.iterrows():
                annotated_image.annotations.loc[_, 'area'] = annotation['width'] * annotation['height']
    return annotated_images

dataset:list[BBAnnotatedImage] = bb_load_dataset(annotations_folder_path, images_folder_path, scale=4)

In [None]:
def filter_by_confidence(dataframe, minimum_score:float):
  return dataframe.loc[dataframe['confidence'] > minimum_score]

def get_detections_for_image(detections, image_name):
    return detections.loc[detections['image'] == image_name]

def filter_by_iou(dataframe, iou:float):
  return dataframe.loc[dataframe['iou'] < iou]

In [None]:
def draw_boxes(image:BBAnnotatedImage, data:pd.DataFrame, alligntop=True, draw_legend=True, line=True, manual_image:Image=None, forced_color = False):
    image_name = image.filename.split('.')[0]
    boxed = image.image.copy()
    if manual_image != None:
      boxed = manual_image
    boxed = boxed.convert("RGBA")
    draw:ImageDraw = ImageDraw.Draw(boxed)
    data = get_detections_for_image(data, image_name)
    default_colors = [
        (31, 119, 180),
        (255, 127, 14),
        (44, 160, 44),
        (214, 39, 40),
        (148, 103, 189),
        (140, 86, 75),
        (227, 119, 194),
        (127, 127, 127),
        (188, 189, 34),
        (23, 190, 207),
    ]
    colors_assigned = dict(zip(annotations['class_label'].unique(), default_colors))
    # map all detections to colors inplace in detections
    #data['color'] = data['class_label'].map(colors_assigned)

    if line:
        # make all colors darker
        #data['color'] = data['color'].apply(lambda x: tuple([int(y * 0.9) for y in x]))
        for index, row in data.iterrows():
            color=colors_assigned[row['class_label']]
            if forced_color:
              color = (0, 0, 0)
            #draw.rectangle([row['x_top_left'] * boxed.width, row['y_top_left'] * boxed.height,(row['x_top_left'] + row['width']) * boxed.width, (row['y_top_left'] + row['height']) * boxed.height], outline=color)
            draw.rectangle([row['x_top_left'], row['y_top_left'], (row['x_top_left'] + row['width']), (row['y_top_left'] + row['height'])], outline=color, width=5)
    else:
        for index, row in data.iterrows():
            color=colors_assigned[row['class_label']]
            overlay = Image.new('RGBA', boxed.size, (0,0,0,0))
            draw_overlay = ImageDraw.Draw(overlay)  # Create a context for drawing things on it.
            draw_overlay.rectangle([row['x_top_left'], row['y_top_left'], (row['x_top_left'] + row['width']), (row['y_top_left'] + row['height'])], fill=color+(102,))
            boxed = Image.alpha_composite(boxed, overlay)
    if draw_legend:
    # write legend in top left corner
        if alligntop:
            legend_x = 10
            legend_y = 10
            legend_size = image.height // 20
            for class_label, color in colors_assigned.items():
                draw.rectangle([legend_x, legend_y, legend_x + legend_size, legend_y + legend_size], fill=color)
                draw.text((legend_x + legend_size + 10, legend_y), class_label, fill=color, font=ImageFont.truetype("Hack-Regular.ttf", legend_size))
                legend_y += legend_size + 10
        else:
            #legend in bottom left corner
            legend_x = 10
            legend_y = boxed.height - 10
            legend_size = boxed.height // 20
            for class_label, color in colors_assigned.items().__reversed__():
                draw.rectangle([legend_x, legend_y - legend_size, legend_x + legend_size, legend_y], fill=color)
                draw.text((legend_x + legend_size + 10, legend_y - legend_size), class_label, fill=color, font=ImageFont.truetype("arial.ttf", legend_size))
                legend_y -= legend_size + 10
    boxed = boxed.convert("RGB")
    return boxed


In [None]:
def calc_iou(detections, image_boxes) -> None: #inplace
    # image_boxes: Dataframe with row for each box in the image ([x_top_left, y_top_left, width, height, confidence, iou])
    # fill iou row inplace
    indices = image_boxes.index.values.tolist()
    for i in range(len(image_boxes)):
        # get box i
        box_i = detections.loc[indices[i]]
        # get coordinates of box i
        x1_i = box_i['x_top_left']
        y1_i = box_i['y_top_left']
        x2_i = x1_i + box_i['width']
        y2_i = y1_i + box_i['height']
        # get area of box i
        area_i = box_i['width'] * box_i['height']
        # get confidence of box i
        conf_i = box_i['confidence']
        # compare box i with all other boxes
        for j in range(i + 1, len(image_boxes)):
            # get box j
            box_j = detections.loc[indices[j]]
            # get confidence of box j
            conf_j = box_j['confidence']
            # get coordinates of box j
            x1_j = box_j['x_top_left']
            y1_j = box_j['y_top_left']
            x2_j = x1_j + box_j['width']
            y2_j = y1_j + box_j['height']
            # get area of box j
            area_j = box_j['width'] * box_j['height']
            # compute intersection
            x1_inter = max(x1_i, x1_j)
            y1_inter = max(y1_i, y1_j)
            x2_inter = min(x2_i, x2_j)
            y2_inter = min(y2_i, y2_j)
            # check if there is an intersection
            if x1_inter > x2_inter or y1_inter > y2_inter:
              continue
            # compute area of intersection
            area_inter = (x2_inter - x1_inter) * (y2_inter - y1_inter)
            # compute iou
            iou = area_inter / (area_i + area_j - area_inter)
            # update iou value of box i/j
            if conf_j > conf_i:
              if iou > box_i['iou']:
                detections.loc[indices[i], 'iou'] = iou
            else:
              if iou > box_j['iou']:
                detections.loc[indices[j], 'iou'] = iou


## Set up the model
This takes a minute or two.

In [None]:
config = config_module.get_config(init_mode='canonical_checkpoint')
module = models.TextZeroShotDetectionModule(
    body_configs=config.model.body,
    normalize=config.model.normalize,
    box_bias=config.model.box_bias)
variables = module.load_variables(config.init_from.checkpoint_path)
model = inference.Model(config, module, variables)
model.warm_up()

In [None]:
#this uses the global model so yeah its here

def get_image_query_embedding(image:BBAnnotatedImage, annotation_index:int):
  _, annotation = list(image.annotations.iterrows())[annotation_index]
  y_min = annotation.y_top_left
  x_min = annotation.x_top_left
  y_max = y_min + annotation.height
  x_max = x_min + annotation.width

  y_min_norm = y_min / image.width # all width because width is greatest and internally img is padded
  x_min_norm = x_min / image.width
  y_max_norm = y_max / image.width
  x_max_norm = x_max / image.width

  query_embeddings, _ = model.embed_image_query(image.npimage, (y_min_norm, x_min_norm, y_max_norm, x_max_norm))

  return annotation.class_label, query_embeddings

def get_all_query_embeddings(image:BBAnnotatedImage) -> list(tuple((str, np.ndarray))):
  results = []
  for i in range(len(image.annotations)):
    results.append(get_image_query_embedding(image, i))
  return results

def get_all_query_embeddings_stubs(image:BBAnnotatedImage):
  results = []
  for i in range(len(image.annotations)):
    _, annotation = list(image.annotations.iterrows())[i]
    results.append((annotation.class_label, (0)))
  return results

# Set run-type

In [None]:
use_text_queries = False
no_inference = True

classless = False

use_negative_queries = False #text only

N = 20 #image only
use_all_generated_image_embeddings = False  # we eliminate images we take our query embeddings from anyways, might aswell use them completely


class_string = "classless" if classless else "classed"
neg_string = "use_negative_queries" if use_negative_queries else "only_positive_queries"
text_string = "text" if use_text_queries else "img"
extra_img_string = f"full_{N}_" if (use_all_generated_image_embeddings) else f"{N}_"
extra_img_string = extra_img_string if not use_text_queries else ""

output_path = f"{output_path_base}/owlvit/{text_string}/{class_string}/{neg_string}/{extra_img_string}"
create_folder_if_not_exists(output_path)

saves_path = f"{saves_path_base}/owlvit/{text_string}/{class_string}/{neg_string}/{extra_img_string}"
create_folder_if_not_exists(saves_path)

### Set annotations, dataset and embeddings

In [None]:
if use_text_queries:
  annotations:pd.DataFrame = pd.concat([image.annotations for image in dataset])

  if classless:
    positive_text_queries = ("shell", "seashell") + tuple(annotations.class_label.unique())
  else:
    positive_text_queries = tuple(annotations.class_label.unique())
  if use_negative_queries:
    negative_text_queries = ("rock", "sand")
  else:
    negative_text_queries = ()
  queries = negative_text_queries + positive_text_queries
  print(queries)
  query_embeddings = model.embed_text_queries(queries)

  test_dataset = dataset

else:
  annotations:pd.DataFrame = pd.concat([image.annotations for image in dataset])

  test_dataset = dataset.copy()
  print(test_dataset)
  query_dataset = set()

  # we take out the query images from which we make the query embeddings
  # if a class has less than N annotations we dont use it
  annotations_per_class = annotations['class_label'].value_counts()
  print(annotations_per_class)
  removed_classes = []
  queries_dict = {}
  for annotation_class, _ in annotations_per_class.items():
    queries_dict[annotation_class] = []
  for annotation_class, amount in reversed(list(annotations_per_class.items())): #loop over counts in reverse
    print(amount)
    if amount < N:
      print(f"Removing: {annotation_class} as it has less annotations than N({N})")
      removed_classes.append(annotation_class)
    else:
      if len(queries_dict[annotation_class]) > N:
        print(f"Skipping: {annotation_class} as it is already populated with more embeddings than N({N})")
        continue #we dont need to generate any more queries for classes that are already sufficiently populated

      print(f"Generating query embeddings for {annotation_class}")
      for image in test_dataset:
        if annotation_class in list(image.annotations.class_label.values):
          query_dataset.add(image)
          if no_inference:
            img_query_embeddings = get_all_query_embeddings_stubs(image)
          else:
            img_query_embeddings = get_all_query_embeddings(image)
          for query_class, embedding in img_query_embeddings:
            queries_dict[query_class].append(embedding)
        if len(queries_dict[annotation_class]) > N:
          print(f"Done: {annotation_class} is populated with more embeddings than N({N})")
          break

  for image in query_dataset:
    test_dataset.remove(image)

  annotations = pd.concat([image.annotations for image in test_dataset])

  for c in removed_classes:
    annotations = annotations[annotations['class_label'] != c]
    queries_dict.pop(c)


  query_embeddings_list = []
  queries = []

  for key, value in queries_dict.items():
    print(key, len(value))
    queries.append(key)
    if use_all_generated_image_embeddings:
      query_embeddings_list.append(np.mean(value, axis=0))
    else:
      query_embeddings_list.append(np.mean(value[:N], axis=0))

  query_embeddings = np.array(query_embeddings_list)



if classless:
    annotations['class_label'] = 'object'

lowest_confidence = 0.10 if classless else 0.05
lowest_confidence = lowest_confidence if use_text_queries else 0.6

# Run

## Inference

In [None]:
if use_text_queries:
  # This isnt all that useful, but it is used to count the amount of boxes that the model gives.
  # Could this be done with our own images or even a null-image?
  # Sure, but why bother if it already works (doesnt take that long either)

  IMAGE_DIR = 'gs://scenic-bucket/owl_vit/example_images'
  %matplotlib inline

  images = {}

  for i, filename in enumerate(tf.io.gfile.listdir(IMAGE_DIR)):
    with tf.io.gfile.GFile(os.path.join(IMAGE_DIR, filename), 'rb') as f:
      image = mpl.image.imread(
          f, format=os.path.splitext(filename)[-1])[..., :3]
    if np.max(image) <= 1.:
      image *= 255
    images[i] = image
  IMAGE_ID =   2
  image = images[IMAGE_ID]

  _, _, boxes = model.embed_image(image)

else:
  test_image_annotated = dataset[1]
  _, _, boxes = model.embed_image(test_image_annotated.npimage)

In [None]:
detections:pd.DataFrame = pd.DataFrame(index=range(len(test_dataset) * len(boxes)), columns=['image', 'class_label', 'id', 'x_top_left', 'y_top_left', 'width', 'height', 'confidence', 'area'])

index = 0

for q in range(len(test_dataset)):
    print(f"{q+1}/{len(test_dataset)}")
# for q in range(1):
    test_image_annotated = test_dataset[q]
    test_image_name = test_image_annotated.filename.split(".")[0]
    test_image = test_image_annotated.npimage

    _, _, boxes = model.embed_image(test_image)
    query_index, scores = model.get_scores(test_image, query_embeddings, len(queries))

    zipped_scores = zip(query_index, scores, boxes) # tuples: (label_id, score, [x, y, w, h])
    for i in range(len(boxes)):
        box = boxes[i]
        score = scores[i]
        class_label = queries[query_index[i]]
        x_top_left, y_top_left, width, height = box
        detections.loc[index] = [test_image_name, class_label, i, ((x_top_left - width/2) * test_image_annotated.width), ((y_top_left - height/2) * test_image_annotated.width) , width * test_image_annotated.width, height * test_image_annotated.width, score, width * height * test_image_annotated.height * test_image_annotated.width]
        index += 1

        #new_row:pd.DataFrame = pd.DataFrame(columns=detections.columns, data=[[test_image_name, class_label, i, x_top_left, y_top_left, width, height, score]])

        #detections.iloc[q*len(queries)+i] = new_row

# add area column
detections['image'] = detections['image'].astype('category')
detections['area'] = detections['width'] * detections['height']

if use_negative_queries and use_text_queries:
  for q in negative_text_queries:
      detections = detections[detections['class_label'] != q]
if classless:
  for q in queries:
    detections.loc[detections['class_label'] == q, 'class_label'] = 'object'

bb.io.save(detections, "pandas", saves_path + "unprocessed_detections.pkl")

## Post process results (add IoU)

Yes, this is horribly inefficient and should definitely not be done in python (or be this unoptimised). But it is what it is

In [None]:
detections = bb.io.load("pandas", saves_path + "unprocessed_detections.pkl")

detections['iou'] = 0.0

detections = filter_by_confidence(detections, lowest_confidence)

index = 1
total_len = len(detections.image.unique())

for image in detections.image.unique():
  img_detections = get_detections_for_image(detections, image)
  print(f"{index}/{total_len} [{len(img_detections)}]")
  index += 1
  calc_iou(detections, img_detections)

bb.io.save(detections, "pandas", saves_path + "processed_detections.pkl")

## Results

In [None]:
detections = bb.io.load("pandas", saves_path + "processed_detections.pkl")

detections['height'] = detections['height'].astype(float)
detections['confidence'] = detections['confidence'].astype(float)
detections['area'] = detections['area'].astype(float)
detections['iou'] = detections['iou'].astype(float)
detections['x_top_left'] = detections['x_top_left'].astype(float)
detections['y_top_left'] = detections['y_top_left'].astype(float)
detections['width'] = detections['width'].astype(float)


annotations['x_top_left'] = annotations['x_top_left'].astype(float)
annotations['y_top_left'] = annotations['y_top_left'].astype(float)
annotations['width'] = annotations['width'].astype(float)
annotations['height'] = annotations['height'].astype(float)
annotations['occluded'] = annotations['occluded'].astype(float)
annotations['truncated'] = annotations['truncated'].astype(float)
annotations['lost'] = annotations['lost'].astype(bool)
annotations['difficult'] = annotations['difficult'].astype(bool)
annotations['ignore'] = annotations['ignore'].astype(bool)

iou_thresholds = [round(0.1 * x, 2) for x in range(6,11)]
score_thresholds = [round(x/10, 6) for x in range(round(lowest_confidence * 10), 10)] #lower than the lowest score is meaninless ofc

options_matrix = list(itertools.product(iou_thresholds, score_thresholds))

save_images = True
image_start = 50
image_amount = 50

image_stop = image_start + image_amount
if image_stop > len(test_dataset):
  image_stop = len(test_dataset)



if save_images:
  for iou, score in options_matrix:
    det_iou = filter_by_iou(detections, iou)
    det = filter_by_confidence(det_iou, score)

    img_path = f"{output_path}iou{iou}/score{score}/"
    create_folder_if_not_exists(img_path)
    print(f"iou={iou}-score={score}: #detections={len(det)}")
    for image in test_dataset[image_start:image_stop]:
      boxed:Image = draw_boxes(image, annotations, line=False)
      if classless:
        boxed:Image = draw_boxes(image, det, draw_legend=True, manual_image=boxed, forced_color = True)
      else:
        boxed:Image = draw_boxes(image, det, draw_legend=True, manual_image=boxed, forced_color = False)
      with open(img_path + image.filename, "w") as imfp:
        boxed.save(imfp)

max_ap = 0
max_ap_nms = 0
max_ap_pr = None
max_ap_fscore = 0

max_fscore = None
max_fscore_f1 = 0
max_fscore_nms = 0
max_fscore_pr = None
max_fscore_ap = 0

fig, ax = plt.subplots(figsize=(10,10))

for iou_threshold in iou_thresholds:
  print(iou_threshold)

  det = filter_by_iou(detections, iou_threshold)
  pr = bb.stat.pr(det, annotations, threshold=0.5, smooth=True)
  ap = bb.stat.ap(pr)
  fscore = bb.stat.fscore(pr)
  peakf1 = bb.stat.peak(fscore)
  ax.plot(pr['recall'], pr['precision'], label=f"nms={iou_threshold}, AP={100 * ap:.2f}%, F1={100 * peakf1.f1:.2f}%")

  print(f"nms={iou_threshold}, AP={100 * ap:.2f}%, F1={100 * peakf1.f1:.2f}%")

  if ap > max_ap:
      max_ap = ap
      max_ap_nms = iou_threshold
      max_ap_pr = pr
      max_ap_fscore = peakf1.f1
  if peakf1.f1 > max_fscore_f1:
      max_fscore_f1 = peakf1.f1
      max_fscore = peakf1
      max_fscore_nms = iou_threshold
      max_fscore_pr = pr
      max_fscore_ap = ap
      topf1 = bb.stat.point(pr, peakf1.f1)

plt.title("Precision-Recall curve for different nms values")
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.legend()

fig.savefig(f"{output_path}iouPR.png")
plt.close()

print(f"best ap: {max_ap * 100}%, best fscore: {max_fscore_f1 * 100}%")
print(f"maximize ap: nms={max_ap_nms}")
print(f"maximize fscore: nms={max_fscore_nms}")


iou_thresholds = [round(0.01 * x, 2) for x in range(1,101)]


max_ap = 0
max_ap_nms = 0
max_ap_pr = None
max_ap_fscore = 0

max_fscore = None
max_fscore_f1 = 0
max_fscore_nms = 0
max_fscore_pr = None
max_fscore_ap = 0

for iou_threshold in iou_thresholds:

  det = filter_by_iou(detections, iou_threshold)
  pr = bb.stat.pr(det, annotations, threshold=0.5, smooth=True)
  ap = bb.stat.ap(pr)
  fscore = bb.stat.fscore(pr)
  peakf1 = bb.stat.peak(fscore)
  if ap > max_ap:
      max_ap = ap
      max_ap_nms = iou_threshold
      max_ap_pr = pr
      max_ap_fscore = peakf1.f1
  if peakf1.f1 > max_fscore_f1:
      max_fscore_f1 = peakf1.f1
      max_fscore = peakf1
      max_fscore_nms = iou_threshold
      max_fscore_pr = pr
      max_fscore_ap = ap
      topf1 = bb.stat.point(pr, peakf1.f1)


print(f"best ap: {max_ap * 100}%, best fscore: {max_fscore_f1 * 100}%")
print(f"maximize ap: nms={max_ap_nms}")
print(f"maximize fscore: nms={max_fscore_nms}")

opt_det = filter_by_iou(detections, max_ap_nms)

fig, ax = plt.subplots(figsize=(10,10))

for shell_class in detections.class_label.unique():
  shell_det = opt_det.loc[opt_det['class_label'] == shell_class]
  shell_anno = annotations[annotations['class_label'] == shell_class]
  pr = bb.stat.pr(shell_det, shell_anno, threshold=0.5, smooth=True)
  ap = bb.stat.ap(pr)
  fscore = bb.stat.fscore(pr)
  peakf1 = bb.stat.peak(fscore)
  ax.plot(pr['recall'], pr['precision'], label=f"{shell_class}: AP={100 * ap:.2f}%, F1={100 * peakf1.f1:.2f}%")

  print(f"{shell_class}: AP={100 * ap:.2f}%, F1={100 * peakf1.f1:.2f}%")

plt.title("Precision-Recall curve for each class at optimal nms threshold")
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.legend()

fig.savefig(f"{output_path}classesPR.png")
plt.close()

# Merged for continuous running

In [None]:
classless_options = [True, False]
N_options = [1, 5, 10, 20, 50]
use_all_generated_image_embeddings_options = [True, False]

options_matrix = list(itertools.product(classless_options, N_options, use_all_generated_image_embeddings_options))

use_text_queries = False
use_negative_queries = False #text only
for classless, N, use_all_generated_image_embeddings in options_matrix:
  print(f"Running classless: {classless}, N:{N}, use_all_generated_image_embeddings:{use_all_generated_image_embeddings}")
  class_string = "classless" if classless else "classed"
  neg_string = "use_negative_queries" if use_negative_queries else "only_positive_queries"
  text_string = "text" if use_text_queries else "img"
  extra_img_string = f"full_{N}_" if (not use_text_queries and use_all_generated_image_embeddings) else f"{N}_"

  output_path = f"{output_path_base}/owlvit/{text_string}/{class_string}/{neg_string}/{extra_img_string}"

  saves_path = f"{saves_path_base}/owlvit/{text_string}/{class_string}/{neg_string}/{extra_img_string}"

  if use_text_queries:
    annotations:pd.DataFrame = pd.concat([image.annotations for image in dataset])

    if classless:
      positive_text_queries = ("shell", "seashell") + tuple(annotations.class_label.unique())
    else:
      positive_text_queries = tuple(annotations.class_label.unique())
    if use_negative_queries:
      negative_text_queries = ("rock", "sand")
    else:
      negative_text_queries = ()
    queries = negative_text_queries + positive_text_queries
    query_embeddings = model.embed_text_queries(queries)

    test_dataset = dataset

  else:
    annotations:pd.DataFrame = pd.concat([image.annotations for image in dataset])

    test_dataset = dataset.copy()
    query_dataset = set()

    # we take out the query images from which we make the query embeddings
    # if a class has less than N annotations we dont use it
    annotations_per_class = annotations['class_label'].value_counts()
    removed_classes = []
    queries_dict = {}
    for annotation_class, _ in annotations_per_class.items():
      queries_dict[annotation_class] = []
    for annotation_class, amount in reversed(list(annotations_per_class.items())): #loop over counts in reverse
      if amount < N:
        removed_classes.append(annotation_class)
      else:
        if len(queries_dict[annotation_class]) > N:
          continue #we dont need to generate any more queries for classes that are already sufficiently populated

        for image in test_dataset:
          if annotation_class in list(image.annotations.class_label.values):
            query_dataset.add(image)
            img_query_embeddings = get_all_query_embeddings(image)
            for query_class, embedding in img_query_embeddings:
              queries_dict[query_class].append(embedding)
          if len(queries_dict[annotation_class]) > N:
            break

    for image in query_dataset:
      test_dataset.remove(image)

    annotations = pd.concat([image.annotations for image in test_dataset])

    for c in removed_classes:
      annotations = annotations[annotations['class_label'] != c]
      queries_dict.pop(c)


    query_embeddings_list = []
    queries = []

    for key, value in queries_dict.items():
      queries.append(key)
      if use_all_generated_image_embeddings:
        query_embeddings_list.append(np.mean(value, axis=0))
      else:
        query_embeddings_list.append(np.mean(value[:N], axis=0))

    query_embeddings = np.array(query_embeddings_list)



  if classless:
      annotations['class_label'] = 'object'


  if use_text_queries:
    # This isnt all that useful, but it is used to count the amount of boxes that the model gives.
    # Could this be done with our own images or even a null-image?
    # Sure, but why bother if it already works (doesnt take that long either)

    IMAGE_DIR = 'gs://scenic-bucket/owl_vit/example_images'
    %matplotlib inline

    images = {}

    for i, filename in enumerate(tf.io.gfile.listdir(IMAGE_DIR)):
      with tf.io.gfile.GFile(os.path.join(IMAGE_DIR, filename), 'rb') as f:
        image = mpl.image.imread(
            f, format=os.path.splitext(filename)[-1])[..., :3]
      if np.max(image) <= 1.:
        image *= 255
      images[i] = image
    IMAGE_ID =   2
    image = images[IMAGE_ID]

    _, _, boxes = model.embed_image(image)

  else:
    test_image_annotated = dataset[1]
    _, _, boxes = model.embed_image(test_image_annotated.npimage)

  detections:pd.DataFrame = pd.DataFrame(index=range(len(test_dataset) * len(boxes)), columns=['image', 'class_label', 'id', 'x_top_left', 'y_top_left', 'width', 'height', 'confidence', 'area'])

  index = 0

  for q in range(len(test_dataset)):
      test_image_annotated = test_dataset[q]
      test_image_name = test_image_annotated.filename.split(".")[0]
      test_image = test_image_annotated.npimage

      _, _, boxes = model.embed_image(test_image)
      query_index, scores = model.get_scores(test_image, query_embeddings, len(queries))

      zipped_scores = zip(query_index, scores, boxes) # tuples: (label_id, score, [x, y, w, h])
      for i in range(len(boxes)):
          box = boxes[i]
          score = scores[i]
          class_label = queries[query_index[i]]
          x_top_left, y_top_left, width, height = box
          detections.loc[index] = [test_image_name, class_label, i, ((x_top_left - width/2) * test_image_annotated.width), ((y_top_left - height/2) * test_image_annotated.width) , width * test_image_annotated.width, height * test_image_annotated.width, score, width * height * test_image_annotated.height * test_image_annotated.width]
          index += 1

          #new_row:pd.DataFrame = pd.DataFrame(columns=detections.columns, data=[[test_image_name, class_label, i, x_top_left, y_top_left, width, height, score]])

          #detections.iloc[q*len(queries)+i] = new_row

  # add area column
  detections['image'] = detections['image'].astype('category')
  detections['area'] = detections['width'] * detections['height']

  if use_negative_queries and use_text_queries:
    for q in negative_text_queries:
        detections = detections[detections['class_label'] != q]
  if classless:
    for q in queries:
      detections.loc[detections['class_label'] == q, 'class_label'] = 'object'

  bb.io.save(detections, "pandas", saves_path + "unprocessed_detections.pkl")

  detections = bb.io.load("pandas", saves_path + "unprocessed_detections.pkl")

  detections['iou'] = 0.0
  lowest_confidence = 0.01 if classless else 0.10
  lowest_confidence = lowest_confidence if use_text_queries else 0.6

  detections = filter_by_confidence(detections, lowest_confidence)

  index = 1
  total_len = len(detections.image.unique())

  for image in detections.image.unique():
    img_detections = get_detections_for_image(detections, image)
    index += 1
    calc_iou(detections, img_detections)

  bb.io.save(detections, "pandas", saves_path + "processed_detections.pkl")

  detections = bb.io.load("pandas", saves_path + "processed_detections.pkl")

  detections['height'] = detections['height'].astype(float)
  detections['confidence'] = detections['confidence'].astype(float)
  detections['area'] = detections['area'].astype(float)
  detections['iou'] = detections['iou'].astype(float)
  detections['x_top_left'] = detections['x_top_left'].astype(float)
  detections['y_top_left'] = detections['y_top_left'].astype(float)
  detections['width'] = detections['width'].astype(float)


  annotations['x_top_left'] = annotations['x_top_left'].astype(float)
  annotations['y_top_left'] = annotations['y_top_left'].astype(float)
  annotations['width'] = annotations['width'].astype(float)
  annotations['height'] = annotations['height'].astype(float)
  annotations['occluded'] = annotations['occluded'].astype(float)
  annotations['truncated'] = annotations['truncated'].astype(float)
  annotations['lost'] = annotations['lost'].astype(bool)
  annotations['difficult'] = annotations['difficult'].astype(bool)
  annotations['ignore'] = annotations['ignore'].astype(bool)

  iou_thresholds = [round(0.1 * x, 2) for x in range(6,11)]
  score_thresholds = [round(x/10, 6) for x in range(round(lowest_confidence * 10), 10)] #lower than the lowest score is meaninless ofc

  options_matrix = list(itertools.product(iou_thresholds, score_thresholds))

  save_images = True
  image_amount = 15
  if image_amount > len(test_dataset):
    image_amount = len(test_dataset)



  if save_images:
    for iou, score in options_matrix:
      det_iou = filter_by_iou(detections, iou)
      det = filter_by_confidence(det_iou, score)

      img_path = f"{output_path}iou{iou}/score{score}/"
      create_folder_if_not_exists(img_path)
      for image in test_dataset[0:image_amount]:
        boxed:Image = draw_boxes(image, annotations, line=False)
        boxed = draw_boxes(image, det, draw_legend=True, manual_image=boxed)
        with open(img_path + image.filename, "w") as imfp:
          boxed.save(imfp)

  max_ap = 0
  max_ap_nms = 0
  max_ap_pr = None
  max_ap_fscore = 0

  max_fscore = None
  max_fscore_f1 = 0
  max_fscore_nms = 0
  max_fscore_pr = None
  max_fscore_ap = 0

  fig, ax = plt.subplots(figsize=(10,10))

  for iou_threshold in iou_thresholds:

    det = filter_by_iou(detections, iou_threshold)
    pr = bb.stat.pr(det, annotations, threshold=0.5, smooth=True)
    ap = bb.stat.ap(pr)
    fscore = bb.stat.fscore(pr)
    peakf1 = bb.stat.peak(fscore)
    ax.plot(pr['recall'], pr['precision'], label=f"nms={iou_threshold}, AP={100 * ap:.2f}%, F1={100 * peakf1.f1:.2f}%")

    print(f"nms={iou_threshold}, AP={100 * ap:.2f}%, F1={100 * peakf1.f1:.2f}%")

    if ap > max_ap:
        max_ap = ap
        max_ap_nms = iou_threshold
        max_ap_pr = pr
        max_ap_fscore = peakf1.f1
    if peakf1.f1 > max_fscore_f1:
        max_fscore_f1 = peakf1.f1
        max_fscore = peakf1
        max_fscore_nms = iou_threshold
        max_fscore_pr = pr
        max_fscore_ap = ap
        topf1 = bb.stat.point(pr, peakf1.f1)

  plt.title("Precision-Recall curve for different nms values")
  ax.set_xlabel("Recall")
  ax.set_ylabel("Precision")
  ax.legend()

  fig.savefig(f"{output_path}iouPR.png")
  plt.close()

  print(f"best ap: {max_ap * 100}%, best fscore: {max_fscore_f1 * 100}%")
  print(f"maximize ap: nms={max_ap_nms}")
  print(f"maximize fscore: nms={max_fscore_nms}")
