# Install dependencies

In [8]:
!pip install git+https://github.com/huggingface/transformers

Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /private/var/folders/_3/b70qb3856jv2v4tk1htvvt1w0000gn/T/pip-req-build-mms3c4yp
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /private/var/folders/_3/b70qb3856jv2v4tk1htvvt1w0000gn/T/pip-req-build-mms3c4yp
^C
[31mERROR: Operation cancelled by user[0m[31m
[0m

In [9]:
!pip install qwen-vl-utils



In [10]:
!pip install torchvision



In [11]:
!pip install accelerate



In [5]:
import torch
import os
import json
import time
import re
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

In [None]:
START = time.time()
LOCAL_RUN = False

In [None]:
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(device)

In [None]:
LOAD_MODEL = time.time()
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
)
print(f"Model loaded in {time.time() - LOAD_MODEL} seconds")

In [None]:
LOAD_PROCESSOR = time.time()
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
print(f"Processor loaded in {time.time() - LOAD_PROCESSOR} seconds")

In [None]:
def extract_coordinates(output_text):
    coord_pattern = r'\(([\d\.\,\s]+)\)'
    match = re.search(coord_pattern, output_text)
    
    if match:

        coordinates = match.group(1).split(',')
        coordinates = list(map(float, coordinates))
        
        if len(coordinates) == 4:
            return coordinates
        else:
            return 0
    else:
        return 0

In [None]:
def process_images(data_directory):
    prompt = "Detect drones in the image. If a drone is detected, return only the bounding box coordinates normalized between 0 and 1, in the format (x, y, w, h), where (x, y) is the top-left corner of the bounding box relative to the image dimensions, and w and h are the width and height relative to the image dimensions. No other text or information; only the coordinates."
    results_dict = {}
    quantity = len(os.listdir(data_directory))

    for filename in os.listdir(data_directory):
        if filename.endswith('.jpg'):

            image_path = os.path.join(data_directory, filename)
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": f"file://{image_path}"},
                        {"type": "text", "text": prompt},
                    ],
                }
            ]

    

            text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            image_inputs, video_inputs = process_vision_info(messages)
            inputs = processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to(device)


            generated_ids = model.generate(**inputs, max_new_tokens=128)
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            output_text = processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )[0]

            print(f"Pending to process {quantity -1}")
            
            output_text_parse = extract_coordinates(output_text)

            results_dict[filename] = {
                "Inference": output_text_parse
            }

    output_path = os.path.join(data_directory, 'vlm-drones/output/inference_results.json')
    with open(output_path, 'w') as json_file:
        json.dump(results_dict, json_file, indent=4)

In [None]:
PROCESS_IMAGES = time.time()
process_images('vlm-drones/images/')
print(f"Images processed in {time.time() - PROCESS_IMAGES} seconds")