In [1]:
from datasets import load_dataset
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration
import torch
import json
from tqdm import tqdm

dataset = load_dataset("huggan/wikiart", split="train")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = "mps" if torch.backends.mps.is_available() else "cpu"

In [4]:
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

results = []

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [None]:
for idx in tqdm(range(len(dataset))):
    sample = dataset[idx]
    image = sample["image"]

    # С помощью CLIP получаем эмбеддинги изображений 
    inputs = clip_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        image_features = clip_model.get_image_features(**inputs)
    image_code = image_features.squeeze().cpu().tolist()

    # С помощью BLIP получаем тестовое описание к изображениям
    blip_inputs = blip_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        caption_ids = blip_model.generate(**blip_inputs)
    caption = blip_processor.batch_decode(caption_ids, skip_special_tokens=True)[0]

    # Снова используя CLIP получаем эмбеддинги текскового описания
    text_inputs = clip_processor(text=[caption], return_tensors="pt").to(device)
    with torch.no_grad():
        text_features = clip_model.get_text_features(**text_inputs)
    text_code = text_features.squeeze().cpu().tolist()

    entry = {
        "id": idx,
        "image_code": image_code,
        "caption": caption,
        "text_code": text_code,
        "artist": sample["artist"],
        "genre": sample["genre"],
        "style": sample["style"]
    }
    results.append(entry)

with open("wikiart_embeddings.json", "w") as f:
    json.dump(results, f, indent=2)
