In [12]:
from ultralytics import YOLO
from transformers import pipeline
from PIL import Image
import torch
from PIL import Image, ImageDraw
import numpy as np
import orjson
from tqdm import tqdm
import os

In [4]:
yolo_model = YOLO("yolov8/runs/detect/yolov9c 0.99 0.769/weights/best.pt")  # load a pretrained model (recommended for training)

In [5]:
image_classifier = pipeline(task="zero-shot-image-classification", model="siglip/siglip-so400m-patch14-384", batch_size=4, device='cuda')

In [25]:
with open('../../data/vlm.jsonl', 'r') as f:
    instances = [orjson.loads(line.strip()) for line in f if line.strip() != ""]
results = []
val_percent = 0.2
val_split = int(len(instances) * val_percent)
train, val = instances[:-val_split], instances[-val_split:]
bs = 4
batched_instances = [val[i:i + bs] for i in range(0, len(val), bs)]

In [26]:
for batch_instance in tqdm(batched_instances):
    images = [Image.open(os.path.join('../../data/images/', i['image'])) for i in batch_instance]
    
    # YOLO object det
    yolo_result = yolo_model.predict(images, imgsz=1600, conf=0.365, iou=0.1, max_det=10, verbose=False)  # max F1, try augment=True and adjusting iou
    yolo_result = [(r.boxes.xyxy.tolist(), r.boxes.conf.tolist()) for r in yolo_result]
    yolo_result = [tuple(zip(*r)) for r in yolo_result]  # list of tuple[box, conf] in each image in xyxy format
    
    # crop the boxes out
    cropped_boxes = []
    for im, boxes in zip(images, yolo_result):
        im_boxes = []
        for (x1, y1, x2, y2), _ in boxes:
            im_boxes.append(im.crop((x1, y1, x2, y2)))
        cropped_boxes.append(im_boxes)
    
    captions_list = [[anno['caption'] for anno in img['annotations']] for img in batch_instance]  # list of list of str, len is n_img == 4
    assert len(cropped_boxes) == len(captions_list)
    
    # siglip inference
    siglip_results = []
    with torch.cuda.amp.autocast():
        for boxes, captions in zip(cropped_boxes, captions_list):
            r = image_classifier(boxes, candidate_labels=captions)
            image_to_text_scores = {caption: [] for caption in captions}  # {caption: [score1, score2, ...]}, scores in sequence of bbox
            for box in r:
                for label_score in box:
                    image_to_text_scores[label_score['label']].append(label_score['score'])
            siglip_results.append(image_to_text_scores)
    
    # combine the results
    visualize = False
    for im, cropped_box_PIL, yolo_box, similarity_scores, instance in zip(images, cropped_boxes, yolo_result, siglip_results, batch_instance):
        if visualize: im_cp = im.copy()
        result_for_im = {}
        for caption, caption_scores in similarity_scores.items():
            box_idx = np.argmax(caption_scores)
            highest_caption_score = max(caption_scores)
            box = cropped_box_PIL[box_idx]
            result_for_im[caption] = yolo_box[box_idx][0]  # dict[caption] = xyxy in list
            if visualize:
                draw = ImageDraw.Draw(im_cp)  # noqa
                (x1, y1, x2, y2), box_conf = yolo_box[box_idx]
                draw.rectangle(xy=((x1, y1), (x2, y2)), outline='red')
                draw.text((x1, y1), text=f'{caption} {box_conf:.2f} {highest_caption_score:.2f}', fill='red')
        if visualize: im_cp.show()
        results.append({'image': instance['image'], 'annotations': [{'bbox': v, 'caption': k} for k, v in result_for_im.items()]})
        # save every image in case of crash
        with open('yolo-siglip-zeroshot.json', 'wb+') as f:
            f.write(orjson.dumps(results))

  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 256/256 [04:41<00:00,  1.10s/it]


In [7]:
# plot bbox
for im, boxes in zip(ims, yolo_result):
    im = im.copy()
    draw = ImageDraw.Draw(im)
    for (x1, y1, x2, y2), conf in boxes:
        draw.rectangle(xy=((x1, y1), (x2, y2)), outline='red')
        draw.text((x1, y1), text=f'{conf:.2f}', fill='red')
    im.show()

In [59]:
visualize = False
for im, cropped_box_PIL, yolo_box, similarity_scores in zip(ims, cropped_boxes, yolo_result, siglip_results):
    if visualize: im_cp = im.copy()
    result_for_im = {}
    for caption, caption_scores in similarity_scores.items():
        box_idx = np.argmax(caption_scores)
        highest_caption_score = max(caption_scores)
        box = cropped_box_PIL[box_idx]
        result_for_im[caption] = yolo_box[box_idx][0]  # dict[caption] = (xyxy in list, conf)
        if visualize:
            draw = ImageDraw.Draw(im_cp)
            (x1, y1, x2, y2), box_conf = yolo_box[box_idx]
            draw.rectangle(xy=((x1, y1), (x2, y2)), outline='red')
            draw.text((x1, y1), text=f'{caption} {box_conf:.2f} {highest_caption_score:.2f}', fill='red')
    if visualize: im_cp.show()
    results.append(result_for_im)

In [60]:
results

[{'grey missile': [705.0738525390625,
   506.7243347167969,
   782.65283203125,
   563.574951171875],
  'red, white, and blue light aircraft': [1030.6815185546875,
   77.49951934814453,
   1056.74853515625,
   110.44055938720703],
  'green and black missile': [705.0738525390625,
   506.7243347167969,
   782.65283203125,
   563.574951171875],
  'white and red helicopter': [527.7639770507812,
   118.3411865234375,
   624.7859497070312,
   161.6909637451172]},
 {'grey camouflage fighter jet': [400.4502868652344,
   158.0403289794922,
   455.9124450683594,
   193.24575805664062],
  'grey and white fighter plane': [1117.64501953125,
   514.673828125,
   1254.2855224609375,
   553.1058959960938],
  'white and black drone': [356.56414794921875,
   455.2095031738281,
   402.8783264160156,
   486.3287353515625],
  'white and black fighter jet': [400.4502868652344,
   158.0403289794922,
   455.9124450683594,
   193.24575805664062],
  'white missile': [400.4502868652344,
   158.0403289794922,
   