# Foundation models for zero-shot detection and segmentation

Based on [Ollama](https://github.com/ollama/ollama) project.

## Preparing Python tools

In [None]:
import json
import base64
import requests
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection


class ZShotModel:
    def __init__(self, model: str):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model=AutoModelForZeroShotObjectDetection.from_pretrained(
            model).to(self.device)
        self.processor=AutoProcessor.from_pretrained(model)

    def infer(self, images: any,
              prompt: str | list[str],
              box_threshold: float=0.2,
              text_threshold: float=0.2):
        # VERY important: text queries need to be lowercased + end with a dot
        if isinstance(prompt, list):
            objects = [item for item in prompt if not item=='']
            text = " . ".join([f"{item}" for item in objects]).lower() + "."
        elif isinstance(prompt, str):
            text = prompt.lower() + "."
        else:
            raise ValueError(
                f"Error, prompt with type \"{type(prompt)}\" is not supported")
        inputs = self.processor(
            images=images, text=text, return_tensors="pt"
        ).to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs)

        results = self.processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            target_sizes=[image.size[::-1]]
        )

        return results


def load_answer(filename: str) -> dict[str, any]:
    with open(filename, "r") as file:
        step2_response = json.loads(file.read())
    return step2_response["response"] if "response" in step2_response else step2_response


def llm_infer(model: str,
              prompt: str,
              image_file: str | list[str] | None=None,
              log_pattern: str | None=None,
              url: str = "http://localhost:11434/api/generate"
              ) -> dict[str, any]:
    payload = dict(model=model, prompt=prompt, stream=False)
    if image_file is not None:
        if isinstance(image_file, str):
            image_file = [image_file]
        images = []
        for image in image_file:
            with open(image, "rb") as file:
                encoded_image = base64.b64encode(file.read()).decode("ascii")
            images.append(encoded_image)
        payload["images"] = images

    if log_pattern is not None:
        with open(f"{log_pattern}_request.json", "w") as file:
            file.write(json.dumps(payload))

    reply = requests.post(url, json=payload)

    if log_pattern is not None:
        with open(f"{log_pattern}_reply.json", "w") as file:
            file.write(reply.content.decode("ascii"))

    return json.loads(reply.content.decode("ascii"))

## Installing Ollama and pulling LLMs

In [None]:
!curl -L https://ollama.com/download/ollama-linux-amd64 -o ollama
!chmod +x ollama

In [None]:
import subprocess
subprocess.Popen(["./ollama", "serve"])
import time
time.sleep(3)

In [None]:
!./ollama pull llava
!./ollama pull llama3.1

In [None]:
import requests


result = requests.post("http://localhost:11434/api/generate", json={
        "model": "llama3.1",
        "prompt": "Why do you cry",
        "stream": False
    })
(result.status_code, result.content)

In [None]:
gdino = ZShotModel(model="IDEA-Research/grounding-dino-base")

## Experiment #1

In [None]:
!wget -q -O image.jpg https://github.com/ant-nik/neural_network_course/blob/main/practice_2_data/video_1_fixed/image_001.jpg?raw=true

In [None]:
llava_reply = llm_infer(
    model="llava",
    prompt="Describe entities on the image as detailed as possible.",
    image_file="/content/image.jpg",
    log_pattern="llava"
)
print(llava_reply)

In [None]:
print(load_answer("llava_reply.json"))

In [None]:
llama_prompt = f"""Extract all nouns from the TEXT section that are physical objects, living beings, dressing, parts of living beings or physical objects.
Split answer in two parts: OUTPUT and INFO.
In OUTPUT section place extracted nouns without enumerations symbols and one entity per line.
Put detailed explanation of the answer to INFO section.

TEXT:
{llava_reply}

OUTPUT:
"""
llama_reply = llm_infer(model="llama3.1",
    prompt=llama_prompt,
    log_pattern="llama"
)
print(llama_reply)

In [None]:
llama_reply = load_answer("llama_reply.json")
print(llama_reply)

In [None]:
objects = list(set(llama_reply.replace("*","").replace(":", "").split("OUTPUT")[1].split("INFO")[0].split("\n")))
objects.remove('')
objects

In [None]:
image = Image.open("image.jpg")

In [None]:
detected_objects = gdino.infer(images=image, prompt=objects)
detected_objects

In [None]:
labels = detected_objects[0]["labels"]
unique_classes = list(set(labels))
class_to_index_map = {
    item: unique_classes.index(item) for item in unique_classes
}
classes = [class_to_index_map[item] for item in detected_objects[0]["labels"]]
unique_classes

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [None]:
!pip install supervision

In [None]:
import cv2
import supervision
import numpy


box_annotator = supervision.BoxAnnotator()
label_annotator = supervision.LabelAnnotator()

image_boxes = supervision.Detections(
    xyxy=detected_objects[0]["boxes"].cpu().numpy(),
    class_id=numpy.array(classes, dtype=int)
)

#, 2, 3, 4])#results[0]["labels"]
"""
labels = [
    f"{class_id} {confidence:0.2f}"
    for confidence, class_id, boxes in results
]
"""
annotated_frame = box_annotator.annotate(scene=image.copy(),
                                         detections=image_boxes) #, labels=labels)
annotated_frame = label_annotator.annotate(
    scene=annotated_frame,
    detections=image_boxes,
    labels=labels
)


In [None]:
%matplotlib inline
supervision.plot_image(annotated_frame, (16, 16))

# Objects count by confidence score thresholds

In [None]:
all_results = gdino.infer(
    images=image,
    prompt=objects,
    box_threshold=0.1,
    text_threshold=0.1
)

In [None]:
x = numpy.linspace(0.01, 1, 100)
y = numpy.diff([len([x for x in filter(lambda x: x > threshold, all_results[0]["scores"])]) for threshold in x])

In [None]:
import plotly.express


plotly.express.line(x=x[1:], y=y)

In [None]:
plotly.express.histogram(y)

In [None]:
numpy.quantile(y, [0.01, 0.05, 0.1, 0.15, 0.2])