In [None]:
"""
===============================================================================
Medicine recommendation RAG pipeline
===============================================================================
"""

import os, re
from pathlib import Path

import pandas as pd
import numpy as np
import networkx as nx
import spacy
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
import joblib

# -------------------------
# OPTIONAL FAISS
# -------------------------
USE_FAISS = True
try:
    import faiss
except ImportError:
    USE_FAISS = False

# -------------------------
# OPTIONAL GROQ CLIENT
# -------------------------
try:
    from groq import Groq
except ImportError:
    Groq = None
    print("[WARN] groq SDK not installed.")

# -------------------------
# CONFIG
# -------------------------
OUT_DIR = Path("kg_rag_artifacts")
DATA_CSV = "drugs_side_effects.csv"
EMBEDDING_FILE = OUT_DIR / "corpus_embeddings.npy"
FAISS_INDEX_FILE = OUT_DIR / "faiss.index"
KG_FILE = OUT_DIR / "medical_kg.graphml"

EMBEDDER_MODEL = "all-MiniLM-L6-v2"
GROQ_MODEL = "gemma2-9b-it"

# âœ… FIX: read API key from environment
GROQ_API_KEY = os.getenv("GROQ_API_KEY")

# -------------------------
# HELPERS
# -------------------------
def clean_text(text: str) -> str:
    if pd.isna(text):
        return ""
    s = str(text)
    s = re.sub(r"[\r\n]+", " ", s)
    s = re.sub(r"[^A-Za-z0-9\s\-,\.;:()/%]", " ", s)
    s = re.sub(r"\s+", " ", s)
    return s.strip()

# -------------------------
# LOAD NLP
# -------------------------
try:
    import scispacy
    try:
        nlp = spacy.load("en_core_sci_sm")
    except Exception:
        nlp = spacy.load("en_core_web_sm")
except Exception:
    nlp = spacy.load("en_core_web_sm")

embedder = SentenceTransformer(EMBEDDER_MODEL)

# -------------------------
# LOAD DATA
# -------------------------
df = pd.read_csv(DATA_CSV).fillna("")

for col in ["drug_name", "side_effects", "medical_condition"]:
    df[f"{col}_clean"] = df[col].astype(str).apply(clean_text)

corpus_embeddings = np.load(EMBEDDING_FILE)

# FAISS
if USE_FAISS and FAISS_INDEX_FILE.exists():
    index = faiss.read_index(str(FAISS_INDEX_FILE))
else:
    index = None
    print("[WARN] FAISS not available, using brute-force similarity")

# Knowledge Graph
G = nx.read_graphml(KG_FILE)

# -------------------------
# NER
# -------------------------
def run_ner(text):
    doc = nlp(text)
    ents = [(ent.text.strip(), ent.label_) for ent in doc.ents]
    if not ents:
        ents = [(chunk.text.strip(), "NOUN_CHUNK") for chunk in doc.noun_chunks]
    return list(dict.fromkeys(ents))

def extract_query_entities(symptoms, additional_info):
    tokens = [clean_text(s) for s in symptoms]
    tokens += [clean_text(e) for e, _ in run_ner(additional_info)]

    doc = nlp(additional_info)
    for tok in doc:
        if tok.pos_ in {"NOUN", "PROPN", "ADJ"} and len(tok.text) > 2:
            tokens.append(clean_text(tok.text))

    return list(dict.fromkeys(filter(None, tokens)))

# -------------------------
# KG HELPERS
# -------------------------
def match_graph_nodes(tokens, max_matches=10):
    matches = []
    for t in tokens:
        for n, d in G.nodes(data=True):
            if t.lower() in d.get("label", "").lower():
                matches.append(n)
                if len(matches) >= max_matches:
                    break
    return list(dict.fromkeys(matches))

def expand_subgraph(seed_nodes, radius=2):
    if not seed_nodes:
        return nx.Graph()

    nodes = set(seed_nodes)
    frontier = set(seed_nodes)

    for _ in range(radius):
        new = set()
        for n in frontier:
            new |= set(G.successors(n)) | set(G.predecessors(n))
        nodes |= new
        frontier = new

    return G.subgraph(nodes).copy()

def subgraph_to_text(subg, max_triples=60):
    triples = []
    for u, v, d in subg.edges(data=True):
        triples.append(
            f"{subg.nodes[u].get('label', u)} --{d.get('relation','related_to')}--> {subg.nodes[v].get('label', v)}"
        )
    return "\n".join(triples[:max_triples])

# -------------------------
# SEMANTIC SEARCH
# -------------------------
def semantic_retrieve(text, top_k=5):
    qv = embedder.encode([clean_text(text)], normalize_embeddings=True)

    if USE_FAISS and index is not None:
        _, I = index.search(qv.astype("float32"), top_k)
        idx = I[0]
    else:
        sims = cosine_similarity(qv, corpus_embeddings)[0]
        idx = sims.argsort()[-top_k:][::-1]

    return df.iloc[idx].copy()

# -------------------------
# GROQ GENERATION
# -------------------------
def generate_with_groq(question, context):
    if Groq is None:
        raise RuntimeError("Groq SDK not installed")
    if not GROQ_API_KEY:
        raise RuntimeError("GROQ_API_KEY not set")

    client = Groq(api_key=GROQ_API_KEY)

    prompt = f"""
Use ONLY the context below.

Context:
{context}

Question:
{question}

Answer:
"""

    resp = client.chat.completions.create(
        model=GROQ_MODEL,
        messages=[{"role": "user", "content": prompt}],
        temperature=0.2,
        max_tokens=300
    )

    return resp.choices[0].message.content.strip()

# -------------------------
# MAIN ORCHESTRATOR
# -------------------------
def answer_via_kg_and_semantics(symptoms, additional_info, question):
    tokens = extract_query_entities(symptoms, additional_info)
    seeds = match_graph_nodes(tokens)
    subg = expand_subgraph(seeds)
    context = subgraph_to_text(subg)

    if GROQ_API_KEY:
        return generate_with_groq(question, context)
    return context

# -------------------------
# EXAMPLE
# -------------------------
if __name__ == "__main__":
    symptoms = ["Fever", "Fatigue"]
    info = "Mild fever and headache for two days."
    q = "Which OTC drugs are safe?"

    print(answer_via_kg_and_semantics(symptoms, info, q))
