# Foundation models for zero-shot detection and segmentation

Based on [Ollama](https://github.com/ollama/ollama) project.

## Preparing Python tools

In [None]:
!pip install supervision

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [None]:
import json
import base64
import requests
import torch
import PIL
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import cv2
import supervision
import numpy
import datetime


__detection_models = {}


def get_model(name: str):
    if not name in __detection_models:
        __detection_models[name] = ZShotModel(model=name)
    return __detection_models[name]


class ZShotModel:
    def __init__(self, model: str):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model=AutoModelForZeroShotObjectDetection.from_pretrained(
            model).to(self.device)
        self.processor=AutoProcessor.from_pretrained(model)

    def infer(self, images: PIL.Image.Image | list[PIL.Image.Image],
              prompt: str | list[str],
              box_threshold: float=0.2,
              text_threshold: float=0.2):
        # VERY important: text queries need to be lowercased + end with a dot
        if isinstance(prompt, list):
            objects = [item for item in prompt if not item=='']
            text = " . ".join([f"{item}" for item in objects]).lower() + "."
        elif isinstance(prompt, str):
            text = prompt.lower() + "."
        else:
            raise ValueError(
                f"Error, prompt with type \"{type(prompt)}\" is not supported")
        inputs = self.processor(
            images=images, text=text, return_tensors="pt"
        ).to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs)

        results = self.processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            target_sizes=[image.size[::-1]]
        )

        labels = results[0]["labels"]
        class2label = list(set(labels))
        label2class = {
            class2label[i]: i for i in range(0, len(class2label))
        }
        classes = [label2class[label] for label in labels]
        results[0]["class2map"] = class2label
        results[0]["label2map"] = label2class
        results[0]["label_classes"] = classes

        return results[0]


def load_answer(filename: str) -> dict[str, any]:
    with open(filename, "r") as file:
        step2_response = json.loads(file.read())
    return step2_response["response"] if "response" in step2_response else step2_response


def llm_infer(model: str,
              prompt: str,
              image_file: str | list[str] | None=None,
              log_pattern: str | None=None,
              url: str = "http://localhost:11434/api/generate"
              ) -> dict[str, any]:
    payload = dict(model=model, prompt=prompt, stream=False)
    if image_file is not None:
        if isinstance(image_file, str):
            image_file = [image_file]
        images = []
        for image in image_file:
            with open(image, "rb") as file:
                encoded_image = base64.b64encode(file.read()).decode("ascii")
            images.append(encoded_image)
        payload["images"] = images

    if log_pattern is not None:
        with open(f"{log_pattern}_request.json", "w") as file:
            file.write(json.dumps(payload))

    reply = requests.post(url, json=payload)

    if log_pattern is not None:
        with open(f"{log_pattern}_reply.json", "w") as file:
            file.write(reply.content.decode("ascii"))

    return json.loads(reply.content.decode("ascii"))


def text_to_objects(
        text: str,
        start: str="OUTPUT",
        end: str="INFO",
        rem: str=":*") -> list[str]:
    clean_list = text
    for char in rem:
        clean_list = clean_list.replace(char, '')

    objects = list(
        set(clean_list.split(start)[1].split(end)[0].split("\n"))
    )
    objects.remove('')
    return objects


def bbox_image(image: PIL.Image.Image, bbox: dict[str, any]) -> None:
    box_annotator = supervision.BoxAnnotator()
    label_annotator = supervision.LabelAnnotator()

    image_boxes = supervision.Detections(
        xyxy=bbox["boxes"].cpu().numpy(),
        class_id=numpy.array(bbox["label_classes"], dtype=int)
    )

    #, 2, 3, 4])#results[0]["labels"]
    """
    labels = [
        f"{class_id} {confidence:0.2f}"
        for confidence, class_id, boxes in results
    ]
    """
    annotated_frame = box_annotator.annotate(
        scene=image.copy(), detections=image_boxes) #, labels=labels)
    annotated_frame = label_annotator.annotate(
        scene=annotated_frame,
        detections=image_boxes,
        labels=bbox["labels"]
    )
    supervision.plot_image(annotated_frame, (16, 16))


def detect_all_objects(
        image: str,
        captioning_prompt: str,
        object_extraction_prompt: str,
        bbox_threshold: float=0.1,
        text_threshold: float=0.1,
        captioning_model: str="llava",
        parsing_model: str="llama3.1",
        detection_model: str="IDEA-Research/grounding-dino-base",
        start_objects_list_marker: str="OUTPUT",
        end_objects_list_marker: str="INFO",
        symbols_to_remove: str="*:",
        capturing_iterations: int=1,
        parsing_iterations: int=1
) -> dict[str, any]:
    calculation_id = (
        captioning_model + "_" + datetime.datetime.utcnow().strftime(
            "%Y-%m-%d_%H.%M.%S_%f")
    )

    llava_reply = []
    objects = set()
    for i in range(0, capturing_iterations):
        llava_reply = llm_infer(
                model=captioning_model,
                prompt=captioning_prompt,
                image_file=image,
                log_pattern=captioning_model + "_" + calculation_id
            )
        for j in range(0, parsing_iterations):
            llama_reply = llm_infer(
                model=parsing_model,
                prompt=llama_prompt.format(llava_reply["response"]),
                log_pattern=parsing_model + "_" + calculation_id
            )

            objects = objects | set(text_to_objects(
                text=llama_reply["response"],
                start=start_objects_list_marker,
                end=end_objects_list_marker,
                rem=symbols_to_remove))

    image_data = PIL.Image.open(image)

    detected_objects = get_model(
        name=detection_model).infer(
            images=image_data,
            prompt=list(objects),
            box_threshold=bbox_threshold,
            text_threshold=text_threshold
        )

    return detected_objects


## Installing Ollama and pulling LLMs

In [None]:
!curl -L https://ollama.com/download/ollama-linux-amd64 -o ollama
!chmod +x ollama

In [None]:
import subprocess


subprocess.Popen(["./ollama", "serve"])
import time
time.sleep(3)

In [None]:
!./ollama pull llava
!./ollama pull llama3.1

In [None]:
import requests


result = requests.post("http://localhost:11434/api/generate", json={
        "model": "llama3.1",
        "prompt": "Why do you cry",
        "stream": False
    })
(result.status_code, result.content)

## Experiment #1

In [None]:
!wget -q -O image.jpg https://github.com/ant-nik/neural_network_course/blob/main/practice_2_data/video_1_fixed/image_001.jpg?raw=true

In [None]:
llava_prompt = "Describe entities on the image as detailed as possible."
llama_prompt = """Extract all nouns from the TEXT section that are physical objects, living beings, dressing, parts of living beings or physical objects.
Split answer in two parts: OUTPUT and INFO.
In OUTPUT section place extracted nouns without enumerations symbols and one entity per line.
Put detailed explanation of the answer to INFO section.

TEXT:
{}

OUTPUT:
"""

In [None]:
%matplotlib inline
image = PIL.Image.open("/content/image.jpg")
detected_objects = detect_all_objects(
        image="/content/image.jpg",
        captioning_prompt=llava_prompt,
        object_extraction_prompt=llama_prompt)
bbox_image(image=image, bbox=detected_objects)

In [None]:
# print(load_answer("/content/llama3.1_llava_2024-08-23_14.35.15_810177_request.json")["prompt"])

# Objects count by confidence score thresholds

In [None]:
all_results = detect_all_objects(
        image="/content/image.jpg",
        captioning_prompt=llava_prompt,
        object_extraction_prompt=llama_prompt,
        bbox_threshold=0.01,
        text_threshold=0.01
)

In [None]:
x = numpy.linspace(0.01, 1, 100)
y = numpy.diff([len([x for x in filter(lambda x: x > threshold, all_results["scores"])]) for threshold in x])

In [None]:
import plotly.express


plotly.express.line(x=x[1:], y=y)

In [None]:
plotly.express.histogram(y)

In [None]:
numpy.quantile(y, [0.01, 0.05, 0.1, 0.15, 0.2])