In [1]:
from pathlib import Path
import os

target_dir = "./"

# === Paths ===
DATA_PATH = Path(target_dir + "bioasq.json")      # <-- set to your actual BioASQ JSON filename
DB_DIR    = Path(target_dir + "chroma_bioasq")    # persistent directory for Chroma
DB_DIR.mkdir(parents=True, exist_ok=True)

# === Collection ===
COLLECTION_NAME = "bioasq_q2q"

# === Embedding model ===
# Choose ONE of these:
#   "BAAI/bge-small-en-v1.5"  (fast, 384-d)  -- default
#   "intfloat/e5-base-v2"     (requires 'query:'/'passage:' prefixes)
EMBED_MODEL_NAME = os.environ.get("EMBED_MODEL_NAME", "BAAI/bge-small-en-v1.5")

# === Index behavior ===
REBUILD_INDEX = False   # set True to drop & rebuild the collection

# === Optional LLM refinement ===
# Option A: local Ollama (install app, run `ollama pull mistral` or `ollama pull llama3:8b`)
USE_OLLAMA     = True
OLLAMA_MODEL   = os.environ.get("OLLAMA_MODEL", "gemma3:1b")  # e.g. "mistral", "llama3:8b"

# Option B: OpenRouter (free-tier friendly). Create an API key and export:
#   export OPENROUTER_API_KEY="..."
USE_OPENROUTER = False
OPENROUTER_MODEL = os.environ.get("OPENROUTER_MODEL", "mistralai/mistral-7b-instruct:free")

# If both are True, we'll try Ollama first, then OpenRouter.

In [2]:
import json
from typing import List, Dict, Tuple

def load_bioasq(path: Path) -> List[Dict]:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    # expecting {"questions": [ {...}, ... ]}
    return data["questions"]

def to_qas(records: List[Dict]) -> List[Tuple[str, str, str]]:
    """
    Return list of (id, question, answer) tuples.
    Favors 'ideal_answer' (joined) else tries 'exact_answer'.
    """
    out = []
    for i, r in enumerate(records):
        qid = r.get("id") or f"bioasq-{i}"
        question = (r.get("body") or "").strip()
        ans = ""
        if r.get("ideal_answer"):
            # 'ideal_answer' is often a list of strings; join with space
            if isinstance(r["ideal_answer"], list):
                ans = " ".join(s.strip() for s in r["ideal_answer"] if s and isinstance(s, str)).strip()
            elif isinstance(r["ideal_answer"], str):
                ans = r["ideal_answer"].strip()
        if not ans and r.get("exact_answer"):
            if isinstance(r["exact_answer"], list):
                # Could be list of lists; flatten
                flat = []
                for item in r["exact_answer"]:
                    if isinstance(item, list):
                        flat.extend(item)
                    elif isinstance(item, str):
                        flat.append(item)
                ans = "; ".join(s.strip() for s in flat if s and isinstance(s, str)).strip()
            elif isinstance(r["exact_answer"], str):
                ans = r["exact_answer"].strip()
        out.append((qid, question, ans))
    return out


In [3]:
from sentence_transformers import SentenceTransformer
import numpy as np
from typing import Sequence

class Embedder:
    """
    Handles model loading and the query/document prefix rules.
    - E5 models: 'query: ' / 'passage: '  (required)  [Microsoft/Zilliz & SBERT docs]
    - BGE models: recommended query instruction prompt (docs often recommend adding it)  [SBERT/Haystack]
    """
    def __init__(self, model_name: str, device: str | None = None):
        self.model_name = model_name
        self.model = SentenceTransformer(model_name, device=device)
        self.model.max_seq_length = 512  # safe default

        name = model_name.lower()
        self.is_e5  = "e5"  in name
        self.is_bge = "bge" in name

        # Recommended query instruction for BGE (English)
        self.bge_query_instruction = "Represent this sentence for searching relevant passages: "

    def _prep_docs(self, texts: Sequence[str]) -> list[str]:
        if self.is_e5:
            return [f"passage: {t}" for t in texts]
        # BGE doesn’t require a doc prefix; empirically it’s fine as-is.
        return list(texts)

    def _prep_query(self, text: str) -> str:
        if self.is_e5:
            return f"query: {text}"
        if self.is_bge:
            return f"{self.bge_query_instruction}{text}"
        return text

    def embed_docs(self, texts: Sequence[str], batch_size: int = 64) -> np.ndarray:
        prepped = self._prep_docs(texts)
        return self.model.encode(
            prepped, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True
        )

    def embed_query(self, text: str) -> np.ndarray:
        prepped = self._prep_query(text)
        return self.model.encode([prepped], normalize_embeddings=True, convert_to_numpy=True)[0]


  from tqdm.autonotebook import tqdm, trange


In [4]:
import chromadb
from chromadb.utils import embedding_functions  # (not used, but handy for alternatives)
from tqdm import tqdm
import uuid
from typing import Optional
from chromadb.api.models import Collection 

def get_client(path: Path):
    return chromadb.PersistentClient(path=str(path))  # persists to disk

def get_or_create_collection(client, name: str):
    # Set cosine distance for nearest-neighbor search
    # (HNSW space can be controlled via metadata/config; cosine is convenient for normalized embeddings)
    try:
        return client.get_collection(name=name)
    except Exception:
        return client.create_collection(name=name, metadata={"hnsw:space": "cosine"})

def drop_collection_if_exists(client, name: str):
    try:
        client.delete_collection(name=name)
    except Exception:
        pass

def build_index(
    data_path: Path,
    db_dir: Path,
    collection_name: str = COLLECTION_NAME,
    model_name: str = EMBED_MODEL_NAME,
    rebuild: bool = REBUILD_INDEX,
    batch_size: int = 256,
) -> tuple[Collection, Embedder]:
    client = get_client(db_dir)

    if rebuild:
        drop_collection_if_exists(client, collection_name)

    collection = get_or_create_collection(client, collection_name)
    embedder = Embedder(model_name)

    # If already populated and not rebuilding, skip
    try:
        existing_count = collection.count()
    except Exception:
        existing_count = 0

    if existing_count and not rebuild:
        print(f"[Chroma] Reusing collection '{collection_name}' with {existing_count} items.")
        return collection, embedder

    # Load data
    records = load_bioasq(data_path)
    qas = [(qid, q, a) for (qid, q, a) in to_qas(records) if q and a]
    print(f"Loaded {len(qas)} Q/A pairs from {data_path.name}")

    # Prepare batches
    ids, docs, metas = [], [], []
    questions = [q for _, q, _ in qas]
    answers   = [a for _, _, a in qas]
    qids      = [qid for qid, _, _ in qas]

    # Compute embeddings for questions
    all_embs = embedder.embed_docs(questions, batch_size=64)

    # Add to collection in chunks
    for i in tqdm(range(0, len(qas), batch_size), desc="Adding to Chroma"):
        sl = slice(i, i + batch_size)
        chunk_ids   = [qids[j] or str(uuid.uuid4()) for j in range(*sl.indices(len(qas)))]
        chunk_docs  = [questions[j] for j in range(*sl.indices(len(qas)))]
        chunk_metas = [{"answer": answers[j]} for j in range(*sl.indices(len(qas)))]

        collection.add(
            ids=chunk_ids,
            documents=chunk_docs,
            metadatas=chunk_metas,
            embeddings=all_embs[sl].tolist(),  # pass precomputed embeddings
        )

    print(f"[Chroma] Built collection '{collection_name}' with {collection.count()} items.")
    return collection, embedder

collection, embedder = build_index(DATA_PATH, DB_DIR)


Failed to send telemetry event ClientStartEvent: capture() takes 1 positional argument but 3 were given
Failed to send telemetry event ClientCreateCollectionEvent: capture() takes 1 positional argument but 3 were given


Loaded 5049 Q/A pairs from bioasq.json


Batches: 100%|█████████████████████████████████████████████████████████████████████████████| 79/79 [02:05<00:00,  1.59s/it]
Adding to Chroma:   0%|                                                                             | 0/20 [00:00<?, ?it/s]Failed to send telemetry event CollectionAddEvent: capture() takes 1 positional argument but 3 were given
Adding to Chroma: 100%|████████████████████████████████████████████████████████████████████| 20/20 [00:11<00:00,  1.75it/s]

[Chroma] Built collection 'bioasq_q2q' with 5049 items.





In [5]:
from typing import Any

def retrieve_answer(
    user_question: str,
    k: int = 5,
    return_matches: bool = True,
) -> dict[str, Any]:
    q_emb = embedder.embed_query(user_question)  # 1 x d

    results = collection.query(
        query_embeddings=[q_emb.tolist()],
        n_results=k,
        include=["documents", "metadatas", "distances", "embeddings"],
    )
    # results are columnar; unwrap first query
    ids = results.get("ids", [[]])[0]
    docs = results.get("documents", [[]])[0]
    metas = results.get("metadatas", [[]])[0]
    dists = results.get("distances", [[]])[0]

    if not ids:
        return {"answer": None, "matches": []}

    best = {
        "id": ids[0],
        "matched_question": docs[0],
        "retrieved_answer": metas[0].get("answer", ""),
        "distance": dists[0],
    }

    payload = {"answer": best["retrieved_answer"], "match": best}
    if return_matches:
        payload["matches"] = [
            {
                "id": ids[i],
                "matched_question": docs[i],
                "retrieved_answer": metas[i].get("answer", ""),
                "distance": dists[i],
            }
            for i in range(len(ids))
        ]
    return payload


In [6]:
def refine_with_ollama(question: str, raw_answer: str, matched_question: str, model: str = OLLAMA_MODEL) -> str | None:
    if not USE_OLLAMA:
        return None
    try:
        import ollama
        prompt = (
            "You are a biomedical assistant. Using the retrieved answer, produce a concise, direct answer.\n"
            "If the retrieved answer already looks complete, lightly paraphrase for clarity without adding facts.\n\n"
            f"User question: {question}\n"
            f"Retrieved (matched) question: {matched_question}\n"
            f"Retrieved answer: {raw_answer}\n\n"
            "Final answer (one short paragraph):"
        )
        resp = ollama.chat(model=model, messages=[{"role": "user", "content": prompt}])
        return resp["message"]["content"].strip()
    except Exception as e:
        print(f"[Ollama] Skipping refinement ({e})")
        return None


In [7]:
def answer_question(user_question: str, k: int = 5, refine: bool = True) -> dict:
    ret = retrieve_answer(user_question, k=k, return_matches=True)
    raw = ret["answer"]
    match_q = ret["match"]["matched_question"] if ret.get("match") else ""

    final_answer = raw
    if refine and raw:
        refined = refine_with_ollama(user_question, raw, match_q)
        if refined:
            final_answer = refined

    return {
        "question": user_question,
        "final_answer": final_answer,
        "retrieved_answer": raw,
        "matched_question": match_q,
        "matches": ret.get("matches", []),
    }


In [8]:
# Try with a sample query (adjust to your dataset).
sample_q = "Is Hirschsprung disease Mendelian or multifactorial?"
result = answer_question(sample_q, k=5, refine=True)

print("Q:", result["question"])
print("\nMatched question:", result["matched_question"])
print("\nAnswer (final):", result["final_answer"])
print("\nAnswer (raw retrieved):", result["retrieved_answer"])
print("\nTop matches (id, distance):")
for m in result["matches"][:5]:
    print(" -", m["id"], f"(distance={m['distance']:.4f})")


Failed to send telemetry event CollectionQueryEvent: capture() takes 1 positional argument but 3 were given


Q: Is Hirschsprung disease Mendelian or multifactorial?

Matched question: Is Hirschsprung disease a mendelian or a multifactorial disorder?

Answer (final): Hirschsprung disease is considered a complex disorder, with a mix of genetic and environmental factors potentially contributing. While Mendelian forms are prevalent, sporadic cases are often linked to multiplicative inheritance, involving multiple loci.

Answer (raw retrieved): Coding sequence mutations in RET, GDNF, EDNRB, EDN3, and SOX10 are involved in the development of Hirschsprung disease. The majority of these genes was shown to be related to Mendelian syndromic forms of Hirschsprung's disease, whereas the non-Mendelian inheritance of sporadic non-syndromic Hirschsprung disease proved to be complex; involvement of multiple loci was demonstrated in a multiplicative model.

Top matches (id, distance):
 - 55031181e9bde69634000014 (distance=0.0341)
 - 5503121de9bde69634000019 (distance=0.2310)
 - 55391825bc4f83e828000016 (dista