In [None]:
import faiss
import numpy as np
import json

from import_llave import load_llave_model
from embed_utils import embed_text

# Load embeddings and metadata
E = np.load("llave_embeddings.npy")  # shape (N, D)
with open("llave_metadata.json", "r", encoding="utf-8") as f:
    META = json.load(f)

assert len(META) == len(E), "Metadata and embedding count mismatch."

# Normalize if not already
E = E.astype(np.float32)
E /= (np.linalg.norm(E, axis=1, keepdims=True) + 1e-12)

# Build FAISS index
D = E.shape[1]
index = faiss.IndexFlatIP(D)  # cosine if vectors are unit-normalized
index.add(E)

# Load model
tokenizer, model, image_processor, device = load_llave_model()

def search_query(query: str, topk: int = 5):
    vec = embed_text(query, model, tokenizer, device)  # (D,)
    vec = np.asarray(vec, dtype=np.float32)
    vec /= (np.linalg.norm(vec) + 1e-12)

    scores, indices = index.search(vec.reshape(1, -1), topk)
    for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
        print(f"{i+1:>2}. score={score:.4f}  →  {META[idx]['path']}")
        # Optional: load or display the image using META[idx]["path"]

In [None]:
q = input("a forklift lifting a box").strip()
search_query(q, topk=5)