In [None]:
# ===============================================================
# Ingredient Search ‚Äî Cards & Multi-Vector Index
# ===============================================================

import os
import json
import time
import re
import unicodedata
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from rapidfuzz import fuzz

# Optional FAISS acceleration
try:
    import faiss
    FAISS_AVAILABLE = True
except ImportError:
    FAISS_AVAILABLE = False


# -----------------------
# Config
# -----------------------
# Choose: "cards" (one vector per entity) or "multivector" (many vectors per entity)
MODE = "multivector"   # "cards" or "multivector"

# Input data produced by the preprocessing you ran earlier
CARDS_PATH    = "novel_foods_cards.csv"           # columns: policy_item_id, canonical, entity_text
MULTIV_PATH   = "novel_foods_multivectors.csv"    # columns: policy_item_id, section, language, text

# Model (multilingual)
MODEL_NAME    = "sentence-transformers/distiluse-base-multilingual-cased-v2"
# Alternatives:
# MODEL_NAME  = "paraphrase-multilingual-MiniLM-L12-v2"
# MODEL_NAME  = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"

# Embeddings / index cache
CACHE_DIR     = "indices_v2"
os.makedirs(CACHE_DIR, exist_ok=True)

EMB_CARDS     = os.path.join(CACHE_DIR, "emb_cards.npy")
IDX_CARDS     = os.path.join(CACHE_DIR, "index_cards.faiss")
LOOKUP_CARDS  = os.path.join(CACHE_DIR, "lookup_cards.csv")
META_CARDS    = os.path.join(CACHE_DIR, "meta_cards.json")

EMB_MULTI     = os.path.join(CACHE_DIR, "emb_multi.npy")
IDX_MULTI     = os.path.join(CACHE_DIR, "index_multi.faiss")
LOOKUP_MULTI  = os.path.join(CACHE_DIR, "lookup_multi.csv")
META_MULTI    = os.path.join(CACHE_DIR, "meta_multi.json")

# Search hyperparams
TOP_K_DEFAULT = 5        # final results to return
RECALL_K      = 200      # how many candidates to pull from ANN/sims before aggregation
ALPHA_SEM     = 0.75     # semantic vs. lexical blend
MIN_CONFIDENCE= 0.50     # minimum blended score to keep

# Optional: section boosts for multivector aggregation
SECTION_BOOST = {
    "CANON_LAT": 1.05,
    "CANON_EN":  1.00,
    "SYN_LAT":   0.95,
    # COMMON_XX will default to 1.00; you can add specific boosts like "COMMON_DE": 1.02, etc.
}


# ===============================================================
# Helpers
# ===============================================================
def normalize_query_lex(s: str) -> str:
    """Light normalization for fuzzy matching (lexical). Keep diacritics? We fold to be robust."""
    if not isinstance(s, str):
        return ""
    s = unicodedata.normalize("NFKC", s.casefold())
    s = re.sub(r"[\u2212\u2010-\u2015]", "-", s)
    s = re.sub(r"[^a-z0-9 \-\u00C0-\u017F]", " ", s)  # keep basic Latin-1 letters
    s = re.sub(r"\s+", " ", s).strip()
    return s

def ensure_cols(df: pd.DataFrame, cols):
    missing = [c for c in cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns: {missing}")


# ===============================================================
# Data loaders
# ===============================================================
def load_cards():
    df = pd.read_csv(CARDS_PATH, dtype={"policy_item_id": str})
    ensure_cols(df, ["policy_item_id", "canonical", "entity_text"])
    # Lookup used when returning results
    lookup = df[["policy_item_id", "canonical", "entity_text"]].copy()
    return df, lookup

def load_multivectors():
    mv = pd.read_csv(MULTIV_PATH, dtype={"policy_item_id": str})
    ensure_cols(mv, ["policy_item_id", "section", "language", "text"])
    # We also want a canonical name map for pretty printing
    if os.path.exists(CARDS_PATH):
        cards = pd.read_csv(CARDS_PATH, dtype={"policy_item_id": str})
        can_map = cards[["policy_item_id", "canonical"]].drop_duplicates()
    else:
        can_map = pd.DataFrame(columns=["policy_item_id","canonical"])
    lookup = mv.merge(can_map, on="policy_item_id", how="left")
    return mv, lookup


# ===============================================================
# Build or load indices (cards)
# ===============================================================
def build_or_load_cards_index():
    print("‚öôÔ∏è Loading model:", MODEL_NAME)
    model = SentenceTransformer(MODEL_NAME)

    df, lookup = load_cards()
    texts = df["entity_text"].astype(str).tolist()

    use_cache = (
        os.path.exists(EMB_CARDS) and os.path.exists(LOOKUP_CARDS) and os.path.exists(META_CARDS)
    )
    if use_cache:
        try:
            with open(META_CARDS, "r", encoding="utf-8") as f:
                meta = json.load(f)
            if meta.get("row_count") == len(texts) and meta.get("model") == MODEL_NAME:
                print("üîÅ Loading cached card embeddings‚Ä¶")
                emb = np.load(EMB_CARDS)
                df_lookup = pd.read_csv(LOOKUP_CARDS, dtype={"policy_item_id": str})
                index = None
                if FAISS_AVAILABLE and os.path.exists(IDX_CARDS):
                    index = faiss.read_index(IDX_CARDS)
                    print("‚úÖ FAISS (cards) loaded.")
                else:
                    print("‚ö†Ô∏è FAISS not available; using cosine similarity for cards.")
                return model, emb, index, df_lookup
        except Exception:
            print("‚ôªÔ∏è Cache mismatch; rebuilding card index.")

    print("‚öôÔ∏è Encoding card texts‚Ä¶")
    t0 = time.time()
    emb = model.encode(texts, show_progress_bar=True, normalize_embeddings=True)
    print(f"‚è±Ô∏è Embedded {len(texts)} cards in {time.time()-t0:.1f}s")

    np.save(EMB_CARDS, emb)
    lookup.to_csv(LOOKUP_CARDS, index=False)
    with open(META_CARDS, "w", encoding="utf-8") as f:
        json.dump({"model": MODEL_NAME, "row_count": len(texts)}, f)

    if FAISS_AVAILABLE:
        index = faiss.IndexFlatIP(emb.shape[1])
        index.add(np.array(emb, dtype="float32"))
        faiss.write_index(index, IDX_CARDS)
        print("‚úÖ FAISS (cards) built & saved.")
    else:
        index = None
        print("‚ö†Ô∏è FAISS not installed ‚Äî using cosine similarity at query time (cards).")

    return model, emb, index, lookup


# ===============================================================
# Build or load indices (multi-vector)
# ===============================================================
def build_or_load_multivector_index():
    print("‚öôÔ∏è Loading model:", MODEL_NAME)
    model = SentenceTransformer(MODEL_NAME)

    mv, lookup = load_multivectors()
    texts = mv["text"].astype(str).tolist()

    use_cache = (
        os.path.exists(EMB_MULTI) and os.path.exists(LOOKUP_MULTI) and os.path.exists(META_MULTI)
    )
    if use_cache:
        try:
            with open(META_MULTI, "r", encoding="utf-8") as f:
                meta = json.load(f)
            if meta.get("row_count") == len(texts) and meta.get("model") == MODEL_NAME:
                print("üîÅ Loading cached multivector embeddings‚Ä¶")
                emb = np.load(EMB_MULTI)
                df_lookup = pd.read_csv(LOOKUP_MULTI, dtype={"policy_item_id": str})
                index = None
                if FAISS_AVAILABLE and os.path.exists(IDX_MULTI):
                    index = faiss.read_index(IDX_MULTI)
                    print("‚úÖ FAISS (multivector) loaded.")
                else:
                    print("‚ö†Ô∏è FAISS not available; using cosine similarity for multivector.")
                return model, emb, index, df_lookup
        except Exception:
            print("‚ôªÔ∏è Cache mismatch; rebuilding multivector index.")

    print("‚öôÔ∏è Encoding multivector texts‚Ä¶")
    t0 = time.time()
    emb = model.encode(texts, show_progress_bar=True, normalize_embeddings=True)
    print(f"‚è±Ô∏è Embedded {len(texts)} rows in {time.time()-t0:.1f}s")

    np.save(EMB_MULTI, emb)
    lookup.to_csv(LOOKUP_MULTI, index=False)
    with open(META_MULTI, "w", encoding="utf-8") as f:
        json.dump({"model": MODEL_NAME, "row_count": len(texts)}, f)

    if FAISS_AVAILABLE:
        index = faiss.IndexFlatIP(emb.shape[1])
        index.add(np.array(emb, dtype="float32"))
        faiss.write_index(index, IDX_MULTI)
        print("‚úÖ FAISS (multivector) built & saved.")
    else:
        index = None
        print("‚ö†Ô∏è FAISS not installed ‚Äî using cosine similarity at query time (multivector).")

    return model, emb, index, lookup


# ===============================================================
# Search (cards)
# ===============================================================
def search_cards(query, model, emb, index, df_lookup, top_k=TOP_K_DEFAULT):
    q_emb = model.encode([query], normalize_embeddings=True)

    # ANN / sims
    if FAISS_AVAILABLE and index is not None:
        scores, idx = index.search(np.array(q_emb, dtype="float32"), min(RECALL_K, len(df_lookup)))
        idx, scores = idx[0], scores[0]
    else:
        sims = cosine_similarity(q_emb, emb)[0]
        idx = np.argsort(sims)[::-1][:min(RECALL_K, len(df_lookup))]
        scores = sims[idx]

    # Blend with lexical on canonical + entity_text
    results = []
    q_norm = normalize_query_lex(query)
    for i, s in zip(idx, scores):
        row = df_lookup.iloc[i]
        canon = str(row.get("canonical", ""))
        blob  = str(row.get("entity_text", ""))

        # lexical score against canonical and a shorter slice of the blob (avoid huge text bias)
        lex1 = fuzz.token_set_ratio(q_norm, normalize_query_lex(canon)) / 100
        lex2 = fuzz.partial_ratio(q_norm, normalize_query_lex(blob[:500])) / 100
        lex  = max(lex1, lex2)

        final = ALPHA_SEM * float(s) + (1 - ALPHA_SEM) * lex
        if final >= MIN_CONFIDENCE:
            results.append({
                "policy_item_id": str(row["policy_item_id"]),
                "canonical": canon,
                "best_text": canon,
                "section": "CARD",
                "language": "",
                "semantic": round(float(s), 3),
                "lexical": round(lex, 3),
                "score": round(final, 3),
            })

    # De-dup by entity and keep best
    best_by_ent = {}
    for r in results:
        pid = r["policy_item_id"]
        if (pid not in best_by_ent) or (r["score"] > best_by_ent[pid]["score"]):
            best_by_ent[pid] = r

    out = sorted(best_by_ent.values(), key=lambda x: x["score"], reverse=True)
    return out[:top_k]


# ===============================================================
# Search (multivector)
# ===============================================================
def section_boost(section: str) -> float:
    if section in SECTION_BOOST:
        return SECTION_BOOST[section]
    if section.startswith("COMMON_"):
        return 1.00
    return 1.00

def search_multivector(query, model, emb, index, df_lookup, top_k=TOP_K_DEFAULT):
    q_emb = model.encode([query], normalize_embeddings=True)

    # ANN / sims across ALL name texts
    if FAISS_AVAILABLE and index is not None:
        scores, idx = index.search(np.array(q_emb, dtype="float32"), min(RECALL_K, len(df_lookup)))
        idx, scores = idx[0], scores[0]
    else:
        sims = cosine_similarity(q_emb, emb)[0]
        idx = np.argsort(sims)[::-1][:min(RECALL_K, len(df_lookup))]
        scores = sims[idx]

    # Row-level blend + per-entity aggregation (max)
    q_norm = normalize_query_lex(query)
    hits = []
    for i, s in zip(idx, scores):
        row = df_lookup.iloc[i]
        text = str(row.get("text", ""))
        canon = str(row.get("canonical", ""))
        section = str(row.get("section", ""))
        lang = str(row.get("language", ""))

        lex1 = fuzz.token_set_ratio(q_norm, normalize_query_lex(text)) / 100
        lex2 = fuzz.token_set_ratio(q_norm, normalize_query_lex(canon)) / 100 if canon else 0.0
        lex  = max(lex1, lex2)

        # Section boost helps e.g., CANON_LAT/COMMON_XX
        boosted_sem = float(s) * section_boost(section)
        final = ALPHA_SEM * boosted_sem + (1 - ALPHA_SEM) * lex

        hits.append({
            "policy_item_id": str(row["policy_item_id"]),
            "canonical": canon,
            "best_text": text,
            "section": section,
            "language": lang,
            "semantic": round(float(s), 3),
            "lexical": round(lex, 3),
            "score": round(final, 3),
        })

    # Aggregate by entity (keep best scoring row per entity)
    best_by_ent = {}
    for h in hits:
        pid = h["policy_item_id"]
        if (pid not in best_by_ent) or (h["score"] > best_by_ent[pid]["score"]):
            best_by_ent[pid] = h

    out = sorted(best_by_ent.values(), key=lambda x: x["score"], reverse=True)
    # Confidence filter
    out = [r for r in out if r["score"] >= MIN_CONFIDENCE]
    return out[:top_k]


# ===============================================================
# CLI
# ===============================================================
def main():
    print(f"üß† Ingredient Search ‚Äî Mode: {MODE}")

    if MODE == "cards":
        model, emb, index, df_lookup = build_or_load_cards_index()
        search_fn = search_cards
    elif MODE == "multivector":
        model, emb, index, df_lookup = build_or_load_multivector_index()
        search_fn = search_multivector
    else:
        raise ValueError("MODE must be 'cards' or 'multivector'")

    print("\n‚úÖ Ready. Type any ingredient name (or 'exit' to quit).")
    while True:
        try:
            query = input("\nüîç Enter ingredient name: ").strip()
        except (EOFError, KeyboardInterrupt):
            query = "exit"

        if query.lower() == "exit":
            print("üëã Exiting. Goodbye!")
            break

        results = search_fn(query, model, emb, index, df_lookup, top_k=TOP_K_DEFAULT)

        print(f"\nResults for '{query}':")
        print("=" * 70)
        if not results:
            print("No confident match (below threshold). Try another term.")
        else:
            for r in results:
                print(
                    f"Entity: {r['canonical']} (ID: {r['policy_item_id']})\n"
                    f"  Best match text: {r['best_text']}\n"
                    f"  Section/Lang: {r['section']} / {r['language']}\n"
                    f"  Scores ‚Üí semantic: {r['semantic']}, lexical: {r['lexical']}, final: {r['score']}\n"
                    f"{'-'*70}"
                )


if __name__ == "__main__":
    main()


  from .autonotebook import tqdm as notebook_tqdm


üß† Ingredient Search ‚Äî Mode: multivector
‚öôÔ∏è Loading model: sentence-transformers/distiluse-base-multilingual-cased-v2




‚öôÔ∏è Encoding multivector texts‚Ä¶


Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 57/57 [01:23<00:00,  1.47s/it]

‚è±Ô∏è Embedded 1821 rows in 83.8s
‚úÖ FAISS (multivector) built & saved.

‚úÖ Ready. Type any ingredient name (or 'exit' to quit).






Results for 'vit A':
Entity: Betaine (ID: 1673130)
  Best match text: N
  Section/Lang: SYN_UNK / UNK
  Scores ‚Üí semantic: 0.747, lexical: 0.333, final: 0.643
----------------------------------------------------------------------
Entity: Eleocharis dulcis (ID: 685187)
  Best match text: E. indica
  Section/Lang: SYN_UNK / UNK
  Scores ‚Üí semantic: 0.742, lexical: 0.308, final: 0.633
----------------------------------------------------------------------
Entity: Ligusticum striatum (ID: 1948733)
  Best match text: K.Y.Pan
  Section/Lang: SYN_UNK / UNK
  Scores ‚Üí semantic: 0.609, lexical: 0.333, final: 0.54
----------------------------------------------------------------------
Entity: Bambusa spp. (ID: 702102)
  Best match text: ŒúœÄŒ±ŒºœÄŒøœç (EL)
  Section/Lang: SYN_UNK / UNK
  Scores ‚Üí semantic: 0.607, lexical: 0.25, final: 0.518
----------------------------------------------------------------------
Entity: Acer nigrum (ID: 677721)
  Best match text: Acer nigrum
  Section/Lang: