In [None]:
import json
import re
import torch
import faiss
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM


In [None]:
with open("../data/medical_final_dataset.json", "r") as f:
    data = json.load(f)

print("Loaded records:", len(data))


In [None]:
rag_docs = []

for item in data:
    text = item["response"]
    if item.get("context"):
        text = item["context"] + " " + text

    rag_docs.append({
        "text": text,
        "metadata": item.get("metadata", {})
    })

print("RAG docs:", len(rag_docs))


In [None]:
def chunk_text(text, chunk_size=250, overlap=40):
    words = text.split()
    chunks = []
    start = 0

    while start < len(words):
        end = start + chunk_size
        chunks.append(" ".join(words[start:end]))
        start = end - overlap
        if start < 0:
            start = 0

    return chunks


In [None]:
chunks = []

for doc in rag_docs:
    for c in chunk_text(doc["text"]):
        chunks.append({
            "text": c,
            "metadata": doc["metadata"]
        })

print("Total chunks:", len(chunks))


In [None]:
biobert_name = "dmis-lab/biobert-base-cased-v1.1"
tokenizer = AutoTokenizer.from_pretrained(biobert_name)
model = AutoModel.from_pretrained(biobert_name)
model.eval()


In [None]:
def mean_pooling(output, mask):
    token_embeddings = output.last_hidden_state
    mask = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return (token_embeddings * mask).sum(1) / mask.sum(1)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

def embed_chunks(chunks, batch_size=32):
    embeddings = []

    for i in tqdm(range(0, len(chunks), batch_size)):
        batch = [c["text"] for c in chunks[i:i+batch_size]]

        enc = tokenizer(
            batch,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            out = model(**enc)
            emb = mean_pooling(out, enc["attention_mask"])

        embeddings.append(emb.cpu().numpy())

    return np.vstack(embeddings).astype("float32")


In [None]:
embeddings = embed_chunks(chunks, batch_size=32)
print("Embeddings shape:", embeddings.shape)


In [None]:
faiss.normalize_L2(embeddings)

index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)

print("FAISS index size:", index.ntotal)


In [None]:
np.save("embeddings.npy", embeddings)
faiss.write_index(index, "faiss.index")

with open("chunks.json", "w") as f:
    json.dump(chunks, f)


In [None]:
def extract_intent(question):
    q = question.lower()
    if "symptom" in q or "sign" in q:
        return "symptoms"
    if "risk" in q or "cause" in q:
        return "risk"
    if "treatment" in q:
        return "treatment"
    if "what is" in q or "define" in q:
        return "definition"
    return None


In [None]:
def intent_match(text, intent):
    t = text.lower()

    if intent == "symptoms":
        return any(k in t for k in [
            "symptom", "sign", "cough", "shortness of breath", "pain", "fatigue"
        ])

    if intent == "risk":
        return any(k in t for k in [
            "risk", "smoking", "exposure"
        ])

    if intent == "treatment":
        return any(k in t for k in [
            "treatment", "surgery", "chemotherapy", "radiation"
        ])

    if intent == "definition":
        return "is a disease" in t or "is a type" in t

    return True


In [None]:
def retrieve(query, top_k=10):
    enc = tokenizer(
        query,
        truncation=True,
        padding=True,
        max_length=512,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        out = model(**enc)
        q_emb = mean_pooling(out, enc["attention_mask"]).cpu().numpy()

    faiss.normalize_L2(q_emb)
    _, idx = index.search(q_emb, top_k)

    disease = extract_disease(query)
    intent = extract_intent(query)

    results = []

    for i in idx[0]:
        text = chunks[i]["text"]

        # Disease filter
        if disease and disease not in text.lower():
            continue

        # Intent filter
        if intent and not intent_match(text, intent):
            continue

        results.append(text)

        if len(results) >= 3:
            break

    return "\n\n".join(results)


In [None]:
print(retrieve("What are the symptoms of non-small cell lung cancer?"))


In [None]:
llama_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

llama_tokenizer = AutoTokenizer.from_pretrained(llama_name)
llama_model = AutoModelForCausalLM.from_pretrained(
    llama_name,
    torch_dtype=torch.float16
).to("cuda").eval()

print(next(llama_model.parameters()).device)


In [None]:
import re

def split_sentences(text):
    sentences = re.split(r'(?<=[.!?])\s+', text.strip())
    return [s.strip() for s in sentences if s.strip()]


In [None]:
def strict_context_filter(answer, context):
    ctx_sentences = split_sentences(context.lower())
    final = []

    for sent in split_sentences(answer):
        s = sent.lower()

        # Keep ONLY if sentence is clearly grounded in context
        if any(cs in s or s in cs for cs in ctx_sentences):
            final.append(sent)

    if not final:
        return "I don't have enough information."

    # Max 2â€“3 sentences
    return " ".join(final[:3])


In [None]:
def build_prompt(context, question):
    return f"""
You are a medical assistant.

TASK:
Select the exact sentences from the context that answer the question.
Do NOT add new information.
Do NOT explain.
Do NOT use lists.
If the answer is not present, say: I don't have enough information.

Context:
{context}

Question:
{question}

Answer (use only sentences from context):
"""


In [None]:
def answer(question):
    context = retrieve(question)

    prompt = build_prompt(context, question)

    inputs = llama_tokenizer(prompt, return_tensors="pt")
    inputs = {k: v.to("cuda") for k, v in inputs.items()}

    with torch.no_grad():
        out = llama_model.generate(
            **inputs,
            max_new_tokens=60,       # VERY SMALL
            do_sample=False,         # ðŸ”´ NO CREATIVITY
            repetition_penalty=1.2,
            eos_token_id=llama_tokenizer.eos_token_id
        )

    raw = llama_tokenizer.decode(out[0], skip_special_tokens=True)

    return strict_context_filter(raw, context)


In [None]:
print(answer("What are the symptoms of non-small cell lung cancer?"))
