# imports & helpers

In [None]:
import os, re, json, unicodedata, pathlib, math
from typing import List, Dict, Any, Optional, Tuple
from urllib.parse import urlparse, urlunparse
import pandas as pd
import numpy as np
from tqdm import tqdm
from slugify import slugify
import time, math

from sentence_transformers import SentenceTransformer
from neo4j import GraphDatabase, basic_auth

In [None]:
# ---------- Text cleanup (Thai + general) ----------
_ws_re = re.compile(r"[ \t\u00A0\u200B\u200C\u200D\u2060]+")  # spaces + NBSP + ZW* + word joiner
_nl_re = re.compile(r"\s*\n\s*")
def normalize_text(x: Optional[str]) -> str:
    if x is None: 
        return ""
    x = unicodedata.normalize("NFC", str(x))
    x = x.replace("\r", "\n")
    x = _nl_re.sub("\n", x).strip()
    x = _ws_re.sub(" ", x)
    x = re.sub(r"\n[•\-\u2022]\s*", "\n• ", x)
    x = re.sub(r"\n{3,}", "\n\n", x)
    return x

def coalesce(*args, default=""):
    for a in args:
        if a is not None and str(a).strip():
            return str(a)
    return default

def ensure_list(x):
    return x if isinstance(x, list) else [x]

# ---------- JSON loader (array-of-objects) ----------
def load_json(path_json: str) -> pd.DataFrame:
    p = pathlib.Path(path_json)
    if not p.exists():
        raise FileNotFoundError(f"{p} not found.")
    try:
        # Expect a single JSON array of objects
        return pd.read_json(p, orient="records")
    except ValueError:
        # If you accidentally saved as JSON Lines, uncomment the next line:
        # return pd.read_json(p, orient="records", lines=True)
        raise

# ---------- Neo4j credentials loader ----------
def load_neo4j_creds(from_txt: str) -> Dict[str, str]:
    d = {}
    with open(from_txt, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith("#"): 
                continue
            if "=" in line:
                k, v = line.split("=", 1)
                d[k.strip()] = v.strip()
    req = ["NEO4J_URI","NEO4J_USERNAME","NEO4J_PASSWORD","NEO4J_DATABASE"]
    missing = [k for k in req if k not in d]
    if missing:
        raise ValueError(f"Missing keys in creds file: {missing}")
    return {
        "uri": d["NEO4J_URI"],
        "user": d["NEO4J_USERNAME"],
        "pwd": d["NEO4J_PASSWORD"],
        "db": d["NEO4J_DATABASE"],
    }

# ---------- Neo4j driver ----------
class Neo4jClient:
    def __init__(self, uri: str, user: str, pwd: str, database: str):
        self.driver = GraphDatabase.driver(uri, auth=basic_auth(user, pwd))
        self.database = database
    def close(self):
        self.driver.close()
    def run(self, cypher: str, **params):
        with self.driver.session(database=self.database) as sess:
            res = sess.run(cypher, params)
            return list(res)

def pcount(label: str, client: Neo4jClient):
    out = client.run(f"MATCH (n:{label}) RETURN count(n) AS c")
    return out[0]["c"] if out else 0


def canonicalize_url(u: str) -> str:
    p = urlparse(u.strip())
    # lowercase scheme/host; drop params/query/fragment; strip trailing slash
    new = p._replace(
        scheme=p.scheme.lower(),
        netloc=p.netloc.lower(),
        params="",
        query="",
        fragment=""
    )
    out = urlunparse(new).rstrip("/")
    return out


def embed_passages(texts, batch_size=64):
    texts = [f"passage: {t}" if not str(t).lower().startswith("passage:") else str(t) for t in texts]
    return np.asarray(
        model.encode(texts, batch_size=batch_size, show_progress_bar=False, normalize_embeddings=True),
        dtype=np.float32
    )


def chunker(seq, n):
    for i in range(0, len(seq), n):
        yield seq[i:i+n]

# load dataframes & normalize

In [None]:
# JSON paths only
threads_json  = "agnos_forum_threads.json"
diseases_json = "agnos_diseases.json"

df_threads = load_json(threads_json)
df_diseases = load_json(diseases_json)

# --- Light cleanup / column harmonization ---

# THREADS unified schema:
# url, thread_id, title, thread_category, answer_by_doctor, title_category
colmap_candidates = {
    "title": ["title","question","thread_title","big_title"],
    "thread_category": ["thread_category","category","chip_category"],
    "answer_by_doctor": ["answer_by_doctor","doctor_answer","answer","answer_text"],
    "thread_id": ["thread_id","slug","id"],
    "url": ["url","link","page_url"]
}
def pick_col(df, keys):
    for k in keys:
        if k in df.columns: 
            return k
    return None

t_title_col = pick_col(df_threads, colmap_candidates["title"])
t_cat_col   = pick_col(df_threads, colmap_candidates["thread_category"])
t_ans_col   = pick_col(df_threads, colmap_candidates["answer_by_doctor"])
t_id_col    = pick_col(df_threads, colmap_candidates["thread_id"])
t_url_col   = pick_col(df_threads, colmap_candidates["url"])

if t_title_col is None or t_cat_col is None or t_url_col is None:
    raise RuntimeError("Threads JSON missing required columns (title/category/url).")

def derive_thread_id(u: str) -> str:
    try:
        slug = u.strip("/").split("/")[-1]
        m = re.search(r"(\d+)", slug)
        return m.group(1) if m else slugify(slug, lowercase=True)
    except Exception:
        return slugify(u, lowercase=True)

threads = []
for _, r in df_threads.iterrows():
    url = canonicalize_url(normalize_text(r.get(t_url_col, "")))
    title = normalize_text(r.get(t_title_col, ""))
    thread_category = normalize_text(r.get(t_cat_col, ""))
    answer_by_doctor = normalize_text(r.get(t_ans_col, "")) if t_ans_col else ""
    thread_id = normalize_text(r.get(t_id_col, "")) if t_id_col else ""
    if not thread_id:
        thread_id = derive_thread_id(url)
    title_category = f"{title} {thread_category}".strip()
    threads.append({
        "url": url,
        "thread_id": thread_id,
        "title": title,
        "thread_category": thread_category,
        "answer_by_doctor": answer_by_doctor or None,
        "title_category": title_category
    })
df_threads_u = pd.DataFrame(threads)

# DISEASES unified schema:
# url, slug, thai_name, english_name, title, info_and_causes, symptoms, diagnosis, treatment,
# specialist_doctor_recommended, precautions, additional_info, symptom_disease
def find_first(df, candidates):
    for c in candidates:
        if c in df.columns:
            return c
    return None

d_url = find_first(df_diseases, ["url","page_url","link"])
d_title = find_first(df_diseases, ["title","thai_title","thai_name"])
d_en = find_first(df_diseases, ["english_name","en_name","en_title"])
d_th = find_first(df_diseases, ["thai_name","thai_title","title"])
d_info = find_first(df_diseases, ["ข้อมูลและสาเหตุของโรค","info_and_causes","causes","ข้อมูลสาเหตุ"])
d_symp = find_first(df_diseases, ["อาการ","symptoms"])
d_dx = find_first(df_diseases, ["แนวทางการวินิจฉัย","diagnosis"])
d_tx = find_first(df_diseases, ["แนวทางการรักษา","treatment"])
d_adv = find_first(df_diseases, ["คำแนะนำจากผู้เชี่ยวชาญ","specialist_doctor_recommended","advice"])
d_warn = find_first(df_diseases, ["ข้อควรระวัง","precautions","warning"])
d_more = find_first(df_diseases, ["ข้อมูลเพิ่มเติม","additional_info","more_info"])

if d_url is None:
    raise RuntimeError("Diseases JSON missing required column: url")

diseases = []
for _, r in df_diseases.iterrows():
    # url = normalize_text(r.get(d_url, ""))
    url = canonicalize_url(normalize_text(r.get(d_url, "")))
    slug = slugify(url.strip("/").split("/")[-1], lowercase=True)
    thai_name = normalize_text(r.get(d_th, "")) if d_th else ""
    english_name = normalize_text(r.get(d_en, "")) if d_en else ""
    title = normalize_text(r.get(d_title, "")) if d_title else (thai_name or english_name)
    info_and_causes = normalize_text(r.get(d_info, "")) if d_info else ""
    symptoms = normalize_text(r.get(d_symp, "")) if d_symp else ""
    diagnosis = normalize_text(r.get(d_dx, "")) if d_dx else ""
    treatment = normalize_text(r.get(d_tx, "")) if d_tx else ""
    specialist = normalize_text(r.get(d_adv, "")) if d_adv else ""
    precautions = normalize_text(r.get(d_warn, "")) if d_warn else ""
    additional_info = normalize_text(r.get(d_more, "")) if d_more else ""
    symptom_disease = f"{info_and_causes} {symptoms}".strip()
    diseases.append({
        "url": url,
        "slug": slug,
        "thai_name": thai_name or title,
        "english_name": english_name or "",
        "title": title,
        "info_and_causes": info_and_causes,
        "symptoms": symptoms,
        "diagnosis": diagnosis,
        "treatment": treatment,
        "specialist_doctor_recommended": specialist,
        "precautions": precautions,
        "additional_info": additional_info,
        "symptom_disease": symptom_disease
    })
df_diseases_u = pd.DataFrame(diseases)

print("Threads:", df_threads_u.shape, "Diseases:", df_diseases_u.shape)
df_threads_u.head(2), df_diseases_u.head(2)


# Dev Mode for faster

In [None]:
# ===== DEV MODE: fast subset without renaming downstream vars =====
DEV_MODE   = False          # flip to False to run on full data
N_THREADS  = 20           # quick dev size
N_DISEASES = 10           # quick dev size

def _subset(df, text_col, n):
    d = df.copy()
    # keep rows that actually have some text
    d = d[d[text_col].astype(str).str.len() >= 6]
    # de-dup by URL just in case
    if "url" in d.columns:
        d = d.drop_duplicates(subset=["url"])
    # deterministic sample if larger than n
    if len(d) > n:
        d = d.sample(n=n, random_state=42)
    return d.reset_index(drop=True)

if DEV_MODE:
    _orig_t, _orig_d = len(df_threads_u), len(df_diseases_u)
    df_threads_u  = _subset(df_threads_u,  "title_category",   N_THREADS)
    df_diseases_u = _subset(df_diseases_u, "symptom_disease",  N_DISEASES)
    print(f"[DEV_MODE] Threads {len(df_threads_u)}/{_orig_t} | Diseases {len(df_diseases_u)}/{_orig_d}")
else:
    print("[DEV_MODE] OFF — using full datasets.")


# connect to Neo4j (reads creds from the provided TXT)

In [None]:
CREDS_TXT = "Neo4j-92e8f832-Created-2025-08-28.txt" 
creds = load_neo4j_creds(CREDS_TXT)
client = Neo4jClient(creds["uri"], creds["user"], creds["pwd"], creds["db"])

print("Connected to:", creds["uri"], " DB:", creds["db"])

# embedding model (BAAI/bge-m3) & dimension

In [None]:
# Load model (1024-d). We'll normalize (L2) for cosine similarity.
model_name = "BAAI/bge-m3"
model = SentenceTransformer(model_name)
emb_dim = model.get_sentence_embedding_dimension()
print("Model:", model_name, "Embedding dim:", emb_dim)


# upsert Threads with embeddings (batched)

In [None]:
# Warm-up (no-op if you've run it already)
_ = model.encode(["query: warmup"], normalize_embeddings=True)

# Uniqueness constraint (idempotent)
client.run("CREATE CONSTRAINT thread_url_unique IF NOT EXISTS FOR (t:Thread) REQUIRE t.url IS UNIQUE")

# Prepare rows from df (re-canonicalize defensively even if you canonicalized at load time)
rows = [{
    "url": canonicalize_url(r["url"]),
    "thread_id": r["thread_id"],
    "title": r["title"],
    "thread_category": r["thread_category"],
    "answer_by_doctor": r.get("answer_by_doctor"),
    "title_category": r["title_category"],
} for r in df_threads_u.to_dict("records")]

# Find existing Thread nodes that already have embeddings
existing_rows = client.run("""
MATCH (t:Thread)
WHERE t.title_category_emb IS NOT NULL
RETURN t.url AS url
""")
existing = { canonicalize_url(r["url"]) for r in existing_rows }

# Toggle whether to re-embed existing rows
REEMBED_EXISTING = False  # set True to recompute embeddings for everything

if REEMBED_EXISTING:
    to_embed = rows
    to_update_only = []
else:
    to_embed = [r for r in rows if r["url"] not in existing]
    to_update_only = [r for r in rows if r["url"] in existing]

print(f"Threads to embed: {len(to_embed)} | to update-only: {len(to_update_only)}")

# --- Update-only (refresh scalar fields, keep existing embedding) ---
if to_update_only:
    cypher_update_only = """
    UNWIND $rows AS row
    MERGE (t:Thread {url: row.url})
    SET  t.thread_id = row.thread_id,
         t.title = row.title,
         t.thread_category = row.thread_category,
         t.answer_by_doctor = row.answer_by_doctor,
         t.title_category = row.title_category
    """
    BATCH = 64
    for bi, part in enumerate(chunker(to_update_only, BATCH), start=1):
        t0 = time.perf_counter()
        client.run(cypher_update_only, rows=part)
        t1 = time.perf_counter()
        print(f"[update-only] Batch {bi} — {len(part)} rows | write {t1-t0:.1f}s")

# --- Embed + upsert (set embedding for new/changed items) ---
if to_embed:
    cypher_embed = """
    UNWIND $rows AS row
    MERGE (t:Thread {url: row.url})
    SET  t.thread_id = row.thread_id,
         t.title = row.title,
         t.thread_category = row.thread_category,
         t.answer_by_doctor = row.answer_by_doctor,
         t.title_category = row.title_category,
         t.title_category_emb = row.title_category_emb
    """
    BATCH = 64  # small for responsive logs; increase later
    total = len(to_embed)
    num_batches = math.ceil(total / BATCH)
    print(f"Embedding & upserting {total} Thread rows in {num_batches} batches (BATCH={BATCH})")
    for bi, part in enumerate(chunker(to_embed, BATCH), start=1):
        t0 = time.perf_counter()
        embs = embed_passages([x["title_category"] for x in part], batch_size=min(64, len(part)))
        for i, x in enumerate(part):
            x["title_category_emb"] = embs[i].tolist()
        t1 = time.perf_counter()
        client.run(cypher_embed, rows=part)
        t2 = time.perf_counter()
        print(f"[embed] Batch {bi}/{num_batches} — {len(part)} rows | embed {t1-t0:.1f}s | write {t2-t1:.1f}s")

print("Thread nodes now:", pcount("Thread", client))


# upsert Diseases with embeddings (batched)

In [None]:
# Uniqueness constraint (idempotent)
client.run("CREATE CONSTRAINT disease_url_unique IF NOT EXISTS FOR (d:Disease) REQUIRE d.url IS UNIQUE")

# Prepare rows from dataframe (URLs are already canonicalized from step #2, but we re-canonicalize defensively)
rows = [{
    "url": canonicalize_url(r["url"]),
    "slug": r["slug"],
    "thai_name": r["thai_name"],
    "english_name": r.get("english_name",""),
    "title": r["title"],
    "info_and_causes": r["info_and_causes"],
    "symptoms": r["symptoms"],
    "diagnosis": r["diagnosis"],
    "treatment": r["treatment"],
    "specialist_doctor_recommended": r["specialist_doctor_recommended"],
    "precautions": r["precautions"],
    "additional_info": r["additional_info"],
    "symptom_disease": r["symptom_disease"],
} for r in df_diseases_u.to_dict("records")]

# Skip diseases that already have embeddings in the DB (compare using canonical URLs)
existing_rows = client.run("""
MATCH (d:Disease)
WHERE d.symptom_disease_emb IS NOT NULL
RETURN d.url AS url
""")
existing = { canonicalize_url(r["url"]) for r in existing_rows }

before = len(rows)
rows = [r for r in rows if r["url"] not in existing]
skipped = before - len(rows)
print(f"Skipping {skipped} already-embedded Disease nodes; processing {len(rows)} new/changed.")

# Nothing to do?
if not rows:
    print("No Disease rows to upsert. (All have embeddings.)")
else:
    BATCH = 16  # small batch for responsive logs (tweak as you like)
    total = len(rows)
    num_batches = math.ceil(total / BATCH)

    cypher = """
    UNWIND $rows AS row
    MERGE (d:Disease {url: row.url})
    SET  d.slug = row.slug,
         d.thai_name = row.thai_name,
         d.english_name = row.english_name,
         d.title = row.title,
         d.info_and_causes = row.info_and_causes,
         d.symptoms = row.symptoms,
         d.diagnosis = row.diagnosis,
         d.treatment = row.treatment,
         d.specialist_doctor_recommended = row.specialist_doctor_recommended,
         d.precautions = row.precautions,
         d.additional_info = row.additional_info,
         d.symptom_disease = row.symptom_disease,
         d.symptom_disease_emb = row.symptom_disease_emb
    """

    print(f"Starting upsert of {total} Disease rows in {num_batches} batches (BATCH={BATCH})")
    for bi, part in enumerate((rows[i:i+BATCH] for i in range(0, total, BATCH)), start=1):
        t0 = time.perf_counter()
        embs = embed_passages([x["symptom_disease"] for x in part], batch_size=min(64, len(part)))
        for i, x in enumerate(part):
            x["symptom_disease_emb"] = embs[i].tolist()
        t1 = time.perf_counter()
        client.run(cypher, rows=part)
        t2 = time.perf_counter()
        print(f"Batch {bi}/{num_batches} — {len(part)} rows | embed {t1-t0:.1f}s | write {t2-t1:.1f}s")

print("Disease nodes now:", pcount("Disease", client))


# create vector indexes (native KNN)

In [None]:
# Create native VECTOR INDEXes (Aura supports this; 5.x+)
client.run(f"""
CREATE VECTOR INDEX thread_titlecat_idx IF NOT EXISTS
FOR (t:Thread) ON (t.title_category_emb)
OPTIONS {{
  indexConfig: {{
    `vector.dimensions`: $dim,
    `vector.similarity_function`: 'cosine'
  }}
}}
""", dim=int(emb_dim))

client.run(f"""
CREATE VECTOR INDEX disease_symptom_idx IF NOT EXISTS
FOR (d:Disease) ON (d.symptom_disease_emb)
OPTIONS {{
  indexConfig: {{
    `vector.dimensions`: $dim,
    `vector.similarity_function`: 'cosine'
  }}
}}
""", dim=int(emb_dim))

print("Vector indexes created (or already exist).")


# sanity checks (counts + small samples)

In [None]:
print("Threads:", pcount("Thread", client))
print("Diseases:", pcount("Disease", client))

res_t = client.run("MATCH (t:Thread) RETURN t.url AS url, t.title AS title LIMIT 3")
res_d = client.run("MATCH (d:Disease) RETURN d.url AS url, d.title AS title LIMIT 3")

print("\nSample Threads:")
for r in res_t:
    print("-", r["title"], "->", r["url"])

print("\nSample Diseases:")
for r in res_d:
    print("-", r["title"], "->", r["url"])


# test KNN: threads (symptom/query) & diseases (name/query)

In [None]:
def embed_query(q: str) -> List[float]:
    # Recommended query prefix for bge
    q_prep = f"query: {q}" if not q.lower().startswith("query:") else q
    v = model.encode([q_prep], normalize_embeddings=True)[0].astype(np.float32).tolist()
    return v

def knn_threads(query: str, k: int = 5):
    v = embed_query(query)
    cypher = """
    CALL db.index.vector.queryNodes('thread_titlecat_idx', $k, $v) 
    YIELD node, score
    RETURN node.url AS url, node.title AS title, node.thread_category AS category, score
    """
    return client.run(cypher, k=int(k), v=v)

def knn_diseases(query: str, k: int = 2):
    v = embed_query(query)
    cypher = """
    CALL db.index.vector.queryNodes('disease_symptom_idx', $k, $v)
    YIELD node, score
    RETURN node.url AS url, node.title AS title, node.thai_name AS th_name, node.english_name AS en_name, score
    """
    return client.run(cypher, k=int(k), v=v)

In [None]:
q_symptoms = "ผดผื่นที่แขน"
q_disease  = "โรคไหล่อักเสบ"  # e.g., user types a disease name

print("=== THREADS by symptom/query ===")
for r in knn_threads(q_symptoms, k=5):
    print(f"{r['score']:.4f}", "-", r["title"], f"({r.get('category')})", "->", r["url"])

print("\n=== DISEASES by disease-name/query ===")
for r in knn_diseases(q_disease, k=2):
    en = r.get("en_name") or ""
    print(f"{r['score']:.4f}", "-", r["th_name"], (f"/ {en}" if en else ""), "->", r["url"])

# (optional) compact counts of embedded nodes

In [None]:
res = client.run("""
MATCH (t:Thread) WHERE t.title_category_emb IS NOT NULL
RETURN count(t) AS c_threads
""")
print("Threads with embeddings:", res[0]["c_threads"])

res = client.run("""
MATCH (d:Disease) WHERE d.symptom_disease_emb IS NOT NULL
RETURN count(d) AS c_diseases
""")
print("Diseases with embeddings:", res[0]["c_diseases"])


# (cleanup helper) close driver when done

In [None]:
client.close()
print("Closed Neo4j driver.")
