In [None]:
!pip install -q qdrant-client pandas pyarrow numpy tqdm sentence-transformers


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/337.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m327.7/337.3 kB[0m [31m9.8 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m337.3/337.3 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import numpy as np
from qdrant_client.models import Batch
from qdrant_client.http.exceptions import UnexpectedResponse
from tqdm import trange
import numpy as np, pandas as pd, uuid, math, json, sys, traceback, re
from qdrant_client.models import Distance, VectorParams, HnswConfigDiff, OptimizersConfigDiff
import pandas as pd
from qdrant_client.models import PayloadSchemaType



In [None]:
from qdrant_client import QdrantClient

client = QdrantClient(
    url="XXXXXXX",
    api_key="XXXXXXX",
)

print(client.get_collections())

collections=[CollectionDescription(name='product_bge'), CollectionDescription(name='faq_bge')]


## Upsert product and faq collections

In [None]:
faq_parquet     = "faq_vectors_bge.parquet"
product_parquet = "product_vectors_bge.parquet"

faq_df  = pd.read_parquet(faq_parquet)
prod_df = pd.read_parquet(product_parquet)

def get_dim(df):
    assert len(df) > 0, "DataFrame trống!"
    return len(df.iloc[0]["embedding"])

DIM = get_dim(faq_df if len(faq_df) else prod_df)
print(f" DIM = {DIM}, faq={len(faq_df)} rows, product={len(prod_df)} rows")

def ensure_collection(name: str, dim: int):
    if client.collection_exists(name):
        print(f"Collection '{name}' đã tồn tại.")
        return

    client.create_collection(
        collection_name=name,
        vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
        hnsw_config=HnswConfigDiff(m=48, ef_construct=200),
        optimizers_config=OptimizersConfigDiff(indexing_threshold=20000),
    )
    print(f" Created collection: {name}")

ensure_collection("faq_bge", DIM)
ensure_collection("product_bge", DIM)


In [None]:
def as_str(x, default=""):
    if x is None:
        return default
    if isinstance(x, float) and math.isnan(x):
        return default
    return x if isinstance(x, str) else str(x)

def to_uuid5_list(series, prefix):
    vals = [as_str(v) for v in series]
    return [str(uuid.uuid5(uuid.NAMESPACE_URL, f"{prefix}:{v}")) for v in vals]

def ensure_meta_dict(m):
    if isinstance(m, dict):
        return m
    if isinstance(m, str):
        try:
            return json.loads(m)
        except json.JSONDecodeError:
            return {}
    return {}

def ensure_matrix(col):
    if isinstance(col, np.ndarray) and col.ndim == 2:
        return col.astype("float32")
    arr = np.vstack(col.to_numpy()).astype("float32")
    return arr

BRAND_RX = re.compile(r"thương\s*hiệu\s*:\s*([A-Za-z0-9À-ỹ\-\&\.\s]+)", re.I)

def extract_brand(r, meta_data):
    b = r.get("brand")
    if isinstance(b, str) and b.strip():
        return b.strip(), "top"
    b = meta_data.get("brand")
    if isinstance(b, str) and b.strip():
        return b.strip(), "meta"
    txt = as_str(r.get("text", ""))
    m = BRAND_RX.search(txt)
    if m:
        return m.group(1).strip(" ;,.\n"), "regex"
    return None, None

def extract_thumbnail(r, meta_data):
    th = r.get("thumbnail")
    if isinstance(th, str) and th.strip():
        return th.strip()
    th = meta_data.get("thumbnail") or meta_data.get("image") or meta_data.get("img")
    if isinstance(th, str) and th.strip():
        return th.strip()
    return None

def resolve_parent_uid(r, meta_data):
    pu = r.get("parent_uid")
    if isinstance(pu, str) and pu.strip():
        return pu.strip()

    id_chunk = as_str(r.get("id", ""))
    if id_chunk and id_chunk.count("::") >= 2:
        return id_chunk.rsplit("::", 1)[0]
    elif id_chunk and id_chunk.count("::") == 1:
        return id_chunk

    sid = meta_data.get("source_id")
    if sid:
        return f"prod_{sid}"
    return ""

def make_payloads(df_part, id_col="id"):
    payloads = []
    for _, r in df_part.iterrows():
        meta_data = ensure_meta_dict(r.get("meta", {}))

        parent_uid = resolve_parent_uid(r, meta_data)
        brand, brand_src = extract_brand(r, meta_data)
        thumbnail = extract_thumbnail(r, meta_data)

        external_id = as_str(r.get(id_col, ""))

        p = {
            "id":          external_id,
            "type":        as_str(r.get("type","")),
            "title":       as_str(r.get("title",""))[:512],
            "url":         as_str(r.get("url",""))[:1024],
            "category":    as_str(r.get("category",""))[:256],
            "text":        as_str(r.get("text","")),
            "parent_uid":  parent_uid,
            "brand":       brand or "",
            "thumbnail":   thumbnail or "",
            "brand_source": brand_src or "",
            "price_numeric": (meta_data.get("price_numeric")),
            "rating":        (meta_data.get("rating")),
            "review_count":  (meta_data.get("review_count")),
        }
        payloads.append(p)
    return payloads

def upsert_df_smart(client, collection, df, init_batch=512, id_col="id"):
    for col in [id_col, "embedding"]:
        if col not in df.columns:
            raise ValueError(f"Thiếu cột bắt buộc: {col}")

    vectors_all = ensure_matrix(df["embedding"])

    N = len(df)
    bs = init_batch
    start = 0
    while start < N:
        end = min(N, start + bs)
        part = df.iloc[start:end]

        ids = to_uuid5_list(part[id_col], prefix=collection)
        vecs = vectors_all[start:end]
        payloads = make_payloads(part, id_col=id_col)

        try:
            client.upsert(
                collection_name=collection,
                points=Batch(ids=ids, vectors=vecs, payloads=payloads),
                wait=True
            )
            start = end
            if bs < init_batch:
                bs = min(init_batch, int(bs * 2))
        except UnexpectedResponse as e:
            msg = str(e)
            if any(key in msg for key in [
                "larger than allowed", "Payload error",
                "Request Entity Too Large", "413", "400"
            ]):
                new_bs = max(64, bs // 2)
                print(f"Request quá lớn (window={end-start}, bs={bs}) --- giảm xuống {new_bs}", file=sys.stderr)
                if new_bs == bs:
                    raise RuntimeError("Batch vẫn vượt giới hạn.") from e
                bs = new_bs
            else:
                traceback.print_exc()
                raise

    print(f"Upsert xong: {collection} (rows={N})")


In [None]:
print("Upsert FAQ:")
upsert_df_smart(client, "faq_bge", faq_df, init_batch=512, id_col="id")

In [None]:
print("Upsert Product:")
upsert_df_smart(client,"product_bge", prod_df, init_batch=512, id_col="id")

## Create payload index

In [None]:
from qdrant_client.models import PayloadSchemaType
client.create_payload_index(
    collection_name="product_bge",
    field_name="source_id",
    field_schema=PayloadSchemaType.KEYWORD,
)

UpdateResult(operation_id=179, status=<UpdateStatus.COMPLETED: 'completed'>)

In [None]:
for field, schema in [
    ("category", PayloadSchemaType.KEYWORD),
    ("brand", PayloadSchemaType.KEYWORD),
    ("parent_uid", PayloadSchemaType.KEYWORD),
    ("type", PayloadSchemaType.KEYWORD),
    ("rating", PayloadSchemaType.FLOAT),
    ("price_numeric", PayloadSchemaType.FLOAT),
    ("review_count", PayloadSchemaType.INTEGER),
]:
    try:
        client.create_payload_index(
            collection_name="product_bge",
            field_name=field,
            field_schema=schema,
        )
        print(f"Created payload index: {field} ({schema})")
    except Exception as e:
        print(f"{field} → {e}")


Created payload index: category (keyword)
Created payload index: brand (keyword)
Created payload index: parent_uid (keyword)
Created payload index: type (keyword)
Created payload index: rating (float)
Created payload index: price_numeric (float)
Created payload index: review_count (integer)


## Create sparse vector

In [None]:
from __future__ import annotations
import re
import unicodedata
from typing import Dict, Any, Optional, List, Tuple
from collections import defaultdict
import os, json, math, re, unicodedata, pickle
import numpy as np
import pandas as pd


In [None]:
def strip_accents(s):
    return unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii")

def norm_space(s):
    return re.sub(r"\s+", " ", s).strip()

def normalize_text_for_sparse(s):
    s = strip_accents(s).lower()
    s = re.sub(r"[^\w\s\-\%\./]", " ", s)
    return norm_space(s)

def unify_number(x):
    return float(x.replace(",", ".").replace(" ", ""))

def mm_from(value, unit):
    unit = (unit or "mm").lower()
    if unit == "cm":
        return value * 10
    if unit == "m":
        return value * 1000
    return value

def g_from(value, unit):
    unit = unit.lower()
    if unit == "kg":
        return value * 1000
    return value

def capacity_ml(value, unit):
    unit = unit.lower()
    if unit == "l":
        return value * 1000
    return value

RE_DIM_VI = re.compile(
    r"(?:(?:[cclwnhdrt]\s*)?)(\d+(?:[.,]\d+)?)\s*[x×]\s*"
    r"(?:(?:[cclwnhdrt]\s*)?)(\d+(?:[.,]\d+)?)\s*[x×]\s*"
    r"(?:(?:[cclwnhdrt]\s*)?)(\d+(?:[.,]\d+)?)(?:\s*(mm|cm|m))?",
    re.I
)
RE_INCH = re.compile(r"(\d+(?:[.,]\d+)?)\s*(?:\"|inch|in)\b", re.I)
RE_INCH_RANGE = re.compile(r"(\d+(?:[.,]\d+)?)\s*[-–]\s*(\d+(?:[.,]\d+)?)\s*(?:inch|\"|\b in\b)", re.I)

RE_CAP_ML = re.compile(r"(\d+(?:[.,]\d+)?)\s*(ml|l)\b", re.I)
RE_MAH    = re.compile(r"(\d{3,6})\s*mah\b", re.I)
RE_WEIGHT = re.compile(r"(\d+(?:[.,]\d+)?)\s*(kg|g)\b", re.I)
RE_POWER  = re.compile(r"(\d+(?:[.,]\d+)?)\s*(w|kw)\b", re.I)
RE_VOLT   = re.compile(r"(\d+(?:[.,]\d+)?)\s*v\b", re.I)
RE_FREQ   = re.compile(r"(\d+(?:[.,]\d+)?)\s*hz\b", re.I)
RE_RAM     = re.compile(r"(\d{1,3})\s*gb\s*(?:ram)?\b", re.I)
RE_STORAGE = re.compile(r"(\d{1,4})\s*(tb|gb)\s*(?:ssd|hdd)?\b", re.I)

RE_RES    = re.compile(r"(\d{3,5})\s*[x×]\s*(\d{3,5})\s*(px|pixel)?", re.I)

RE_MODEL_LINE = re.compile(r"(?:ma\s*sp|m[aã]\s*sp|model(?:\s*no\.)?|sku)\s*[:\-]?\s*([A-Z0-9\-\._]{3,})", re.I)
RE_MODEL_FREE = re.compile(r"\b([A-Z0-9]{3,}[-_\.]?[A-Z0-9]{2,})\b")

RE_COLOR = re.compile(r"\b(den|đen|do|đỏ|xanh|xam|xám|trang|trắng|hong|hồng|vang|vàng|nau|nâu|tim|tím|bac|bạc|gold|silver|black|white|blue|red|green|gray)\b", re.I)
RE_MATERIAL = re.compile(r"\b(oxford|inox|nhua|nhựa|vai|vải|da|polyester|nylon|aluminum|nhom|nhôm|thep|thép|thuy\s*tinh|thủy\s*tinh)\b", re.I)

RE_MARKETING = re.compile(
    r"(cam\s*k[eê]t|dich\s*vu|doi\s*tra|khuyen\s*mai|cs[kh]|5\*|phuc\s*vu|hai\s*long|lien\s*he|cua\s*hang)",
    re.I
)
TECH_PAT = re.compile(
    r"(\d+(?:[.,]\d+)?\s*(mm|cm|m|inch|\"|ml|mah|w|hz|v|gb|tb|px)\b)"
    r"|(\b(kich\s*thuoc|ma\s*sp|m[aã]\s*sp|model|sku|chat\s*lieu|mau\s*sac|dung\s*tich|trong\s*luong|man\s*hinh|ram|rom|ssd|hdd|tan\s*so|dien\s*ap)\b)",
    re.I
)

def split_details_block(text):
    base = strip_accents(text)
    m = re.search(r"details\s*:\s*", base, re.I)
    if not m:
        return "", text
    start = m.end()
    return text[start:], text[:start]

def make_text_lite(title, text):
    details, _ = split_details_block(text)
    cand = f"{title}\n{details or text}"
    lines = [norm_space(l) for l in cand.split("\n") if l.strip()]
    out = []
    for l in lines:
        l_noacc = strip_accents(l.lower())
        if RE_MARKETING.search(l_noacc):
            continue
        if TECH_PAT.search(l_noacc):
            out.append(l)
    if not out:
        out = lines[:3]
    return ". ".join(out[:30])

def extract_slots(title, text):
    raw = f"{title}\n{text}"
    raw_noacc_lower = strip_accents(raw).lower()
    raw_norm = normalize_text_for_sparse(raw)

    slots = {}

    m = RE_DIM_VI.search(strip_accents(raw))
    if m:
        a, b, c, u = m.groups()
        a, b, c = map(unify_number, (a, b, c))
        slots["dimensions_mm"] = [mm_from(a, u), mm_from(b, u), mm_from(c, u)]
        slots["dimensions_raw"] = f"{m.group(1)}x{m.group(2)}x{m.group(3)} {u or 'mm'}"

    m = RE_INCH.search(raw_norm)
    if m:
        slots["screen_size_inch"] = unify_number(m.group(1))
    m = RE_INCH_RANGE.search(strip_accents(raw))
    if m:
        lo, hi = map(unify_number, m.groups())
        slots["compat_inch_range"] = [min(lo, hi), max(lo, hi)]

    m = RE_CAP_ML.search(raw_norm)
    if m:
        v, u = m.groups()
        slots["capacity_ml"] = capacity_ml(unify_number(v), u)
        slots["capacity_raw"] = f"{v}{u}"
    m = RE_MAH.search(raw_norm)
    if m:
        v = m.group(1)
        slots["battery_mah"] = int(unify_number(v))

    m = RE_WEIGHT.search(raw_norm)
    if m:
        v, u = m.groups()
        slots["weight_g"] = g_from(unify_number(v), u)
        slots["weight_raw"] = f"{v}{u}"

    for rex, key in [(RE_POWER, "power_w"), (RE_VOLT, "voltage_v"), (RE_FREQ, "frequency_hz")]:
        m = rex.search(raw_norm)
        if m:
            slots[key] = float(unify_number(m.group(1)))

    m = RE_RAM.search(raw_norm)
    if m:
        slots["ram_gb"] = int(unify_number(m.group(1)))
    m = RE_STORAGE.search(raw_norm)
    if m:
        v, unit = m.groups()
        val = float(unify_number(v))
        slots["storage_gb"] = val * 1024 if unit.lower() == "tb" else val

    m = RE_RES.search(raw_norm)
    if m:
        w, h, _ = m.groups()
        slots["resolution"] = f"{w}x{h}"

    m = RE_MODEL_LINE.search(strip_accents(raw))
    if m:
        slots["model"] = m.group(1).upper()
        slots["model_conf"] = "high"
    else:
        if any(k in raw_noacc_lower for k in ["ma sp", "mã sp", "model", "sku"]):
            m2 = RE_MODEL_FREE.search(strip_accents(raw).upper())
            if m2:
                slots["model"] = m2.group(1)
                slots["model_conf"] = "medium"

    colors = []
    color_line = re.search(r"(m[aà]u\s*s[aă]c\s*:\s*)(.+)", strip_accents(raw), re.I)
    if color_line:
        payload = color_line.group(2)
        for tok in re.split(r"[\/,\;\|]", payload):
            t = norm_space(strip_accents(tok).lower())
            if t:
                colors.append(t)
    if not colors:
        colors = list(set(RE_COLOR.findall(strip_accents(raw))))
    if colors:
        slots["color"] = [c.replace(" ", "_") for c in colors]

    materials = []
    mat_line = re.search(r"(ch[aă]t\s*li[eê]u\s*:\s*)(.+)", strip_accents(raw), re.I)
    if mat_line:
        payload = mat_line.group(2)
        mats = RE_MATERIAL.findall(strip_accents(payload).lower())
        if mats:
            materials = mats
        else:
            cand = re.findall(r"[a-zA-Z]{3,}", strip_accents(payload))
            materials = cand[:2]
    else:
        mats2 = RE_MATERIAL.findall(raw_noacc_lower)
        if mats2:
            materials = list(set(mats2))
    if materials:
        slots["material"] = [m.lower() for m in materials]

    feats = []
    rn = raw_noacc_lower
    if ("chong tham" in rn) or ("chong nuoc" in rn):
        feats.append("chong_tham")
    if "chong soc" in rn:
        feats.append("chong_soc")
    if ("gan len thanh keo" in rn) or ("gan vali" in rn):
        feats.append("gan_vali")
    if "dem to ong" in rn:
        feats.append("dem_to_ong")
    if feats:
        slots["features"] = feats

    return slots

def make_kv_compact(slots, brand=None, category_norm=None):
    parts = []
    if brand:
        parts.append(f"brand:{strip_accents(brand).lower()}")
    if category_norm:
        parts.append(f"category:{strip_accents(category_norm).lower()}")
    if "model" in slots:
        parts.append(f"model:{slots['model']}")
    if "dimensions_mm" in slots:
        L, W, H = [int(v) if float(v).is_integer() else round(float(v), 1) for v in slots["dimensions_mm"]]
        parts.append(f"size_mm:{L}x{W}x{H}")
    if "compat_inch_range" in slots:
        lo, hi = slots["compat_inch_range"]
        parts.append(f"compat_inch:{lo}-{hi}")
    if "screen_size_inch" in slots:
        parts.append(f"screen:{slots['screen_size_inch']:.1f}inch")
    if "capacity_ml" in slots:
        parts.append(f"capacity_ml:{int(slots['capacity_ml'])}")
    if "battery_mah" in slots:
        parts.append(f"battery:{int(slots['battery_mah'])}mah")
    if "weight_g" in slots:
        parts.append(f"weight_g:{int(slots['weight_g'])}")
    if "power_w" in slots:
        parts.append(f"power_w:{slots['power_w']}")
    if "voltage_v" in slots:
        parts.append(f"voltage_v:{slots['voltage_v']}")
    if "frequency_hz" in slots:
        parts.append(f"freq_hz:{slots["frequency_hz"]}")
    if "ram_gb" in slots:
        parts.append(f"ram:{int(slots['ram_gb'])}gb")
    if "storage_gb" in slots:
        val = float(slots["storage_gb"])
        parts.append(f"storage:{int(val) if val.is_integer() else val}gb")
    if "resolution" in slots:
        parts.append(f"res:{slots['resolution']}")
    if "material" in slots:
        parts.append("material:" + "_".join(slots["material"]))
    if "color" in slots:
        parts.append("color:" + "_".join(slots["color"]))
    if "features" in slots:
        parts.append("feat:" + "_".join(slots["features"]))
    return " | ".join(parts)

def make_text_dense_input(record, category_norm=None):
    title = record.get("title", "") or ""
    text = record.get("text", "") or ""
    brand = record.get("brand")
    catn = category_norm or record.get("category")
    slots = extract_slots(title, text)
    kv = make_kv_compact(slots, brand=brand, category_norm=catn)
    text_lite = make_text_lite(title, text)
    parts = [title]
    if brand:
        parts.append(f"Brand: {brand}")
    if catn:
        parts.append(f"Category: {catn}")
    if kv:
        parts.append(kv)
    if text_lite:
        parts.append(text_lite)
    return "\n".join([p for p in parts if p])

def make_text_sparse_input(record, category_norm=None):
    title = record.get("title", "") or ""
    text = record.get("text", "") or ""
    brand = record.get("brand")
    catn = category_norm or record.get("category")
    slots = extract_slots(title, text)
    kv = make_kv_compact(slots, brand=brand, category_norm=catn)
    text_lite = make_text_lite(title, text)
    raw = " | ".join([
        kv,
        title,
        brand or "",
        text_lite
    ])
    return normalize_text_for_sparse(raw)

def build_product_texts(record, category_norm=None):
    title = record.get("title", "") or ""
    text = record.get("text", "") or ""
    brand = record.get("brand")
    catn = category_norm or record.get("category")
    slots = extract_slots(title, text)
    kv = make_kv_compact(slots, brand=brand, category_norm=catn)
    text_lite = make_text_lite(title, text)
    dense_in = make_text_dense_input(record, category_norm=catn)
    sparse_in = make_text_sparse_input(record, category_norm=catn)
    return {
        "text_dense": dense_in,
        "text_sparse": sparse_in,
        "kv_compact": kv,
        "text_lite": text_lite,
        "slots": slots,
    }

def _uniq(seq):
    seen = set()
    out = []
    for x in seq:
        if x and x not in seen:
            seen.add(x)
            out.append(x)
    return out

def build_parent_sparse_from_chunks(rows,
                                    category_norm_map=None,
                                    title_boost=1.5, brand_boost=1.5, kv_boost=2.0,
                                    max_titles=3, max_text_lite_chars=4000):
    buckets = defaultdict(list)
    for r in rows:
        buckets[r["parent_uid"]].append(r)

    out = {}
    for pid, items in buckets.items():
        titles = _uniq([i.get("title", "") for i in items])
        texts = [i.get("text", "") for i in items]
        brand = next((i.get("brand") for i in items if i.get("brand")), "")
        cat_raw = next((i.get("category") for i in items if i.get("category")), "")
        catn = category_norm_map.get(cat_raw, cat_raw) if category_norm_map else cat_raw

        kv_all, lite_all = [], []
        for it in items:
            slots = extract_slots(it.get("title", ""), it.get("text", ""))
            kv = make_kv_compact(slots, brand=brand, category_norm=catn)
            lite = make_text_lite(it.get("title", ""), it.get("text", ""))
            if kv:
                kv_all.append(kv)
            if lite:
                lite_all.append(lite)

        kv_merged = " | ".join(_uniq(kv_all))
        lite_merged = ". ".join(_uniq(lite_all))[:max_text_lite_chars]

        title_part = (" | ".join(titles[:max_titles]) + " ") * int(title_boost)
        brand_part = ((brand or "") + " ") * int(brand_boost)
        kv_part = (kv_merged + " ") * int(kv_boost)

        raw_sparse = " | ".join([
            kv_part.strip(),
            title_part.strip(),
            brand_part.strip(),
            lite_merged
        ])

        text_sparse_parent = normalize_text_for_sparse(raw_sparse)

        out[pid] = {
            "text_sparse_parent": text_sparse_parent,
            "kv_compact": kv_merged,
            "brand": brand,
            "category_norm": catn
        }
    return out


In [None]:
import pandas as pd

df_chunks = pd.read_json(
    '/content/drive/MyDrive/tiki_chatbot/products_chunked.enriched.jsonl',
    lines=True
)

COLS_TO_STR = ["title", "text", "brand", "category", "parent_uid"]
df_chunks[COLS_TO_STR] = df_chunks[COLS_TO_STR].fillna("")


try:
    parent_sparse = build_parent_sparse_from_chunks(df_chunks.to_dict("records"))
    print("Done.")

except Exception as e:
    print(f"Lỗi: {e}")

parent_sparse = build_parent_sparse_from_chunks(df_chunks.to_dict("records"))



In [None]:
def strip_accents(s):
    return unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii")

def normalize_text_for_sparse(s):
    s = strip_accents(str(s)).lower()
    s = re.sub(r"[^\w\s\-\%\./]", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

def tokenize(s):
    s = normalize_text_for_sparse(s)
    toks = s.split()
    out = []
    for t in toks:
        if len(t) == 1 and not t.isdigit():
            continue
        out.append(t)
    return out


class BM25Okapi:
    def __init__(self, corpus_tokens, k1=1.5, b=0.75, epsilon=0.25):
        self.k1 = k1
        self.b = b
        self.epsilon = epsilon

        self.corpus_size = len(corpus_tokens)
        self.doc_len = np.array([len(doc) for doc in corpus_tokens], dtype=np.float32)
        self.avgdl = float(self.doc_len.mean()) if self.corpus_size else 0.0

        self.term_freqs = []
        self.doc_freqs = {}
        for doc in corpus_tokens:
            tf = {}
            for w in doc:
                tf[w] = tf.get(w, 0) + 1
            self.term_freqs.append(tf)
            for w in tf.keys():
                self.doc_freqs[w] = self.doc_freqs.get(w, 0) + 1

        self.idf = {}
        for w, df in self.doc_freqs.items():
            val = math.log((self.corpus_size - df + 0.5) / (df + 0.5) + 1e-9)
            if val < 0:
                val *= self.epsilon
            self.idf[w] = val

    def get_scores(self, query_tokens):
        scores = np.zeros(self.corpus_size, dtype=np.float32)
        if self.corpus_size == 0:
            return scores
        for w in query_tokens:
            if w not in self.idf:
                continue
            idf = self.idf[w]
            for i, tf in enumerate(self.term_freqs):
                f = tf.get(w, 0)
                if f == 0:
                    continue
                denom = f + self.k1 * (1 - self.b + self.b * (self.doc_len[i] / (self.avgdl or 1.0)))
                scores[i] += idf * (f * (self.k1 + 1)) / (denom + 1e-9)
        return scores

class ParentBM25Index:
    def __init__(self, k1=1.5, b=0.75, epsilon=0.25):
        self.k1 = k1
        self.b = b
        self.epsilon = epsilon

        self.parents = []
        self.meta = {}
        self.docs_tokens = []
        self._bm25 = None

    def fit(self, parent_rows):
        self.parents.clear()
        self.meta = {}
        self.docs_tokens.clear()

        for pid, info in parent_rows.items():
            ts = (info.get("text_sparse_parent") or "").strip()
            if not ts:
                continue
            self.parents.append(pid)
            self.meta[pid] = info
            self.docs_tokens.append(tokenize(ts))

        self._bm25 = BM25Okapi(self.docs_tokens, k1=self.k1, b=self.b, epsilon=self.epsilon)

    def is_ready(self):
        return (self._bm25 is not None) and (len(self.parents) == len(self.docs_tokens) > 0)

    def search(self, query, topk=20):
        if not self.is_ready():
            return pd.DataFrame(columns=["parent_uid", "score", "brand", "category_norm", "kv_compact"])

        q_tokens = tokenize(query)
        scores = self._bm25.get_scores(q_tokens)

        k = min(topk, len(scores))
        idx = np.argpartition(-scores, kth=k-1)[:k]
        idx = idx[np.argsort(-scores[idx])]

        rows = []
        for i in idx:
            pid = self.parents[i]
            m = self.meta.get(pid, {})
            rows.append({
                "parent_uid": pid,
                "score": float(scores[i]),
                "brand": m.get("brand"),
                "category_norm": m.get("category_norm"),
                "kv_compact": m.get("kv_compact"),
            })
        return pd.DataFrame(rows)

    def save(self, folder):
        os.makedirs(folder, exist_ok=True)
        with open(os.path.join(folder, "parents.pkl"), "wb") as f:
            pickle.dump(self.parents, f)
        with open(os.path.join(folder, "docs_tokens.pkl"), "wb") as f:
            pickle.dump(self.docs_tokens, f)
        with open(os.path.join(folder, "meta.json"), "w", encoding="utf-8") as f:
            json.dump(self.meta, f, ensure_ascii=False, indent=2)
        with open(os.path.join(folder, "params.json"), "w", encoding="utf-8") as f:
            json.dump({"k1": self.k1, "b": self.b, "epsilon": self.epsilon}, f)

    @classmethod
    def load(cls, folder):
        with open(os.path.join(folder, "params.json"), "r", encoding="utf-8") as f:
            p = json.load(f)
        obj = cls(k1=p.get("k1", 1.5), b=p.get("b", 0.75), epsilon=p.get("epsilon", 0.25))
        with open(os.path.join(folder, "parents.pkl"), "rb") as f:
            obj.parents = pickle.load(f)
        with open(os.path.join(folder, "docs_tokens.pkl"), "rb") as f:
            obj.docs_tokens = pickle.load(f)
        with open(os.path.join(folder, "meta.json"), "r", encoding="utf-8") as f:
            obj.meta = json.load(f)
        obj._bm25 = BM25Okapi(obj.docs_tokens, k1=obj.k1, b=obj.b, epsilon=obj.epsilon)
        return obj


In [None]:
pidx = ParentBM25Index(k1=1.3, b=0.72)
pidx.fit(parent_sparse)

pidx.save("bm25_parent_index")

