In [6]:
import json
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torch
from PIL import Image
from tqdm import tqdm
import re

In [2]:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL", trust_remote_code=True).eval().cuda()

The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".


Loading checkpoint shards:   0%|          | 0/10 [00:00<?, ?it/s]

In [31]:
def generate_caption(img_pth):
    query = tokenizer.from_list_format([
        {'image': img_pth},
        {'text': 'In two or more sentences, describe the image in extensive detail:'},
    ])
    inputs = tokenizer(query, return_tensors='pt')
    inputs = inputs.to(model.device)
    pred = model.generate(**inputs, max_new_tokens=300)
    raw_output = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
    response = re.search(r":(.*?)<\|endoftext\|>", raw_output)
    
    # Return clean caption if found, otherwise return raw output
    return response.group(1).strip() if response else raw_output
    return response

In [4]:
images = [img for img in os.listdir("./") if img.split(".")[-1] in ["png", "jpg", "jpeg"]]

In [32]:
caps = []
for i, img_pth in tqdm(enumerate(images), total=len(images)):
    img_cap = generate_caption(img_pth)
    if i == 0:
        print(img_cap)
    caps.append({
        "path": img_pth,
        "pid": str(i),
        "name": ".".join(img_pth.split(".")[:-1]),
        "caption": img_cap
    })

 11%|█         | 2/19 [00:02<00:16,  1.02it/s]

This image depicts a bar chart that contains four bars. The bars are labeled as follows: Lorem, Ipsum, Dolor, and Sit. The bars are all of the same height, and they are all of the same width. The bar chart is drawn on a white background.


100%|██████████| 19/19 [00:21<00:00,  1.15s/it]


In [35]:
path = "cap.json"

In [36]:
with open(path, 'w') as f:
    json.dump(caps, f, indent=4)