In [None]:
!pip install transformers torch bitsandbytes safetensors
!pip install sentence_transformers
!pip install datasets
!pip install 'accelerate>=0.26.0'

In [None]:
import torch
from datasets import load_dataset
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from sentence_transformers import SentenceTransformer
from PIL import Image
import json
from tqdm import tqdm
import os

In [None]:

BATCH_SIZE = 64
INTERMEDIATE_SAVE_EVERY = 500
OUTPUT_PARTIAL = "wikiart_captions_partial.json"
OUTPUT_FINAL = "wikiart_captions_final.json"

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = LlavaNextProcessor.from_pretrained("unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit")
model = LlavaNextForConditionalGeneration.from_pretrained(
    "unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
).to(device)
model.eval()

text_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device)

wikiart_dataset = load_dataset("huggan/wikiart", split="train")

prompt = "<image> Describe the artwork's genre, style, subject and colors in detail."


results = []


if os.path.exists(OUTPUT_PARTIAL):
    with open(OUTPUT_PARTIAL, "r") as f:
        results = json.load(f)
    processed_indices = set([item["index"] for item in results])
else:
    processed_indices = set()

for start_idx in tqdm(range(0, len(wikiart_dataset), BATCH_SIZE)):
    batch = wikiart_dataset[start_idx:start_idx + BATCH_SIZE]

    images = []
    actual_indices = []

    for i, item in enumerate(batch):
        idx = start_idx + i
        if idx in processed_indices:
            continue

        image = item["image"]
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)

        images.append(image)
        actual_indices.append(idx)

    if not images:
        continue 

    try:
        with torch.inference_mode():
            inputs = processor(images=images, text=prompt, return_tensors="pt", padding=True).to(device)
            output = model.generate(**inputs, max_new_tokens=250, do_sample=False)

            captions = []
            for j in range(len(images)):
                generated_tokens = output[j][inputs["input_ids"].shape[-1]:]
                caption = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
                captions.append(caption)

            embeddings = text_encoder.encode(captions, batch_size=32, convert_to_tensor=False)

            for idx, caption, embedding in zip(actual_indices, captions, embeddings):
                results.append({
                    "index": idx,
                    "caption": caption,
                    "embedding": embedding
                })

    except Exception as e:
        print(f"Error at index {start_idx}: {e}")
        continue

    if start_idx % INTERMEDIATE_SAVE_EVERY == 0:
        with open(OUTPUT_PARTIAL, "w") as f:
            json.dump(results, f, indent=2)

with open(OUTPUT_FINAL, "w") as f:
    json.dump(results, f, indent=2)
