In [1]:
import gradio as gr
from PIL import Image
import torch
import numpy as np
import faiss
import json

from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    CLIPProcessor,
    CLIPModel
)
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
wikiart_dataset = load_dataset("huggan/wikiart", split="train")
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [3]:
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device).eval()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [5]:
image_index = faiss.read_index("../create_index/image_index.faiss")
text_index = faiss.read_index("../create_index/text_index.faiss")

In [6]:
def generate_caption(image: Image.Image):
    inputs = blip_processor(image, return_tensors="pt").to(device)
    with torch.no_grad():
        caption_ids = blip_model.generate(**inputs)
    caption = blip_processor.decode(caption_ids[0], skip_special_tokens=True)
    return caption

In [7]:
def get_clip_text_embedding(text):
    inputs = clip_processor(text=[text], return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        features = clip_model.get_text_features(**inputs)
    features = features.cpu().numpy().astype("float32")
    faiss.normalize_L2(features)
    return features

In [8]:
def get_clip_image_embedding(image):
    inputs = clip_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        features = clip_model.get_image_features(**inputs)
    features = features.cpu().numpy().astype("float32")
    faiss.normalize_L2(features)
    return features

In [9]:
def get_results_with_images(embedding, index, top_k=5):
    D, I = index.search(embedding, top_k)
    results = []
    for idx in I[0]:
        try:
            idx_int = int(idx)
            item = wikiart_dataset[idx_int]
            img = item["image"]
            caption = f"ID: {idx_int}"
            results.append((img, caption))
        except IndexError:
            continue
    return results

In [10]:
def search_similar_images(image: Image.Image):
    caption = generate_caption(image)

    text_emb = get_clip_text_embedding(caption)
    image_emb = get_clip_image_embedding(image)

    text_results = get_results_with_images(text_emb, text_index)
    image_results = get_results_with_images(image_emb, image_index)

    return caption, text_results, image_results

In [11]:
with gr.Blocks(title="🎨 Semantic WikiArt Search (BLIP + CLIP)") as demo:
    gr.Markdown("## Semantic WikiArt Search\nЗагрузите изображение и найдите похожие по описанию и изображению.")

    input_image = gr.Image(label="📥 Входное изображение", type="pil")

    caption_output = gr.Textbox(label="📜 Сгенерированное описание")

    gr.Markdown("### 🔍 Похожие по описанию (текстовое сходство)")
    text_gallery = gr.Gallery(columns=5, label="По описанию", height="auto")

    gr.Markdown("### 🎨 Похожие по изображению (визуальное сходство)")
    image_gallery = gr.Gallery(columns=5, label="По изображению", height="auto")

    def wrapper(image):
        caption, text_results, image_results = search_similar_images(image)
        return caption, text_results, image_results

    input_image.change(
        fn=wrapper,
        inputs=input_image,
        outputs=[caption_output, text_gallery, image_gallery]
    )

demo.launch()

* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.




In [13]:
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)

Rerunning server... use `close()` to stop if you need to change `launch()` parameters.
----


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


* Running on public URL: https://7f0275af3323561e9d.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


