In [1]:
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from transformers.image_transforms import corners_to_center_format
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
import time
from pathlib import Path

device = 'cuda' if torch.cuda.is_available() else 'cpu'
image_fp = Path('drive/MyDrive/simulated_cube_dataset/images/')
label_fp = Path('drive/MyDrive/simulated_cube_dataset/labels/')

def process_image(image):
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image = image.resize((500,500))
    return image

def get_model():
    model_id = "IDEA-Research/grounding-dino-tiny"
    processor = AutoProcessor.from_pretrained(model_id)
    model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
    return processor, model

def run_model(model, image, text):
    processor, model = model
    inputs = processor(images=image, text=text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    results = processor.post_process_grounded_object_detection(outputs,
                                                               inputs.input_ids,
                                                               box_threshold=0.4,
                                                               text_threshold=0.4,
                                                               target_sizes=None)
    return inputs, outputs, results

def process_output(model, inputs, outputs, threshold):
    processor, model = model
    results = processor.post_process_grounded_object_detection(outputs,
                                                               inputs.input_ids,
                                                               box_threshold=threshold,
                                                               text_threshold=threshold,
                                                               target_sizes=None)
    return results

def check_outputs(results):
  undetected_outputs = []
  for i, r in enumerate(results):
    if len(r['labels']) == 0:
      undetected_outputs.append((i,r))
  if len(undetected_outputs) > 0:
    return undetected_outputs
  return None

def isolate_boxes(results):
  return [r['boxes'] for r in results]

def process_to_yolo(boxes):
    yolo_boxes = []
    for _box in boxes:
      _boxes = []
      for box in _box:
        center_xywh = corners_to_center_format(box)
        x, y, w, h = center_xywh.tolist()
        _boxes.append((0, x, y, w, h))
      yolo_boxes.append(_boxes)
    return yolo_boxes

def plot_boxes(images, boxes):
    fig, axs = plt.subplots(1, len(images), figsize=(10, 5))
    for i, (im, box) in enumerate(zip(images, boxes)):
        ax = axs[i]
        ax.imshow(im)
        for b in box:
            x1, y1, x2, y2 = [int(x) for x in b]
            ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, color='red'))
    plt.show()

def determine_split(im_processed: int):
  '''
  takes in integer value, returns split and upper bound
  '''
  if im_processed < 800:
    return 'train', 799
  elif 800 <= im_processed < 900:
    return 'val', 899
  else:
    return 'test', 1000

def determine_batch(curr, upper_bound):
  """
  checks if 21 will overflow current split, returns proper value
  """
  next_batch = curr + 21
  res = upper_bound - next_batch
  if res < 0:
    return 21 + res
  else:
    return 21


def get_fns(cur, batch):
  return ["{:04d}".format(i) for i in range(cur, cur + batch + 1)]

def get_images(split, fns):
  images = [Image.open(image_fp / split / (fn + '.jpg')) for fn in fns]
  images = [process_image(im) for im in images]
  return images

def write_label(split, fn, boxes):
  path = label_fp / split / (fn + '.txt')
  with open(path, 'w') as fp:
    for box in boxes:
      res = " ".join(map(str, box)) + '\n'
      fp.write(res)

In [3]:
images_processed = 298

model = get_model()

threshold = 0.4
complete_start_time = time.time()
start_time = time.time()
while images_processed <= 1000:
  split, upper_bound = determine_split(images_processed)
  batch = determine_batch(images_processed, upper_bound)
  fns = get_fns(images_processed, batch)
  images = get_images(split, fns)
  texts = ['rubiks cube.'] * len(fns)
  print(f'Processing {len(fns)} images in {split} split')
  image_start = time.time()
  inputs, outputs, results = run_model(model, images, texts)
  while check_outputs(results) is not None:
    print('Some results did not produce boxes, reducing threshold by 0.025')
    threshold -= 0.025
    results = process_output(model, inputs, outputs, threshold)
  threshold = 0.4
  yolo_boxes = process_to_yolo(isolate_boxes(results))
  for fn, box in zip(fns, yolo_boxes):
    write_label(split, fn, box)
  del inputs
  del outputs
  del results
  image_end = time.time()
  print(f'Wrote {len(fns)} labels in {split} split')
  print(f'Time taken: {image_end - image_start}')
  print(f'Current progress: {images_processed + len(fns)}/{upper_bound}')
  print(f'Overall elapsed time: {image_end - complete_start_time}')
  if not batch:
    end_time = time.time()
    images_processed += 1
    print(f'Completed processing {split} split')
    print(f'Elapsed time: {end_time - start_time}')
    print(f'Moving to {determine_split(images_processed)[0]} split')
    start_time = time.time()

  images_processed += batch
end_time = time.time()
print(f'Completed processing all splits')
print(f'Final elapsed time: {end_time - complete_start_time}')



Processing 22 images in train split
Some results did not produce boxes, reducing threshold by 0.025
Some results did not produce boxes, reducing threshold by 0.025
Wrote 22 labels in train split
Time taken: 5.1421730518341064
Current progress: 320/799
Overall elapsed time: 6.023683309555054
Processing 22 images in train split
Wrote 22 labels in train split
Time taken: 5.006994247436523
Current progress: 341/799
Overall elapsed time: 11.936782598495483
Processing 22 images in train split
Wrote 22 labels in train split
Time taken: 5.000412464141846
Current progress: 362/799
Overall elapsed time: 31.002752780914307
Processing 22 images in train split
Wrote 22 labels in train split
Time taken: 5.012889385223389
Current progress: 383/799
Overall elapsed time: 54.70752191543579
Processing 22 images in train split
Wrote 22 labels in train split
Time taken: 5.076694965362549
Current progress: 404/799
Overall elapsed time: 77.98115372657776
Processing 22 images in train split
Wrote 22 labels in