# VERSION 1: WITHOUT TEXT INDEX (Pure Vector Search - FASTEST)

In [None]:
# qdrant_flat_adapter_NO_TEXT_INDEX.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import sys
import time
import uuid
import logging
import pickle
import numpy as np
import pandas as pd
from qdrant_client import QdrantClient
from qdrant_client.http import models as qm


In [None]:
# Load embeddings
with open('resources/embeddings/biochirp_embeddings.pkl', 'rb') as f:
    biochirp_embeddings = pickle.load(f)


In [None]:

# =========================
# Logging
# =========================

def setup_logging(level: int = logging.INFO) -> logging.Logger:
    logger = logging.getLogger("qdrant_upload")
    if not logger.handlers:
        handler = logging.StreamHandler(stream=sys.stdout)
        handler.setFormatter(logging.Formatter("%(asctime)s | %(levelname)s | %(message)s"))
        logger.addHandler(handler)
    logger.setLevel(level)
    return logger

LOG = setup_logging(logging.INFO)

In [None]:
# @dataclass
# class QdrantConfig:
    
#     url: str = "http://localhost:6333"
#     distance: qm.Distance = qm.Distance.COSINE
#     create_payload_indexes: bool = True
#     index_text_field: bool = False  # ? DISABLED for maximum speed
    
#     # HNSW optimization parameters
#     hnsw_m: int = 16                # edges per node (16 is balanced)
#     hnsw_ef_construct: int = 100    # construction quality
#     hnsw_ef: int = 64               # search quality (runtime tunable)

# def get_client(cfg: QdrantConfig) -> QdrantClient:
#     """Prefer gRPC for heavy ingest; fall back to HTTP if 6334 is closed."""
#     try:
#         return QdrantClient(
#             host="localhost", port=6333, grpc_port=6334,
#             prefer_grpc=True, timeout=300.0
#         )
#     except Exception:
#         LOG.warning("gRPC unavailable, falling back to HTTP")
#         return QdrantClient(host="localhost", port=6333, prefer_grpc=False, timeout=300.0)

In [None]:
@dataclass
class QdrantConfig:
    host: str = "localhost"  # <-- Change from "localhost" to your Docker service/container name
    port: int = 6333
    grpc_port: int = 6334
    distance: qm.Distance = qm.Distance.COSINE
    create_payload_indexes: bool = True
    index_text_field: bool = False  # ? DISABLED for maximum speed
    
    # HNSW optimization parameters
    hnsw_m: int = 16                # edges per node (16 is balanced)
    hnsw_ef_construct: int = 100    # construction quality
    hnsw_ef: int = 64               # search quality (runtime tunable)


def get_client(cfg: QdrantConfig) -> QdrantClient:
    """Prefer gRPC for heavy ingest; fall back to HTTP if 6334 is closed."""
    try:
        return QdrantClient(
            host="localhost",
            port=cfg.port,
            grpc_port=cfg.grpc_port,
            prefer_grpc=True,
            timeout=300.0
        )
    except Exception:
        LOG.warning("gRPC unavailable, falling back to HTTP")
        return QdrantClient(
            host="localhost",
            port=cfg.port,
            prefer_grpc=False,
            timeout=300.0
        )

# def get_client(cfg: QdrantConfig) -> QdrantClient:
#     """Prefer gRPC for heavy ingest; fall back to HTTP if 6334 is closed."""
#     try:
#         return QdrantClient(
#             host="localhost", port=6333, grpc_port=6334,
#             prefer_grpc=True, timeout=300.0
#         )
#     except Exception:
#         LOG.warning("gRPC unavailable, falling back to HTTP")
#         return QdrantClient(host="localhost", port=6333, prefer_grpc=False, timeout=300.0)

In [None]:




# =========================
# Config / Client (NO TEXT INDEX)
# =========================



def model_to_collection(model_name: str) -> str:
    return f"emb_{model_name.replace('/', '_')}"

# =========================
# Deterministic Point IDs
# =========================

QDRANT_ID_NAMESPACE = uuid.UUID("3e0d7e58-5c5e-4c5e-9a0b-7a6be3f2b9aa")

def make_point_uuid(model: str, db: str, field: str, i: int) -> str:
    name = f"{model}|{db}|{field}|{i}"
    return str(uuid.uuid5(QDRANT_ID_NAMESPACE, name))

# =========================
# Collection / Index setup (NO TEXT INDEX)
# =========================

def ensure_collection_for_model(
    client: QdrantClient,
    model_name: str,
    dim: int,
    cfg: QdrantConfig,
) -> str:
    coll = model_to_collection(model_name)
    if not client.collection_exists(coll):
        LOG.info(f"[{coll}] creating collection (dim={dim}, distance={cfg.distance})")
        client.create_collection(
            collection_name=coll,
            vectors_config=qm.VectorParams(size=dim, distance=cfg.distance),
            hnsw_config=qm.HnswConfigDiff(
                m=cfg.hnsw_m,
                ef_construct=cfg.hnsw_ef_construct,
                full_scan_threshold=10000,
            ),
            optimizers_config=qm.OptimizersConfigDiff(
                indexing_threshold=20000,
            ),
        )
        
        # Only create KEYWORD indexes for filtered fields
        if cfg.create_payload_indexes:
            for field, schema in [
                ("model", qm.PayloadSchemaType.KEYWORD),
                ("db",    qm.PayloadSchemaType.KEYWORD),
                ("field", qm.PayloadSchemaType.KEYWORD),
            ]:
                LOG.info(f"[{coll}] creating payload index on '{field}' ({schema})")
                client.create_payload_index(coll, field_name=field, field_schema=schema)
            
            # ? NO TEXT INDEX - Skipped for maximum retrieval speed
            LOG.info(f"[{coll}] TEXT index on 'text' field: DISABLED (pure vector search)")
    else:
        LOG.info(f"[{coll}] exists (skip create)")
    return coll

# =========================
# Dimension inference
# =========================

def infer_dim_by_model_flat(a: Dict[str, Any]) -> Dict[str, int]:
    dims: Dict[str, int] = {}
    for model, dbs in a.items():
        seen = set()
        for db in dbs.values():
            for field in db.values():
                arr = field.get("embeddings") if isinstance(field, dict) else None
                if isinstance(arr, np.ndarray) and arr.ndim == 2:
                    seen.add(int(arr.shape[1]))
        if not seen:
            raise ValueError(f"No embeddings found for model '{model}' in flat-leaf schema.")
        if len(seen) != 1:
            raise ValueError(f"Model '{model}' has mixed dimensions: {sorted(seen)}")
        dims[model] = seen.pop()
    return dims

# =========================
# Chunking helper
# =========================

def _iter_row_chunks(n_rows: int, chunk_rows: int):
    for s in range(0, n_rows, chunk_rows):
        e = min(s + chunk_rows, n_rows)
        yield s, e

# =========================
# Upsert with retry + adaptive split
# =========================

def _try_upsert(
    client: QdrantClient,
    collection: str,
    ids: List[str],
    vecs: List[np.ndarray],
    payloads: List[Dict[str, Any]],
) -> None:
    client.upsert(
        collection_name=collection,
        points=qm.Batch(ids=ids, vectors=np.vstack(vecs), payloads=payloads),
        wait=False,
    )

def _flush_with_retry(
    client: QdrantClient,
    collection: str,
    ids: List[str],
    vecs: List[np.ndarray],
    payloads: List[Dict[str, Any]],
    global_state: Dict[str, Any] | None,
    max_retries: int = 5,
) -> None:
    n = len(ids)
    if n == 0:
        return

    try:
        t0 = time.perf_counter()
        _try_upsert(client, collection, ids, vecs, payloads)
        dt = time.perf_counter() - t0
        if global_state is not None:
            global_state["uploaded"] += n
            uploaded = global_state["uploaded"]
            total = global_state.get("total", None)
            rate = n / dt if dt > 0 else float("inf")
            if total:
                pct = 100.0 * uploaded / total if total > 0 else 100.0
                remaining = max(total - uploaded, 0)
                eta = remaining / rate if rate > 0 else float("inf")
                LOG.info(
                    f"[{collection}] batch={n} in {dt:.2f}s | rate={rate:,.1f} pts/s | "
                    f"uploaded={uploaded:,}/{total:,} ({pct:.2f}%) | ETA={eta:.1f}s"
                )
            else:
                LOG.info(f"[{collection}] batch={n} in {dt:.2f}s | rate={rate:,.1f} pts/s | uploaded={uploaded:,}")
        return
    except Exception as e:
        LOG.warning(f"[{collection}] upsert failed for batch size {n}: {e!r}")

    left_ids, right_ids = ids[: n // 2], ids[n // 2 :]
    left_vecs, right_vecs = vecs[: n // 2], vecs[n // 2 :]
    left_pl, right_pl = payloads[: n // 2], payloads[n // 2 :]

    delay = 1.0
    for attempt in range(1, max_retries + 1):
        try:
            _try_upsert(client, collection, left_ids, left_vecs, left_pl)
            if global_state is not None:
                global_state["uploaded"] += len(left_ids)
            _try_upsert(client, collection, right_ids, right_vecs, right_pl)
            if global_state is not None:
                global_state["uploaded"] += len(right_ids)
            LOG.info(f"[{collection}] success after split retry (attempt {attempt}) for total {n}")
            return
        except Exception as e:
            LOG.warning(f"[{collection}] retry {attempt}/{max_retries} failed: {e!r}; sleeping {delay:.1f}s")
            time.sleep(delay)
            delay *= 2

    if len(left_ids) > 1:
        _flush_with_retry(client, collection, left_ids, left_vecs, left_pl, global_state, max_retries=3)
    else:
        LOG.error(f"[{collection}] could not upsert even single left point: {left_ids}")

    if len(right_ids) > 1:
        _flush_with_retry(client, collection, right_ids, right_vecs, right_pl, global_state, max_retries=3)
    else:
        LOG.error(f"[{collection}] could not upsert even single right point: {right_ids}")

def _flush_batch(
    client: QdrantClient,
    collection: str,
    ids: List[str],
    vecs: List[np.ndarray],
    payloads: List[Dict[str, Any]],
    *,
    global_state: Dict[str, Any] | None = None,
) -> None:
    if not ids:
        return
    _flush_with_retry(client, collection, ids, vecs, payloads, global_state)
    ids.clear(); vecs.clear(); payloads.clear()

# =========================
# Bulk upload with HNSW optimization
# =========================

def upload_from_flat(
    client: QdrantClient,
    a: Dict[str, Any],
    dim_by_model: Dict[str, int],
    cfg: QdrantConfig,
    batch_size: int = 5_000,       # ? Increased batch size
    use_tqdm: bool = True,
    chunk_rows: int = 100_000,
    cooldown_after_timeouts: float = 0.5,
) -> None:
    try:
        from tqdm import tqdm
    except Exception:
        tqdm = None
        use_tqdm = False

    for model_name, dbs in a.items():
        coll = ensure_collection_for_model(client, model_name, dim_by_model[model_name], cfg)
        
        # ? DISABLE HNSW DURING UPLOAD (massive speedup)
        LOG.info(f"[{coll}] Disabling HNSW indexing for fast upload...")
        client.update_collection(
            collection_name=coll,
            hnsw_config=qm.HnswConfigDiff(m=0)
        )
        
        total_points = 0
        for db_name, fields in dbs.items():
            for field_name, leaf in fields.items():
                if isinstance(leaf, dict) and isinstance(leaf.get("embeddings"), np.ndarray):
                    arr = leaf["embeddings"]
                    if arr.ndim == 2:
                        total_points += arr.shape[0]
        LOG.info(f"[{coll}] planned points: {total_points:,}")

        global_state = {"uploaded": 0, "total": total_points}
        pbar = tqdm(total=total_points, desc=f"{coll}", unit="pt") if (use_tqdm and tqdm and total_points > 0) else None

        t_model_start = time.perf_counter()
        for db_name, fields in dbs.items():
            for field_name, leaf in fields.items():
                if not isinstance(leaf, dict):
                    continue
                arr = leaf.get("embeddings")
                texts = leaf.get("texts", [])
                metas = leaf.get("metadata", [])

                if not (isinstance(arr, np.ndarray) and arr.ndim == 2):
                    continue
                if arr.dtype != np.float32:
                    arr = arr.astype(np.float32)

                n = arr.shape[0]
                if not isinstance(texts, list) or len(texts) != n:
                    texts = ["" for _ in range(n)]
                if not isinstance(metas, list) or len(metas) != n:
                    metas = [{} for _ in range(n)]

                LOG.info(f"[{coll}] slice model={model_name} db={db_name} field={field_name} rows={n:,}")

                for s, e in _iter_row_chunks(n, chunk_rows):
                    LOG.info(f"[{coll}]  chunk rows {s:,}:{e:,} (size={e-s:,})")
                    ids, vecs, payloads = [], [], []
                    timeouts_in_chunk = 0

                    for i in range(s, e):
                        pid = make_point_uuid(model_name, db_name, field_name, i)
                        ids.append(pid)
                        vecs.append(arr[i])
                        row_meta = metas[i] if isinstance(metas[i], dict) else {}
                        payloads.append({
                            "model": model_name,
                            "db": db_name,
                            "field": field_name,
                            "text": texts[i],
                            "composite_id": f"{model_name}|{db_name}|{field_name}|{i}",
                            **row_meta,
                        })

                        if len(ids) >= batch_size:
                            try:
                                _flush_batch(client, coll, ids, vecs, payloads, global_state=global_state)
                                if pbar: pbar.update(batch_size)
                            except Exception as e:
                                timeouts_in_chunk += 1
                                LOG.warning(f"[{coll}] batch flush bubbled exception: {e!r}")
                                batch_size = max(256, batch_size // 2)
                                time.sleep(cooldown_after_timeouts)

                    try:
                        _flush_batch(client, coll, ids, vecs, payloads, global_state=global_state)
                        if pbar:
                            already = pbar.n
                            target = min(global_state["uploaded"], total_points)
                            if target > already:
                                pbar.update(target - already)
                    except Exception as e:
                        timeouts_in_chunk += 1
                        LOG.warning(f"[{coll}] remainder flush bubbled exception: {e!r}")
                        batch_size = max(256, batch_size // 2)
                        time.sleep(cooldown_after_timeouts)

                    if timeouts_in_chunk > 0 and batch_size > 256:
                        LOG.info(f"[{coll}] reducing batch_size to {batch_size} due to timeouts in chunk")

        if pbar:
            pbar.close()

        #  RE-ENABLE HNSW AFTER UPLOAD (builds index once)
        LOG.info(f"[{coll}] Re-enabling HNSW indexing...")
        client.update_collection(
            collection_name=coll,
            hnsw_config=qm.HnswConfigDiff(
                m=cfg.hnsw_m,
                ef_construct=cfg.hnsw_ef_construct
            )
        )

        dt_model = time.perf_counter() - t_model_start
        up = global_state["uploaded"]
        r = up / dt_model if dt_model > 0 else float("inf")
        LOG.info(f"[{coll}] done: uploaded={up:,}/{total_points:,} in {dt_model:.2f}s ({r:,.1f} pts/s)")




In [None]:
# cfg = QdrantConfig(url="http://localhost:6333")
cfg = QdrantConfig(host="localhost", port=6333, grpc_port=6334)

client = get_client(cfg)

dim_by_model = infer_dim_by_model_flat(biochirp_embeddings)

# ? Upload with HNSW disabled during ingest, re-enabled after
upload_from_flat(client, biochirp_embeddings, dim_by_model, cfg, batch_size=1024)

# Example search
from sentence_transformers import SentenceTransformer

model_names = list(dim_by_model.keys())
model_cache = {name: SentenceTransformer(name) for name in model_names}

In [None]:
a

In [None]:
# =========================
# Fast Search (No Text Filtering)
# =========================

def search_fast(
    client: QdrantClient,
    collection_name: str,
    query_vector: List[float],
    limit: int = 100,
    hnsw_ef: int = 64,  # ? Tune this: 32=fastest, 128=more accurate
    filter_conditions: Optional[qm.Filter] = None,
) -> List[qm.ScoredPoint]:
    """
    Fast vector similarity search without text filtering.
    """
    return client.search(
        collection_name=collection_name,
        query_vector=query_vector,
        search_params=qm.SearchParams(
            hnsw_ef=hnsw_ef,
            exact=False,
        ),
        limit=limit,
        query_filter=filter_conditions,
    )


In [None]:
# def search_fast(
#     client: QdrantClient,
#     collection_name: str,
#     query_vector: List[float],
#     model_name: str,
#     db_name: str,
#     field_name: str,
#     limit: int = 100,
#     hnsw_ef: int = 64,
# ) -> List[qm.ScoredPoint]:
#     """
#     Fast vector similarity search without text filtering.
#     """
#     from kneed import KneeLocator
#     # ? Correct filter pattern matching your function
#     flt = qm.Filter(must=[
#         qm.FieldCondition(key="model", match=qm.MatchValue(value=model_name)),
#         qm.FieldCondition(key="db", match=qm.MatchValue(value=db_name)),
#         qm.FieldCondition(key="field", match=qm.MatchValue(value=field_name)),
#     ])
    
#     return client.search(
#         collection_name=collection_name,
#         query_vector=query_vector,
#         search_params=qm.SearchParams(
#             hnsw_ef=hnsw_ef,  # ? 32=fastest, 64=balanced, 128=accurate
#             exact=False,
#         ),
#         limit=limit,
#         with_payload=True,
#         query_filter=flt,
#     )


# def search_reference_term_all_models_FAST(
#     client: QdrantClient,
#     reference_term: str,
#     target_field: str,
#     model_cache: Dict[str, "SentenceTransformer"],
#     limit_per_model: int = 50,
#     use_knee_cutoff: bool = True,
#     db_whitelist: Optional[List[str]] = None,
#     hnsw_ef: int = 64,  # ? NEW: Tunable search speed parameter
# ) -> pd.DataFrame:
#     """
#     Fast version with optimized HNSW search parameters.
#     """
#     from kneed import KneeLocator
#     rows: List[Dict[str, Any]] = []

#     for model_name, model in model_cache.items():
#         coll = model_to_collection(model_name)
#         if not client.collection_exists(coll):
#             LOG.warning(f"[{coll}] collection missing; skipping model='{model_name}'")
#             continue

#         q_vec = model.encode(
#             reference_term,
#             convert_to_tensor=True,
#             normalize_embeddings=True
#         )

#         dbs = db_whitelist if db_whitelist is not None else distinct_values(client, coll, "db")

#         for db_name in dbs:
#             # ? Correct filter pattern
#             flt = qm.Filter(must=[
#                 qm.FieldCondition(key="model", match=qm.MatchValue(value=model_name)),
#                 qm.FieldCondition(key="db", match=qm.MatchValue(value=db_name)),
#                 qm.FieldCondition(key="field", match=qm.MatchValue(value=target_field)),
#             ])

#             hits = client.search(
#                 collection_name=coll,
#                 query_vector=q_vec,
#                 limit=limit_per_model,
#                 with_payload=True,
#                 query_filter=flt,
#                 search_params=qm.SearchParams(
#                     hnsw_ef=hnsw_ef,  # ? Speed optimization
#                     exact=False,
#                 ),
#             )
            
#             if not hits:
#                 continue

#             scores = np.array([h.score for h in hits], dtype=np.float32)
#             if use_knee_cutoff and len(scores) > 1:
#                 sorted_scores = np.sort(scores)[::-1]
#                 try:
#                     knee = KneeLocator(range(len(sorted_scores)), sorted_scores,
#                                        curve="convex", direction="decreasing", S=0.5)
#                     threshold = float(sorted_scores[knee.knee])
#                 except Exception:
#                     threshold = 0.0
#             else:
#                 threshold = 0.0

#             for h in hits:
#                 if h.score >= threshold:
#                     pl = h.payload or {}
#                     row = {
#                         "reference_term": reference_term,
#                         "model": pl.get("model", model_name),
#                         "db": pl.get("db", db_name),
#                         "field": pl.get("field", target_field),
#                         "score": float(h.score),
#                         "cutoff_used": float(threshold),
#                         "text": pl.get("text", ""),
#                     }
#                     for k, v in pl.items():
#                         if k not in row:
#                             row[k] = v
#                     rows.append(row)

#     df = pd.DataFrame(rows)
#     return df.sort_values("score", ascending=False).reset_index(drop=True) if not df.empty else df


In [None]:
def distinct_values(
    client: QdrantClient,
    collection_name: str,
    field_name: str,
    limit: int = 1000,
) -> List[str]:
    """
    Get distinct values for a payload field using Qdrant's Facet API.
    
    Args:
        client: QdrantClient instance
        collection_name: Name of the collection
        field_name: Payload field to get distinct values from
        limit: Maximum number of unique values to return
    
    Returns:
        List of distinct values for the field
    """
    try:
        # Use Qdrant Facet API (available in Qdrant 1.12+)
        facet_result = client.facet(
            collection_name=collection_name,
            key=field_name,
            limit=limit,
        )
        
        # Extract unique values from facet hits
        return [hit.value for hit in facet_result.hits]
    
    except Exception as e:
        LOG.warning(f"Facet API failed for field '{field_name}': {e!r}, falling back to scroll method")
        
        # Fallback: scroll through points and collect unique values
        unique_values = set()
        offset = None
        
        while True:
            records, offset = client.scroll(
                collection_name=collection_name,
                limit=100,
                offset=offset,
                with_payload=[field_name],
                with_vectors=False,
            )
            
            if not records:
                break
                
            for record in records:
                if record.payload and field_name in record.payload:
                    value = record.payload[field_name]
                    if isinstance(value, list):
                        unique_values.update(value)
                    else:
                        unique_values.add(value)
            
            if offset is None:
                break
        
        return list(unique_values)




def search_reference_term_all_models_FAST(
    client: QdrantClient,
    reference_term: str,
    target_field: str,
    model_cache: Dict[str, "SentenceTransformer"],
    limit_per_model: int = 50,
    use_knee_cutoff: bool = True,
    db_whitelist: Optional[List[str]] = None,
    hnsw_ef: int = 64,
) -> pd.DataFrame:
    """
    Fast version with optimized HNSW search parameters.
    """
    rows: List[Dict[str, Any]] = []

    for model_name, model in model_cache.items():
        coll = model_to_collection(model_name)
        if not client.collection_exists(coll):
            LOG.warning(f"[{coll}] collection missing; skipping model='{model_name}'")
            continue

        q_vec = model.encode(
            reference_term,
            convert_to_tensor=True,
            normalize_embeddings=True
        )

        # ? Use facet API to get distinct db values if no whitelist provided
        if db_whitelist is not None:
            dbs = db_whitelist
        else:
            dbs = distinct_values(client, coll, "db")
            LOG.info(f"[{coll}] Found {len(dbs)} databases: {dbs}")

        for db_name in dbs:
            flt = qm.Filter(must=[
                qm.FieldCondition(key="model", match=qm.MatchValue(value=model_name)),
                qm.FieldCondition(key="db", match=qm.MatchValue(value=db_name)),
                qm.FieldCondition(key="field", match=qm.MatchValue(value=target_field)),
            ])

            hits = client.search(
                collection_name=coll,
                query_vector=q_vec,
                limit=limit_per_model,
                with_payload=True,
                query_filter=flt,
                search_params=qm.SearchParams(
                    hnsw_ef=hnsw_ef,
                    exact=False,
                ),
            )
            
            if not hits:
                continue

            scores = np.array([h.score for h in hits], dtype=np.float32)
            if use_knee_cutoff and len(scores) > 1:
                sorted_scores = np.sort(scores)[::-1]
                try:
                    from kneed import KneeLocator
                    knee = KneeLocator(range(len(sorted_scores)), sorted_scores,
                                       curve="convex", direction="decreasing", S=0.3)
                    threshold = float(sorted_scores[knee.knee])
                except Exception:
                    threshold = 0.0
            else:
                threshold = 0.0

            for h in hits:
                if h.score >= threshold:
                    pl = h.payload or {}
                    row = {
                        "reference_term": reference_term,
                        "model": pl.get("model", model_name),
                        "db": pl.get("db", db_name),
                        "field": pl.get("field", target_field),
                        "score": float(h.score),
                        "cutoff_used": float(threshold),
                        "text": pl.get("text", ""),
                    }
                    for k, v in pl.items():
                        if k not in row:
                            row[k] = v
                    rows.append(row)

    df = pd.DataFrame(rows)
    return df.sort_values("score", ascending=False).reset_index(drop=True) if not df.empty else df


In [None]:
from sentence_transformers import SentenceTransformer

# Setup
cfg = QdrantConfig(index_text_field=False)  # Fast version
client = get_client(cfg)
model_names = [
    # "FremyCompany/BioLORD-2023-S",
    "malteos/scincl",
    "pritamdeka/S-PubMedBERT-MS-MARCO",
    # "cambridgeltl/SapBERT-from-PubMedBERT-fulltext",
    "nuvocare/WikiMedical_sent_biobert",
]
model_cache = {name: SentenceTransformer(name) for name in model_names}



In [None]:
# list(set(df["text"]))

In [None]:
# FAST search with optimized HNSW parameters
df = search_reference_term_all_models_FAST(
    client=client,
    reference_term="cancer",
    target_field="disease_name",
    model_cache=model_cache,
    limit_per_model=200,
    use_knee_cutoff=True,
    db_whitelist=["hcdt", "ttd", "ctd"],
    # db_whitelist=["ttd"],
    hnsw_ef=512,  # ? Tune: 32=fastest, 64=balanced, 128=accurate
)

# print(df.head(20))


In [None]:
df

In [None]:
df

In [None]:
df.columns

In [None]:
len(set(df["text"]))

In [None]:
len(set(df["text"]))

In [None]:

# # =========================
# # Usage Example
# # =========================

# # if __name__ == "__main__":


# query = "Tuberculosis treatment"
# model_name = model_names[0]
# query_embedding = model_cache[model_name].encode(query)

# results = search_fast(
#     client=client,
#     collection_name=model_to_collection(model_name),
#     query_vector=query_embedding,
#     limit=200,
#     hnsw_ef=64,  # ? Adjust for speed/accuracy tradeoff
#     filter_conditions=qm.Filter(
#         must=[
#             qm.FieldCondition(
#                 key="db",
#                 match=qm.MatchAny(any=["hcdt", "ttd", "ctd"])
#             )
#         ]
#     )
# )

# print(f"Found {len(results)} results")


In [None]:
results

In [None]:
# from qdrant_client import QdrantClient

# client = QdrantClient(host="localhost", port=6333)
# collection_name = "emb_malteos_scincl"
# client.recreate_collection(
#     collection_name=collection_name,
#     vector_size=768,  # use your model's embedding dimension
#     distance="Cosine",
# )

# for record in records:  # your data loader here
#     client.upsert(
#         collection_name=collection_name,
#         points=[{
#             "id": record["id"],
#             "vector": record["embedding"],
#             "payload": {
#                 "text": record["text"],  # THIS FIELD IS NEEDED!
#                 # other metadata
#             }
#         }]
#     )
