In [1]:
"""
    Wyszukiwanie obrazów po promptcie CLIP (text->image) na bazie gotowych embeddingów.

    Wymaga:
    - CSV z kolumnami:
        "Cutout Image Path"
        "openai/clip-vit-base-patch32_Embedding" (JSON string list[float], znormalizowane lub nie)
    - transformers (CLIPModel, CLIPProcessor)

    UI:
    - pole tekstowe (prompt)
    - przycisk "Szukaj"
    - wynik: TOP 20 miniatur + similarity
"""

from pathlib import Path
import json
import numpy as np
import pandas as pd
import torch
import ipywidgets as W
from IPython.display import display
from PIL import Image
import io
from transformers import CLIPProcessor, CLIPModel

# --- ścieżki / kolumny ---
csv_path = Path("/Users/olga/MetaLogic/outputs/cutouts_with_embeddings.csv")
IMG_COL = "Cutout Image Path"
EMB_COL = "openai/clip-vit-base-patch32_Embedding"

# --- model ---
MODEL_ID = "openai/clip-vit-base-patch32"
device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
processor = CLIPProcessor.from_pretrained(MODEL_ID)
model = CLIPModel.from_pretrained(MODEL_ID).to(device)
model.eval()

# --- wczytanie embeddingów obrazów (raz) ---
df = pd.read_csv(csv_path)
emb_list = df[EMB_COL].apply(json.loads).tolist()
E = np.vstack(emb_list).astype(np.float32)  # (N, 512)

# dla bezpieczeństwa: normalizacja (jeśli już były znormalizowane, nic nie psuje)
E /= (np.linalg.norm(E, axis=1, keepdims=True) + 1e-12)

def _thumb_bytes(p: Path, size=(520, 520)):
    img = Image.open(p).convert("RGB")
    img.thumbnail(size)
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=85)
    return buf.getvalue()

# --- UI ---
txt = W.Text(value="samochód syrena na ulicy", description="Prompt:", layout=W.Layout(width="900px"))
btn = W.Button(description="Szukaj (CLIP)", button_style="primary")
out = W.Output()

def _search(_):
    prompt = (txt.value or "").strip()
    with out:
        out.clear_output()
        if not prompt:
            print("Wpisz prompt.")
            return

        # embedding tekstu
        with torch.no_grad():
            inputs = processor(text=[prompt], return_tensors="pt", padding=True).to(device)
            t = model.get_text_features(**inputs)
            t = t / t.norm(dim=-1, keepdim=True)
            t = t.detach().cpu().float().numpy().reshape(-1)  # (512,)

        # cosine similarity = dot product (po normalizacji)
        sims = E @ t  # (N,)
        top_idx = np.argsort(-sims)[:20]

        df_top = df.iloc[top_idx].copy()
        df_top["Similarity"] = sims[top_idx]

        # wizualizacja: 2 kolumny, 10 wierszy
        cards = []
        for rank, (irow, row) in enumerate(df_top.iterrows(), start=1):
            p = Path(str(row[IMG_COL]))
            sim = float(row["Similarity"])
            try:
                im = W.Image(value=_thumb_bytes(p), format="jpeg", width=520, height=520)
            except Exception:
                im = W.Label(f"Nie można otworzyć: {p}")

            title = W.HTML(f"<b>{rank:02d}. sim={sim:.3f} — {p.name}</b><br><code>{p}</code>")
            cards.append(W.VBox([title, im], layout=W.Layout(border="1px solid #ddd", padding="8px")))

        rows = []
        for r in range(0, len(cards), 2):
            rows.append(W.HBox(cards[r:r+2], layout=W.Layout(gap="12px")))
        display(W.VBox(rows, layout=W.Layout(gap="12px")))

btn.on_click(_search)
display(W.HBox([txt, btn]), out)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


HBox(children=(Text(value='samochód syrena na ulicy', description='Prompt:', layout=Layout(width='900px')), Bu…

Output()