In [1]:
with open('key.txt', "r") as file:
    api_key = file.read().strip()

In [2]:
import time
import random
import torch
import torchvision.transforms as transforms
import google.generativeai as genai
from PIL import Image

from fixed_prompts import classification_p, description_p, class_ps
from cross_modal_encoder import encoder


genai.configure(api_key=api_key)


def gemini_process(prompt, image=None, temperature=0.99):
    """
    Uses Gemini Pro to generate text based on a prompt, optionally with an image.
    Args:
    prompt: The text prompt for Gemini Pro.
    image_path: Path to the input image (optional).
    temperature: Sampling temperature for generating diverse responses.
    Returns:
    The generated response as a string.
    """
    model = genai.GenerativeModel("models/gemini-1.5-flash")
    
    input_content = [prompt]
    
    if image is not None:
        try:
            input_content.append(image)
        except Exception as e:
            print(f"Error processing image: {e}")
            return None

    response = model.generate_content(
        contents=input_content,
        generation_config={"temperature": temperature, "max_output_tokens": 256}
    )
    
    try:
        print(f'Prompt: {prompt}\nResponse: {response.candidates[0].content.parts[0].text}')
        # Quota: 15 rpm; this should be fine
        time.sleep(4 + random.random())
        return response.candidates[0].content.parts[0].text
    except Exception as e:
        print(f"Error in response generation: {e}")
        return prompt

  warn(
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def create_classifier(class_names, k=10):
    """
    Constructs a zero-shot image classifier.
    Args:
    class_names: A list of class names.
    class_ps: A list of prompt templates for generating class descriptions.
    k: Number of class descriptions to be generated by the LLM.
    Returns:
    A zero-shot image classification model.
    """
    assert k >= len(class_ps), "k should be greater than or equal to the number of class prompts."
    assert k % len(class_ps) == 0, "k should be a multiple of the number of class prompts."

    weights = []
    for class_name in class_names:
        class_name_feature = encoder.encode_text(class_name)
        template_feature = encoder.encode_text(f"A photo of {class_name}")
        llm_class_description = torch.zeros((1, encoder.output_feature_length))

        for _ in range(k // len(class_ps)):
            for class_p in class_ps:
                llm_description = gemini_process(class_p.format(class_name=class_name), temperature=0.99)
                llm_class_description += encoder.encode_text(llm_description)

        llm_class_description /= k
        class_feature = class_name_feature + template_feature + llm_class_description
        normalized_class_feature = class_feature / class_feature.norm(dim=-1, keepdim=True)
        weights.append(normalized_class_feature.squeeze())

    weights = torch.stack(weights)
    model = {"weights": weights.T, "class_names": class_names}
    return model


def classify(image, classifier):
    """
    Performs zero-shot image classification.
    Args:
    image: Input testing image.
    classifier: A zero-shot classification model generated by create_classifier function.
    classification_p: Prompt template for generating the initial classification prediction.
    description_p: Prompt template for generating an image description.
    Returns:
    Predicted class name.
    """
    image_feature = encoder.encode_image(image)

    # Gemini Pro for initial classification prediction
    initial_prediction = gemini_process(classification_p.format(classes=classifier["class_names"]), image, temperature=0.99)
    prediction_feature = encoder.encode_text(initial_prediction)

    # Gemini Pro for generating image description
    image_description = gemini_process(description_p, image, temperature=0.99)
    description_feature = encoder.encode_text(image_description)

    query_feature = image_feature + prediction_feature + description_feature
    query_feature /= query_feature.norm(dim=-1, keepdim=True)

    logits = torch.matmul(query_feature, classifier["weights"])
    index = torch.argmax(logits, dim=-1)
    return classifier["class_names"][index.item()]

In [4]:
def load_image(image_path, image_size=224):
    """
    Loads and preprocesses an image for classification.
    Args:
    image_path: Path to the image file.
    image_size: The size to which the image will be resized (default: 224).
    Returns:
    A preprocessed image tensor.
    """
    image = Image.open(image_path).convert("RGB")

    # preprocess = transforms.Compose([
    #     transforms.Resize((image_size, image_size)),
    #     transforms.ToTensor(),
    # ])

    # image_tensor = preprocess(image).unsqueeze(0)
    return image

In [5]:
labels = ['apple_pie', 'baby_back_ribs', 'baklava']
classifier = create_classifier(class_names=labels, k=10)

Prompt: 1. Describe what a apple_pie looks like in one or two sentences.
Response: A golden-brown apple pie is a round or rectangular pastry crust filled with sweet, spiced apples, often with a lattice top or a simple crust covering.  The filling may be visible through slits or gaps in the crust, showing its juicy, slightly caramelized apples.

Prompt: 2. How can you identify a apple_pie in one or two sentences?
Response: An apple pie is identified by its sweet, spiced filling of apples baked within a flaky pastry crust, often topped with a lattice or streusel.  It's typically round and golden-brown.

Prompt: 3. What does a apple_pie look like? Respond with one or two sentences.
Response: An apple pie typically looks like a golden-brown pastry crust filled with a sweet, cinnamon-spiced mixture of apples.  It might be topped with a lattice crust, streusel topping, or simply a dusting of sugar.

Prompt: 4. Describe an image from the internet of a apple_pie. Respond with one or two senten

In [18]:
img = load_image('food-101/images/apple_pie/64846.jpg')
predicted_label = classify(img, classifier)
print("Predicted Label:", predicted_label)

Prompt: You are given an image and a list of class labels. Classify the image given the class labels. Answer using a single word if possible. Here are the class labels: ['apple_pie', 'baby_back_ribs', 'baklava']
Response: Baklava
Prompt: What do you see? Describe any object precisely, including its type or class.
Response: Here is a description of the object in the image:

The image shows a piece of what appears to be a baked dessert, possibly a crumble or clafoutis, on a white plate. 


**Type/Class:** The dessert is a type of baked confection.  More specifically, it looks like a fruit crumble or a similar type of dessert with a crisp, browned topping and a soft, possibly custardy interior.  The precise type cannot be determined definitively from the image.


**Description:** The dessert is roughly triangular in shape, with a golden-brown, slightly irregular surface suggesting a crumbly texture.  A dollop of whipped cream, white and smooth in appearance, is placed on top of it. A smal