## Zero-Shot Prompting

In [None]:
import base64
import os
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from functools import partial

api_key = ""
client = OpenAI(api_key=api_key)

def encode_image(image_path):
    """
    Encodes an image to a base64 string.

    Args:
        image_path (str): The file path of the image to encode.

    Returns:
        str: Base64-encoded string of the image, or None if an error occurs.
    """    
    try:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode("utf-8")
    except Exception as e:
        print(f"Error encoding image: {e}")
        return None
    
def llm_chat_completion(messages, model="gpt-4o-mini", max_tokens=300, temperature=0.0):
    """
    Calls OpenAI's ChatCompletion API with the specified parameters.

    Args:
        messages (list): A list of message dictionaries for the conversation.
        model (str): The model to use for the chat completion.
        max_tokens (int, optional): Maximum tokens for the response. Defaults to 300.
        temperature (float, optional): Sampling temperature for randomness. Defaults to 0.0.

    Returns:
        str or None: The response content from the API, or None if an error occurs.
    """
    try:
        response =  client.chat.completions.create(
            model=model,
            messages=messages,
            max_tokens=max_tokens,
            temperature=temperature
        )
        return response.choices[0].message.content
    except Exception as e:
        print(f"Error calling LLM: {e}")
        return None

def build_few_shot_messages(few_shot_prompt, user_prompt = "Please analyze the following image:", detail="auto"):
    """
    Generates few-shot example messages from image-caption pairs.

    Args:
        few_shot_prompt (dict): A dictionary mapping image paths to metadata, 
                                including "image_caption".
        detail (str, optional): Level of image detail to include. Defaults to "auto".

    Returns:
        list: A list of few-shot example messages.
    """
    few_shot_messages = []
    for path, data in few_shot_prompt.items():
        base64_image = encode_image(path)
        if not base64_image:
            continue  # skip if failed to encode
        caption = data
        
        few_shot_messages.append(
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": user_prompt},
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{base64_image}",
                            "detail": detail
                        }
                    },
                ]
            }
        )
        few_shot_messages.append({"role": "assistant", "content": caption})
    return few_shot_messages

def build_user_message(image_path, user_prompt="Please analyze the following image:", detail="auto"):
    """
    Creates a user message for analyzing a single image.

    Args:
        image_path (str): Path to the image file.
        detail (str, optional): Level of image detail to include. Defaults to "auto".

    Returns:
        dict or None: The user message dictionary, or None if image encoding fails.
    """
    base64_image = encode_image(image_path)
    if not base64_image:
        return None
    
    return {
        "role": "user",
        "content": [
            {"type": "text", "text": user_prompt},
            {
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}",
                    "detail": detail
                }
            },
        ]
    }

def get_image_caption(
    image_path,
    few_shot_prompt=None,
    system_prompt="You are a helpful assistant that can analyze images and provide captions.",
    user_prompt="Please analyze the following image:",
    model="gpt-4o-mini",
    max_tokens=300,
    detail="auto",
    llm_chat_func=llm_chat_completion,
    temperature=0.0
):
    """
    Gets a caption for an image using a LLM.

    Args:
        image_path (str): File path of the image to be analyzed.
        few_shot_prompt (dict, optional): Maps image paths to {"image_caption": <caption>}.
        system_prompt (str, optional): Initial system prompt for the LLM.
        user_prompt (str, optional): User prompt for the LLM.
        model (str, optional): LLM model name (default "gpt-4o-mini").
        max_tokens (int, optional): Max tokens in the response (default 300).
        detail (str, optional): Level of detail for the image analysis (default "auto").
        llm_chat_func (callable, optional): Function to call the LLM. Defaults to `llm_chat_completion`.
        temperature (float, optional): Sampling temperature (default 0.0).

    Returns:
        str or None: The generated caption, or None on error.
    """
    try:
        user_message = build_user_message(image_path, detail)
        if not user_message:
            return None
        
        # Build message sequence
        messages = [{"role": "system", "content": system_prompt}]

        # Include few-shot examples if provided
        if few_shot_prompt:
            few_shot_messages = build_few_shot_messages(few_shot_prompt, detail)
            messages.extend(few_shot_messages)

        messages.append(user_message)

        # Call the LLM
        response_text = llm_chat_func(
            model=model,
            messages=messages,
            max_tokens=max_tokens,
            temperature=temperature
        )
        return response_text

    except Exception as e:
        print(f"Error getting caption: {e}")
        return None
    
def process_images_in_parallel(
    image_paths, 
    model="gpt-4o-mini", 
    system_prompt="You are a helpful assistant that can analyze images and provide captions.", 
    user_prompt="Please analyze the following image:", 
    few_shot_prompt = None, 
    max_tokens=300, 
    detail="auto", 
    max_workers=5):
    """
    Processes a list of images in parallel to generate captions using a specified model.

    Args:
        image_paths (list): List of file paths to the images to be processed.
        model (str): The model to use for generating captions (default is "gpt-4o").
        max_tokens (int): Maximum number of tokens in the generated captions (default is 300).
        detail (str): Level of detail for the image analysis (default is "auto").
        max_workers (int): Number of threads to use for parallel processing (default is 5).

    Returns:
        dict: A dictionary where keys are image paths and values are their corresponding captions.
    """    
    captions = {}
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Pass additional arguments using a lambda or partial
        future_to_image = {
            executor.submit(
                get_image_caption, 
                image_path, 
                few_shot_prompt, 
                system_prompt, 
                user_prompt, 
                model, 
                max_tokens, 
                detail): image_path
            for image_path in image_paths
        }

        # Use tqdm to track progress
        for future in tqdm(as_completed(future_to_image), total=len(image_paths), desc="Processing images"):
            image_path = future_to_image[future]
            try:
                caption = future.result()
                captions[image_path] = caption
            except Exception as e:
                print(f"Error processing {image_path}: {e}")
                captions[image_path] = None
    return captions

## Process Zero-Shot Captions

In [None]:
IMAGE_QUALITY = "high"
PATH_TO_SAMPLES = "images/"

system_prompt = """You are an AI assistant that provides captions of images. 
You will be provided with an image. Analyze the content, context, and notable features of the images.
Provide a concise caption that covers the important aspects of the image."""

user_prompt = "Please analyze the following image:"

image_paths = [os.path.join(PATH_TO_SAMPLES, x) for x in os.listdir(PATH_TO_SAMPLES)]

zero_shot_high_quality_captions = process_images_in_parallel(image_paths, model = "gpt-4o-mini", system_prompt=system_prompt, user_prompt = user_prompt, few_shot_prompt= None, detail=IMAGE_QUALITY, max_workers=5)

In [None]:
with open("results/zero_shot_captions_results.json", "w", encoding="utf-8") as f:
    f.write(json.dumps(zero_shot_high_quality_captions))

## Perform few-shot prompt

In [None]:
image_captions = json.load(open("image_captions.json"))

FEW_SHOT_EXAMPLES_PATH = "few_shot_examples/"

few_shot_samples = os.listdir(FEW_SHOT_EXAMPLES_PATH)
few_shot_captions = {os.path.join(FEW_SHOT_EXAMPLES_PATH,k):v for k,v in  image_captions.items() if k in few_shot_samples}

In [None]:
IMAGE_QUALITY = "high"
few_shot_high_quality_captions = process_images_in_parallel(image_paths, model = "gpt-4o-mini", few_shot_prompt= few_shot_captions, detail=IMAGE_QUALITY, max_workers=5)

In [None]:
with open("results/few_shot_captions_results.json", "w") as f:
    f.write(json.dumps(few_shot_high_quality_captions))

## Perform Chain-of-Thought Prompting

In [None]:
few_shot_samples_1_cot = """Observations:
Visible Blur: The foreground and parts of the image are out of focus or blurred, indicating either camera movement or subject motion during the shot.
Tall, Ornate Buildings: The structures have multiple floors, detailed balconies, and decorative facades, suggesting older or classic urban architecture.
Street-Level View: Parked cars line both sides of a road or narrow street, confirming an urban environment with typical city traffic and infrastructure.
Soft, Warm Light: The sunlight appears to be hitting the buildings at an angle, creating a warm glow on the façade and enhancing the sense of a real city scene rather than a staged setup.

Final Caption: A blurry photo of a city street with buildings."""

few_shot_samples_2_cot = """Observations:  
Elevated Desert Dunes: The landscape is made up of large, rolling sand dunes in a dry, arid environment.  
Off-Road Vehicle: The white SUV appears equipped for travel across uneven terrain, indicated by its size and ground clearance.  
Tire Tracks in the Sand: Visible tracks show recent movement across the dunes, reinforcing that the vehicle is in motion and navigating a desert path.  
View from Inside Another Car: The dashboard and windshield framing in the foreground suggest the photo is taken from a passenger’s or driver’s perspective, following behind or alongside the SUV.  

Final Caption: A white car driving on a desert road."""


few_shot_samples_3_cot = """Observations:
Towering Rock Formations: The steep canyon walls suggest a rugged desert landscape, with sandstone cliffs rising on both sides.
Illuminated Tents: Two futuristic-looking tents emit a soft glow, indicating a nighttime scene with lights or lanterns inside.
Starry Night Sky: The visible stars overhead reinforce that this is an outdoor camping scenario after dark.
Single Male Figure: A man, seen from behind, stands near one of the tents, indicating he is likely part of the camping group.

Final Caption: A man standing next to a tent in the desert."""

In [None]:
import copy

few_shot_samples_cot = copy.deepcopy(few_shot_captions)

few_shot_samples_cot["few_shot_examples/photo_1.jpg"] = few_shot_samples_1_cot
few_shot_samples_cot["few_shot_examples/photo_3.jpg"] = few_shot_samples_2_cot
few_shot_samples_cot["few_shot_examples/photo_4.jpg"] = few_shot_samples_3_cot

In [None]:
IMAGE_QUALITY = "high"

system_prompt_cot_reasoning = """You are an AI assistant that generates captions for images. 
You will be provided with an image. Your task is to carefully analyze the content, context, and notable features of the image. 
Break down your reasoning into clear, intermediate observations, considering both prominent and subtle details. 
Finally, create a concise caption that accurately summarizes the key elements and context of the image."""

user_prompt = "Please analyze the following image:"

cot_high_quality_captions = process_images_in_parallel(image_paths, model = "gpt-4o-mini", system_prompt=system_prompt_cot_reasoning, user_prompt=user_prompt, few_shot_prompt= few_shot_samples_cot, detail=IMAGE_QUALITY, max_workers=5)

In [None]:
import re
def extract_final_caption(text):
    """
    Extracts the content that comes after "Final Caption" in the provided text.

    Args:
    text (str): The text from which to extract the caption.

    Returns:
    str: The extracted caption or an empty string if no caption is found.
    """
    try:
        pattern = r"(?<=Final Caption: ).*"
        match = re.search(pattern, text)
        return match.group()
    except:
        return text


In [None]:
cot_high_quality_captions_extracted = {}
captions = []

for k, v in cot_high_quality_captions.items():
    captions_extracted = extract_final_caption(v)
    cot_high_quality_captions_extracted[k] = captions_extracted

In [None]:
for key, value in cot_high_quality_captions.items():
    print(key)
    print(value)

In [None]:
with open("results/cot_high_quality_captions_results.json", "w") as f:
    f.write(json.dumps(cot_high_quality_captions_extracted))
    

## Utilize Object Detection Prompting

In [None]:
# Load model directly
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
model = AutoModelForZeroShotObjectDetection.from_pretrained("google/owlvit-base-patch32")

In [None]:
IMAGE_QUALITY = "high"
system_prompt_object_detection = """You are provided with an image. You must identify all important objects in the image, and provide a standardized list of objects in the image.
Return your output as follows:
Output: object_1, object_2"""

user_prompt = "Extract the objects from the provided image:"

detected_objects = process_images_in_parallel(image_paths, system_prompt=system_prompt_object_detection, user_prompt=user_prompt, model = "gpt-4o-mini", few_shot_prompt= None, detail=IMAGE_QUALITY, max_workers=5)

In [None]:
detected_objects_cleaned = {}

for key, value in detected_objects.items():
    detected_objects_cleaned[key] = list(set([x.strip() for x in value.replace("Output: ", "").split(",")]))

In [None]:
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import torch

def detect_and_draw_bounding_boxes(
    image_path,
    text_queries,
    model,
    processor,
    output_path,
    score_threshold=0.2
):
    """
    Detect objects in an image and draw bounding boxes over the original image using PIL.

    Parameters:
    - image_path (str): Path to the image file.
    - text_queries (list of str): List of text queries to process.
    - model: Pretrained model to use for detection.
    - processor: Processor to preprocess image and text queries.
    - output_path (str): Path to save the output image with bounding boxes.
    - score_threshold (float): Threshold to filter out low-confidence predictions.

    Returns:
    - output_image_pil: A PIL Image object with bounding boxes and labels drawn.
    """
    img = Image.open(image_path).convert("RGB")
    orig_w, orig_h = img.size  # original width, height

    inputs = processor(
        text=text_queries,
        images=img,
        return_tensors="pt",
        padding=True,
        truncation=True
    ).to("cpu")

    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)

    logits = torch.max(outputs["logits"][0], dim=-1)       # shape (num_boxes,)
    scores = torch.sigmoid(logits.values).cpu().numpy()    # convert to probabilities
    labels = logits.indices.cpu().numpy()                  # class indices
    boxes_norm = outputs["pred_boxes"][0].cpu().numpy()    # shape (num_boxes, 4)

    converted_boxes = []
    for box in boxes_norm:
        cx, cy, w, h = box
        cx_abs = cx * orig_w
        cy_abs = cy * orig_h
        w_abs  = w  * orig_w
        h_abs  = h  * orig_h
        x1 = cx_abs - w_abs / 2.0
        y1 = cy_abs - h_abs / 2.0
        x2 = cx_abs + w_abs / 2.0
        y2 = cy_abs + h_abs / 2.0
        converted_boxes.append((x1, y1, x2, y2))

    draw = ImageDraw.Draw(img)
    
    for score, (x1, y1, x2, y2), label_idx in zip(scores, converted_boxes, labels):
        if score < score_threshold:
            continue

        draw.rectangle([x1, y1, x2, y2], outline="red", width=3)

        label_text = text_queries[label_idx].replace("An image of ", "")

        text_str = f"{label_text}: {score:.2f}"
        text_size = draw.textsize(text_str)  # If no font used, remove "font=font"
        text_x, text_y = x1, max(0, y1 - text_size[1])  # place text slightly above box

        draw.rectangle(
            [text_x, text_y, text_x + text_size[0], text_y + text_size[1]],
            fill="white"
        )
        draw.text((text_x, text_y), text_str, fill="red")  # , font=font)

    img.save(output_path, "JPEG")

    return img


In [None]:
import json
with open("results/detected_objects.json", "w") as f:
    f.write(json.dumps(detected_objects_cleaned))

In [None]:
for key, value in tqdm(detected_objects_cleaned.items()):
    value = ["An image of " + x for x in value]
    detect_and_draw_bounding_boxes(key, value, model, processor, "images_with_bounding_boxes/" + key.split("/")[-1], score_threshold=0.15)

In [None]:
IMAGE_QUALITY = "high"
image_paths_obj_detected_guided = [x.replace("downloaded_images", "images_with_bounding_boxes") for x in image_paths] 

system_prompt_obj_det="""You are a helpful assistant that can analyze images and provide captions. You are provided with images that also contain bounding box annotations of the important objects in them, along with their labels.
Analyze the overall image and the provided bounding box information and provide an appropriate caption for the image.""",

user_prompt="Please analyze the following image:",

obj_det_zero_shot_high_quality_captions = process_images_in_parallel(image_paths_obj_detected_guided, system_prompt=system_prompt_obj_det, user_prompt=user_prompt, model = "gpt-4o-mini", few_shot_prompt= None, detail=IMAGE_QUALITY, max_workers=5)

In [None]:
with open("results/zero_shot_object_detection_guided_captions.json", "w") as f:
    f.write(json.dumps(obj_det_zero_shot_high_quality_captions))