In [1]:
import gradio as gr
from PIL import Image
import torch
import numpy as np
import faiss
import json

from transformers import (
    GitProcessor,
    GitForCausalLM,
    AutoTokenizer,
    AutoModelForCausalLM,
    CLIPProcessor,
    CLIPModel
)
from sentence_transformers import SentenceTransformer
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

In [3]:
git_processor = GitProcessor.from_pretrained("microsoft/git-large")
git_model = GitForCausalLM.from_pretrained("microsoft/git-large").to(device).eval()

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
tokenizer_llama = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model_llama = AutoModelForCausalLM.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
).eval()

In [5]:
text_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

In [6]:
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval()
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [7]:
wikiart_dataset = load_dataset("huggan/wikiart", split="train")

with open("../join_jsons/wikiart_10000_combined.json", "r", encoding="utf-8") as f:
    data = json.load(f)

In [8]:
image_index = faiss.read_index("../create_index/image_index_llama.faiss")
text_index = faiss.read_index("../create_index/text_index_llama.faiss")

In [9]:
def clean_caption(text):
    return text.replace("[ unused0 ]", "").strip()

In [10]:
def generate_captions(image: Image.Image):
    inputs = git_processor(images=image, return_tensors="pt")["pixel_values"].to(device)

    captions = []
    with torch.no_grad():
        deterministic_ids = git_model.generate(
            pixel_values=inputs,
            max_new_tokens=30,
            do_sample=False
        )
        captions.append(clean_caption(git_processor.tokenizer.decode(deterministic_ids[0], skip_special_tokens=True)))

        sampled_ids = git_model.generate(
            pixel_values=inputs,
            max_new_tokens=30,
            do_sample=True,
            top_k=100,
            temperature=0.8,
            num_return_sequences=2
        )
        sampled = git_processor.tokenizer.batch_decode(sampled_ids, skip_special_tokens=True)
        captions.extend([clean_caption(c) for c in sampled])

    return captions

In [11]:
def refine_caption(base, desc1, desc2):
    prompt = f"""
Given the base caption that is true and factual:
\"{base}\"

And two descriptive captions:
1) {desc1}
2) {desc2}

Write a short, coherent description that is faithful to the base caption but incorporates descriptive elements from captions 1 and 2 without contradicting the original meaning.
"""
    inputs = tokenizer_llama(prompt, return_tensors="pt").to(model_llama.device)
    with torch.no_grad():
        output = model_llama.generate(**inputs, max_new_tokens=100, do_sample=False)
        text = tokenizer_llama.decode(output[0], skip_special_tokens=True)
        answer = text[len(prompt):].strip()
        for prefix in ["Example:", "example:"]:
            if answer.startswith(prefix):
                answer = answer[len(prefix):].strip()
        return answer

In [12]:
def get_text_embedding(text):
    emb = text_encoder.encode([text], normalize_embeddings=False).astype("float32")
    faiss.normalize_L2(emb)
    return emb

In [13]:
def get_image_embedding(image):
    inputs = clip_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        image_features = clip_model.get_image_features(**inputs)
    emb = image_features.cpu().numpy().astype("float32")
    faiss.normalize_L2(emb)
    return emb

In [14]:
def get_results_with_images(embedding, index, top_k=5):
    D, I = index.search(embedding, top_k)
    results = []
    for idx in I[0]:
        try:
            idx_int = int(idx)
            item = wikiart_dataset[idx_int]
            img = item["image"]
            caption = f"ID: {idx_int}"
            results.append((img, caption))
        except IndexError:
            continue
    return results

In [15]:
def search_similar_images(image: Image.Image):
    captions = generate_captions(image)
    refined = refine_caption(captions[0], captions[1], captions[2])

    text_emb = get_text_embedding(refined)
    image_emb = get_image_embedding(image)

    text_results = get_results_with_images(text_emb, text_index)
    image_results = get_results_with_images(image_emb, image_index)

    return refined, text_results, image_results

In [16]:
with gr.Blocks(title="🎨 Semantic WikiArt Search (BLIP + CLIP)") as demo:
    gr.Markdown("## Semantic WikiArt Search\nЗагрузите изображение и найдите похожие по описанию и изображению.")

    input_image = gr.Image(label="📥 Входное изображение", type="pil")

    caption_output = gr.Textbox(label="📜 Сгенерированное описание")

    gr.Markdown("### 🔍 Похожие по описанию (текстовое сходство)")
    text_gallery = gr.Gallery(columns=5, label="По описанию", height="auto")

    gr.Markdown("### 🎨 Похожие по изображению (визуальное сходство)")
    image_gallery = gr.Gallery(columns=5, label="По изображению", height="auto")

    def wrapper(image):
        caption, text_results, image_results = search_similar_images(image)
        return caption, text_results, image_results

    input_image.change(
        fn=wrapper,
        inputs=input_image,
        outputs=[caption_output, text_gallery, image_gallery]
    )

demo.launch()

* Running on local URL:  http://127.0.0.1:7870
* To create a public link, set `share=True` in `launch()`.




Traceback (most recent call last):
  File "/opt/homebrew/Cellar/jupyterlab/4.4.0/libexec/lib/python3.13/site-packages/gradio/queueing.py", line 625, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<5 lines>...
    )
    ^
  File "/opt/homebrew/Cellar/jupyterlab/4.4.0/libexec/lib/python3.13/site-packages/gradio/route_utils.py", line 322, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<11 lines>...
    )
    ^
  File "/opt/homebrew/Cellar/jupyterlab/4.4.0/libexec/lib/python3.13/site-packages/gradio/blocks.py", line 2220, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<8 lines>...
    )
    ^
  File "/opt/homebrew/Cellar/jupyterlab/4.4.0/libexec/lib/python3.13/site-packages/gradio/blocks.py", line 1731, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignor

In [17]:
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)

Rerunning server... use `close()` to stop if you need to change `launch()` parameters.
----


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


* Running on public URL: https://619f08b350bae78de8.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


