In [None]:
import pandas as pd
df_products = pd.read_csv("selected_products.csv")
df_products.head(3)

Unnamed: 0,asin,title,imgUrl,productURL,stars,reviews,price,category_id,isBestSeller,boughtInLastMonth,id,category_name
0,B08ZDQX51W,Original Replacement Dell 130W Laptop Charger ...,https://m.media-amazon.com/images/I/61sADwl+YW...,https://www.amazon.com/dp/B08ZDQX51W,4.5,0,24.98,65,False,0,65,Laptop Accessories
1,B01BPCTXHC,Griffin Elevator Stand for Laptops - Lift Your...,https://m.media-amazon.com/images/I/710N2S69Nv...,https://www.amazon.com/dp/B01BPCTXHC,4.6,0,35.0,65,False,0,65,Laptop Accessories
2,B0B2RKB1HP,"Rolling Backpack, 17.3 inch Laptop Backpack wi...",https://m.media-amazon.com/images/I/61EWjAcjla...,https://www.amazon.com/dp/B0B2RKB1HP,4.6,0,89.99,65,False,0,65,Laptop Accessories


In [None]:
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig

# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    # New way of increasing inference speed
    # use_flash_attention_2=True
)

# Load the quantized model
model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-13b-hf",
    quantization_config=quantization_config,
    device_map="auto",
)

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-13b-hf")

In [None]:
from PIL import Image
import requests
import re

def get_image_caption(image_url):
    # Load a complete image
    try:
        image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
    except:
        return ""
        
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": "What product is shown in this image? Describe the product only"},
            ],
        },
    ]
    #  Use 3rd person perspective to describe it.
    # Describe the image as a product.
    inputs = processor.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to(model.device, torch.float16)

    # Generate
    generate_ids = model.generate(**inputs, max_new_tokens=300)
    result = processor.batch_decode(generate_ids, skip_special_tokens=True)
    return result

def process_text(text):
    # Extract text after "ASSISTANT: The image shows "
    match = re.search(r"ASSISTANT: The image shows (.*?)]", text)
    
    if match:
        extracted_text = match.group(1).strip()  # Get matched text
        formatted_text = extracted_text[0].upper() + extracted_text[1:-1]  # Capitalize first letter
        return formatted_text
    else:
        return ""

In [None]:
import csv

for index, row in enumerate(df_products.itertuples()):
    # if index < 288:
    #     continue
        
    print(f"Index No: {index}\nID: {row.asin}\n Image URL: {row.imgUrl}")

    caption = get_image_caption(row.imgUrl)
    caption = process_text(str(caption))
    
    print(f"Generated Caption: {caption}\n\n")
    
    with open("captions.csv", mode="a", newline='') as f:
        writer = csv.writer(f)
        writer.writerow([row.asin, row.imgUrl, caption])