In [6]:
!pip install transformers torch bitsandbytes safetensors
!pip install sentence_transformers
!pip install datasets
!pip install 'accelerate>=0.26.0' bitsandbytes



In [7]:
!pip install gradio faiss-cpu



In [22]:
import gradio as gr
from PIL import Image
import torch
import numpy as np
import faiss
from datasets import load_dataset

from transformers import (
    LlavaNextProcessor,
    LlavaNextForConditionalGeneration,
    CLIPProcessor,
    CLIPModel,
)
from sentence_transformers import SentenceTransformer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

wikiart_dataset = load_dataset("huggan/wikiart", split="train")

processor = LlavaNextProcessor.from_pretrained("unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit")
model = LlavaNextForConditionalGeneration.from_pretrained(
    "unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
).to(device).eval()

PROMPT = "<image> Describe the artwork's genre, style, subject and colors in detail."

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

text_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device)

image_index = faiss.read_index("image_index_Llava.faiss")
text_index = faiss.read_index("text_index_Llava.faiss")

def generate_caption(image: Image.Image):
    image = image.convert("RGB")
    inputs = processor(text=PROMPT, images=image, return_tensors="pt").to(device, torch.float16)
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=250)
    caption = processor.batch_decode(output, skip_special_tokens=True)[0]
    return caption.strip()

def get_text_embedding(text):
    with torch.no_grad():
        embedding = text_encoder.encode([text], normalize_embeddings=True)
    return np.array(embedding).astype("float32")

def get_clip_image_embedding(image):
    inputs = clip_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        features = clip_model.get_image_features(**inputs)
    features = features.cpu().numpy().astype("float32")
    faiss.normalize_L2(features)
    return features

def get_results_with_images(embedding, index, top_k=5):
    D, I = index.search(embedding, top_k)
    results = []
    for idx in I[0]:
        try:
            item = wikiart_dataset[int(idx)]
            results.append((item["image"], f"ID: {int(idx)}"))
        except IndexError:
            continue
    return results

def search_similar_images(image: Image.Image):
    caption = generate_caption(image)

    text_emb = get_text_embedding(caption)
    image_emb = get_clip_image_embedding(image)

    text_results = get_results_with_images(text_emb, text_index)
    image_results = get_results_with_images(image_emb, image_index)

    return caption, text_results, image_results

with gr.Blocks(title="🎨 Semantic WikiArt Search (LLava)") as demo:
    gr.Markdown("## Semantic WikiArt Search\nЗагрузите изображение и найдите похожие по описанию и изображению.")

    input_image = gr.Image(label="📥 Входное изображение", type="pil")

    caption_output = gr.Textbox(label="📜 Сгенерированное описание")

    gr.Markdown("### 🔍 Похожие по описанию (текстовое сходство)")
    text_gallery = gr.Gallery(columns=5, label="По описанию", height="auto")

    gr.Markdown("### 🎨 Похожие по изображению (визуальное сходство)")
    image_gallery = gr.Gallery(columns=5, label="По изображению", height="auto")

    def wrapper(image):
        caption, text_results, image_results = search_similar_images(image)
        return caption, text_results, image_results

    input_image.change(
        fn=wrapper,
        inputs=input_image,
        outputs=[caption_output, text_gallery, image_gallery]
    )

demo.launch()


Resolving data files:   0%|          | 0/72 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/45 [00:00<?, ?it/s]

* Running on local URL:  http://127.0.0.1:7869
* To create a public link, set `share=True` in `launch()`.




Traceback (most recent call last):
  File "/venv/main/lib/python3.12/site-packages/gradio/queueing.py", line 625, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/main/lib/python3.12/site-packages/gradio/route_utils.py", line 322, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/main/lib/python3.12/site-packages/gradio/blocks.py", line 2220, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/main/lib/python3.12/site-packages/gradio/blocks.py", line 1731, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/main/lib/python3.12/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^

In [23]:
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)

Rerunning server... use `close()` to stop if you need to change `launch()` parameters.
----
* Running on public URL: https://baf487016f5554af17.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


