[![Labellerr](https://storage.googleapis.com/labellerr-cdn/%200%20Labellerr%20template/notebook.webp)](https://www.labellerr.com)

# KOSMOS-2

---

[![labellerr](https://img.shields.io/badge/Labellerr-BLOG-black.svg)](https://www.labellerr.com/blog/<BLOG_NAME>)
[![Youtube](https://img.shields.io/badge/Labellerr-YouTube-b31b1b.svg)](https://www.youtube.com/@Labellerr)
[![Github](https://img.shields.io/badge/Labellerr-GitHub-green.svg)](https://github.com/Labellerr/Hands-On-Learning-in-Computer-Vision)
[![Scientific Paper](https://img.shields.io/badge/Official-Paper-blue.svg)](<PAPER LINK>)

In [None]:
from PIL import Image
import requests
from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
from io import BytesIO
import os

## Helper Function

In [None]:
def show_image(source):
    """
    Display an image from a URL or a local file path.

    Args:
        source (str): The URL or local file path of the image.
    """
    try:
        if source.startswith("http://") or source.startswith("https://"):
            # Load image from URL
            response = requests.get(source)
            response.raise_for_status()  # Raise exception for bad response
            img = Image.open(BytesIO(response.content))
        elif os.path.exists(source):
            # Load image from local file path
            img = Image.open(source)
        else:
            raise ValueError("Invalid source. Provide a valid URL or local file path.")
        
        display(img)
    
    except Exception as e:
        print(f"Error displaying image: {e}")

In [None]:
checkpoint = "microsoft/kosmos-2-patch14-224"
model = Kosmos2ForConditionalGeneration.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)

In [None]:
url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
image = Image.open(requests.get(url, stream=True).raw)

prompt = "<grounding> An image of"

inputs = processor(text=prompt, images=image, return_tensors="pt")

generated_ids = model.generate(
    pixel_values=inputs["pixel_values"],
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    image_embeds=None,
    image_embeds_position_mask=inputs["image_embeds_position_mask"],
    use_cache=True,
    max_new_tokens=64,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False)
print(processed_text)

caption, entities = processor.post_process_generation(generated_text)
print(caption)

print(entities)

In [None]:
def kosmos2_generate(image, prompt="<grounding> An image of", cleanup_and_extract=False):
    """
    Generate text from an image using the Kosmos-2 model.

    Args:
        image (PIL.Image): The input image.
        prompt (str): The text prompt to guide the generation.
        max_new_tokens (int): Maximum number of tokens to generate.

    Returns:
        str: The generated text.
    """
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    
    generated_ids = model.generate(
        pixel_values=inputs["pixel_values"],
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        image_embeds=None,
        image_embeds_position_mask=inputs["image_embeds_position_mask"],
        use_cache=True,
        max_new_tokens=256
    )
    
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=cleanup_and_extract)
    
    return processed_text

In [None]:
url1 = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
image1 = Image.open(requests.get(url1, stream=True).raw).resize((500, 500))
display(image1)

In [None]:
kosmos2_generate(image1, prompt="what will happen to snowman?")

In [None]:
kosmos2_generate(image1, prompt="what is happening in this image?")

In [None]:
kosmos2_generate(image1, prompt="<grounding> An image of")

In [None]:
question = ["what is the color of the snowman?",
            "<grounding> what is the color of the snowman's hat?",
            "what is snowman doing in the image?",
            "is something bad going to happen in the image?",
            "what will happen to snowman?",
            "what will happen if the snowman melts?"]

for q in question:
    print("Question:", q)
    answer = kosmos2_generate(image1, prompt=f"{q}")
    print("Answer:", answer[0], "\n", answer[1])
    print("-" * 50)

In [None]:
url = "https://farm7.staticflickr.com/6076/6081598580_50d7e63633_z.jpg"
image2 = Image.open(requests.get(url, stream=True).raw)
display(image2)

In [None]:
question = ["<grounding> An image of",
            "Which animals are in this image?",
            "What are the elephants doing in this image?"]

for q in question:
    print("Question:", q)
    answer = kosmos2_generate(image2, prompt=f"{q}")
    print("Answer:", answer[0], "\n", answer[1])
    print("-" * 50)

In [None]:
url = "https://i.pinimg.com/736x/9e/c3/26/9ec3269cf6bdcccbfcf64c1cd4c9a453.jpg"
image3 = Image.open(requests.get(url, stream=True).raw)
display(image3)

In [None]:
question = ["<grounding> An image of ",
            "Where is this image taken?",
            "What is the tallest object in the image?",
            "Any person in the image?",
            "<grounding> where is person in the image?"]

for q in question:
    print("Question:", q)
    answer = kosmos2_generate(image3, prompt=f"{q}")
    print("Answer:", answer[0], "\n", answer[1])
    print("-" * 50)

In [None]:
url = "https://i.pinimg.com/736x/9e/c3/26/9ec3269cf6bdcccbfcf64c1cd4c9a453.jpg"
image4 = Image.open(requests.get(url, stream=True).raw)
display(image4)

In [None]:
question = ["<grounding> An image of ",
            "Where is this image taken?",
            "What is the tallest object in the image?",
            "Any person in the image?",
            "where is person in the image?"]

for q in question:
    print("Question:", q)
    answer = kosmos2_generate(image4, prompt=f"{q}")
    print("Answer:", answer[0], "\n", answer[1])
    print("-" * 50)