In [None]:
import google.generativeai as genai
import random
from datasets import load_dataset
from collections import defaultdict
import matplotlib.pyplot as plt

from PIL import Image
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("philschmid/amazon-product-descriptions-vlm")
product_images = dataset['train']['image']
product_desc = dataset['train']['description']
uniq_ids = dataset['train']['Uniq Id']

Define a function that sends an image and the specified prompts to genai to generate synthetic captions so that we have a larger dataset. We specify different prompts so that the generated captions are diverse and prevent from overfitting.

In [None]:
genai.configure(api_key="get your own key")

uniq_to_generated_description = defaultdict(list)

error_imgs = []

def generate_captions(image, description, uniq_id):
    prompt = f"""
    Here is a product image and its description:
    "{description}"

    Generate 6 alternative captions:
    "**Concise**: A short, to-the-point caption (under 15 words).",
    "**Marketing-Oriented**: A persuasive, engaging caption that highlights key benefits.",
    "**Technical & Detailed**: A caption focused on specifications, materials, and functionality.",
    "**Visual-Only Description (Basic)**: Describe only the product’s appearance without mentioning its use.",
    "**Visual-Only Description (Creative)**: A descriptive, engaging take on the product’s look with vivid imagery.",
    "**Keyword-Rich**: Use common product-related keywords.",

    Separate each caption with a newline character. Do **not** format in markdown.
    """

    model = genai.GenerativeModel("gemini-1.5-flash")
    response = model.generate_content([prompt, image])  # Send image + text
    
    return response.text.strip() if response else "Error: No response for " + str(uniq_id)

In [254]:
with open("./data.json", "r") as file:
    data = json.load(file)  # Load JSON data
uniques = defaultdict(list)
for key, value in data.items():
    if isinstance(value, list):  # Ensure values are lists
        uniques[key] = value

len(uniques.keys())

1344

currently we have 1,344 imgs and captions

Since gemini gives a time limit we will not be able to generate all of them at once so I will just generate them in batches and manually update the start and end as a batch.

In [None]:
start = 1286
end = 1000 + 346
current_imgs = [product_images[start]]
current_desc = [product_desc[start]]
current_ids = [uniq_ids[start]]
uniq_to_generated_description[current_ids[0]] = []

In [247]:
len(product_images)

1345

In [275]:
for image, description, uniq_id in zip(current_imgs, current_desc, current_ids):
    generated_captions = generate_captions(image, description, uniq_id)
    for x in generated_captions.split('\n'):
        # if 'error' not in x.lower():
        if x != '':
            uniq_to_generated_description[uniq_id].append(x)
        else:
            error_imgs.append(uniq_id)

In [276]:
print(len(uniq_to_generated_description.keys()))

1345


store these generated captions for later use

In [278]:
with open("./data.json", "w") as file:
    json.dump(uniq_to_generated_description, file)