In [None]:
import os
import numpy as np
import fiftyone.zoo as foz
import fiftyone.types as fot
import json
import torch
import pandas as pd
from PIL import Image, ImageDraw
from transformers import AutoProcessor, AutoModelForCausalLM
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm  

In [2]:
# Device settings
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

In [3]:
# Model
def load_model():
    CHECKPOINT = "microsoft/Florence-2-base-ft"
    model = AutoModelForCausalLM.from_pretrained(CHECKPOINT, trust_remote_code=True).to(device, dtype=torch_dtype)
    processor = AutoProcessor.from_pretrained(CHECKPOINT, trust_remote_code=True)
    return model, processor

In [16]:
# processor
model, processor = load_model()

In [None]:
# Checking for a single image
image = Image.open('tiny_coco_subset/data/000000000139.jpg').convert("RGB")

inputs = processor(text="<CAPTION>", images=image, return_tensors="pt").to(device, torch_dtype)

with torch.no_grad():
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=512,
        num_beams=3,
        do_sample=False
    )

generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
    generated_text,
    task=task_prompt,
    image_size=(image.width, image.height)
)

caption = parsed_answer.get("<CAPTION>", "No caption generated.")
print(caption)

In [None]:
# Processing for all images 

image_folder = "tiny_coco_subset/data"
output_csv = "generated_captions.csv"

image_files = [f for f in os.listdir(image_folder) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
captions_data = []

for image_file in tqdm(image_files):
    try:
        
        image_path = os.path.join(image_folder, image_file)
        image = Image.open(image_path).convert("RGB")

        
        inputs = processor(text="<CAPTION>", images=image, return_tensors="pt").to(device, torch_dtype)

        
        with torch.no_grad():
            generated_ids = model.generate(
                input_ids=inputs["input_ids"],
                pixel_values=inputs["pixel_values"],
                max_new_tokens=512,
                num_beams=3,
                do_sample=False
            )

        
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        parsed_answer = processor.post_process_generation(
            generated_text,
            task=task_prompt,
            image_size=(image.width, image.height)
        )

        
        caption = parsed_answer.get("<CAPTION>", "No caption generated.")

        
        captions_data.append([image_file, caption])

    except Exception as e:
        print(f"Error processing {image_file}: {e}")
        captions_data.append([image_file, "ERROR"])

with open(output_csv, mode="w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["image_file", "caption"])
    writer.writerows(captions_data)

print(f"Captions saved to {output_csv}")


100%|██████████| 1000/1000 [09:41<00:00,  1.72it/s]

Captions saved to generated_captions.csv



