In [None]:
import os
import xml.etree.ElementTree as ET
import xml.dom.minidom
import cv2
import torch
import torchvision
import shutil
from tqdm import tqdm
from groundingdino.util.inference import load_model, load_image, predict, annotate

# Configuration
config_path = "github_files/GroundingDINO_SwinT_OGC.py"
weights_path = "github_files/groundingdino_swint_ogc.pth"

#config_path = "github_files/GroundingDINO_SwinB_cfg.py"
#weights_path = "github_files/groundingdino_swinb_cogcoor.pth"

image_directory = "frames"  # Folder with images to annotate
annotation_directory = "annotations"  # Folder to save Pascal VOC XML annotations
annotated_image_directory = "annotated_images"  # Folder to save annotated images

text_prompt = "all single syringes"
box_threshold = 0.35
text_threshold = 0.10


def reset_directory(directory: str) -> None:
    """Delete the directory if it exists, and then create an empty one."""
    if os.path.exists(directory):
        shutil.rmtree(directory)
    os.makedirs(directory, exist_ok=True)


reset_directory(annotation_directory)
reset_directory(annotated_image_directory)

# Ensure output directories exist
os.makedirs(annotation_directory, exist_ok=True)
os.makedirs(annotated_image_directory, exist_ok=True)

# Load GroundingDINO model
model = load_model(config_path, weights_path)

# Function to create Pascal VOC format XML annotation
def create_pascal_voc_xml(image_filename: str, image_shape: list[int], boxes: list[list[float]]) -> str:
    annotation = ET.Element("annotation")

    folder = ET.SubElement(annotation, "folder")
    folder.text = "images"

    filename = ET.SubElement(annotation, "filename")
    filename.text = image_filename

    size = ET.SubElement(annotation, "size")
    width = ET.SubElement(size, "width")
    width.text = str(image_shape[1])

    height = ET.SubElement(size, "height")
    height.text = str(image_shape[0])

    depth = ET.SubElement(size, "depth")
    depth.text = str(image_shape[2])

    for box in boxes:
        obj = ET.SubElement(annotation, "object")
        name = ET.SubElement(obj, "name")
        name.text = "syringe"  # Set fixed label

        pose = ET.SubElement(obj, "pose")
        pose.text = "Unspecified"

        truncated = ET.SubElement(obj, "truncated")
        truncated.text = "0"

        difficult = ET.SubElement(obj, "difficult")
        difficult.text = "0"

        bndbox = ET.SubElement(obj, "bndbox")
        xmin = ET.SubElement(bndbox, "xmin")
        xmin.text = str(int(box[0]))

        ymin = ET.SubElement(bndbox, "ymin")
        ymin.text = str(int(box[1]))

        xmax = ET.SubElement(bndbox, "xmax")
        xmax.text = str(int(box[2]))

        ymax = ET.SubElement(bndbox, "ymax")
        ymax.text = str(int(box[3]))

    return xml.dom.minidom.parseString(ET.tostring(annotation)).toprettyxml()

# Iterate through images in the folder
valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"}
image_filenames = [f for f in os.listdir(image_directory) if os.path.splitext(f)[1].lower() in valid_extensions]

for image_filename in tqdm(image_filenames):
    image_path = os.path.join(image_directory, image_filename)

    # Load image
    image_source, image = load_image(image_path)
    h, w, c = image_source.shape  # Ensure correct shape extraction

    # Run GroundingDINO detection
    boxes, logits, phrases = predict(
        model=model,
        image=image,
        caption=text_prompt,
        box_threshold=box_threshold,
        text_threshold=text_threshold,
        device="cpu"
    )

    # Convert bounding boxes to Pascal VOC format (adjust scaling properly)
    boxes_rescaled = boxes * torch.tensor([w, h, w, h])  # Ensure correct scaling
    rel_box = torchvision.ops.box_convert(boxes=boxes_rescaled, in_fmt="cxcywh", out_fmt="xyxy").numpy()

    # Save Pascal VOC annotation
    voc_xml_annotation = create_pascal_voc_xml(
        image_filename=image_filename,
        image_shape=[h, w, c],
        boxes=rel_box
    )

    voc_xml_filename = os.path.join(annotation_directory, f"{os.path.splitext(image_filename)[0]}.xml")
    with open(voc_xml_filename, "w") as voc_xml_file:
        voc_xml_file.write(voc_xml_annotation)

    # Save annotated image
    annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=["syringe"] * len(boxes))
    output_file = os.path.join(annotated_image_directory, f"{os.path.splitext(image_filename)[0]}_annotated.jpg")
    cv2.imwrite(output_file, annotated_frame)

print(f"Annotations saved to: {annotation_directory}")
print(f"Annotated images saved to: {annotated_image_directory}")