In [1]:
documents = [
    {
        "id": "doc1",
        "text": "Diabetes is caused by insulin resistance. Metformin is used to treat diabetes."
    },
    {
        "id": "doc2",
        "text": "BRCA1 is a gene associated with breast cancer. Tamoxifen treats breast cancer."
    },
    {
        "id": "doc3",
        "text": "Cancer is caused by genetic mutations. Chemotherapy treats cancer."
    },
    {
        "id": "doc4",
        "text": "Hypertension is linked to high blood pressure. Amlodipine is commonly prescribed for hypertension."
    },
    {
        "id": "doc5",
        "text": "Alzheimer’s disease is associated with amyloid-beta plaque accumulation. Donepezil is used to manage Alzheimer’s symptoms."
    },
    {
        "id": "doc6",
        "text": "Asthma is a chronic inflammatory airway disease. Inhaled corticosteroids are used to control asthma."
    },
    {
        "id": "doc7",
        "text": "COVID-19 is caused by the SARS-CoV-2 virus. Antiviral drugs like remdesivir are used for treatment."
    },
    {
        "id": "doc8",
        "text": "Parkinson’s disease is caused by the loss of dopamine-producing neurons. Levodopa is used to treat Parkinson’s disease."
    },
    {
        "id": "doc9",
        "text": "Tuberculosis is caused by Mycobacterium tuberculosis. Rifampicin is a key drug in tuberculosis treatment."
    },
    {
        "id": "doc10",
        "text": "Depression is linked to neurotransmitter imbalance. Selective serotonin reuptake inhibitors are used to treat depression."
    }
]


# Vector RAG (FAISS)

In [2]:
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer

embedder = SentenceTransformer("all-MiniLM-L6-v2")

texts = [d["text"] for d in documents]
doc_ids = [d["id"] for d in documents]

embeddings = embedder.encode(texts, convert_to_numpy=True)

index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)

  from .autonotebook import tqdm as notebook_tqdm


Vector search

In [3]:
def vector_search(query, k=2):
    q_emb = embedder.encode([query], convert_to_numpy=True)
    D, I = index.search(q_emb, k)
    return [documents[i] for i in I[0]]

# Graph RAG (Neo4j)

Entity + relation extraction (simple rule-based)

In [4]:
import spacy
nlp = spacy.load("en_core_web_sm")

def extract_entities(text):
    doc = nlp(text)
    return [ent.text for ent in doc.ents]

Neo4j connection

In [None]:
from neo4j import GraphDatabase

driver = GraphDatabase.driver(
    "bolt://localhost:7687",
    auth=("neo4j", "password")   
)

Build graph

In [6]:
def build_graph(docs):
    with driver.session() as session:
        session.run("MATCH (n) DETACH DELETE n")

        for d in docs:
            entities = extract_entities(d["text"])
            for e in entities:
                session.run(
                    "MERGE (:Entity {name:$name})",
                    name=e
                )

In [7]:
build_graph(documents)

# Graph Query (Reasoning)

In [8]:
def graph_search(entity):
    query = """
    MATCH (e:Entity)
    WHERE toLower(e.name) CONTAINS toLower($entity)
    RETURN e.name
    """
    with driver.session() as session:
        result = session.run(query, entity=entity)
        return [r["e.name"] for r in result]

# Hybrid RAG (KEY PART)

In [9]:
def hybrid_retrieval(query):
    # Step 1: Vector recall
    vector_docs = vector_search(query, k=2)

    # Step 2: Extract entities from retrieved docs
    expanded_entities = set()
    for d in vector_docs:
        ents = extract_entities(d["text"])
        expanded_entities.update(ents)

    # Step 3: Graph expansion
    graph_results = []
    for ent in expanded_entities:
        graph_results.extend(graph_search(ent))

    return {
        "vector_docs": vector_docs,
        "graph_entities": list(set(graph_results))
    }

In [None]:
import requests

url = "https://api.groq.com/openai/v1/models"

headers = {
    "Authorization": f"Bearer {grok_api_key}"
}

response = requests.get(url, headers=headers)
response.raise_for_status()

models = response.json()["data"]

for m in models:
    print(m["id"])


meta-llama/llama-guard-4-12b
openai/gpt-oss-safeguard-20b
canopylabs/orpheus-v1-english
allam-2-7b
canopylabs/orpheus-arabic-saudi
meta-llama/llama-prompt-guard-2-22m
qwen/qwen3-32b
meta-llama/llama-4-maverick-17b-128e-instruct
llama-3.3-70b-versatile
meta-llama/llama-prompt-guard-2-86m
whisper-large-v3
groq/compound-mini
moonshotai/kimi-k2-instruct-0905
openai/gpt-oss-120b
whisper-large-v3-turbo
llama-3.1-8b-instant
moonshotai/kimi-k2-instruct
meta-llama/llama-4-scout-17b-16e-instruct
openai/gpt-oss-20b
groq/compound


In [11]:
import requests

def generate_answer(query, context):
    url = "https://api.groq.com/openai/v1/chat/completions"

    headers = {
        "Authorization": f"Bearer {GROQ_API_KEY}",
        "Content-Type": "application/json"
    }

    payload = {
        "model": "llama-3.1-8b-instant",
        "messages": [
            {
                "role": "system",
                "content": "Answer strictly using the provided context. If the answer is not in the context, say 'Not found in context.'"
            },
            {
                "role": "user",
                "content": f"Question: {query}\n\nContext: {context}"
            }
        ],
        "temperature": 0.2,
        "max_tokens": 200
    }

    response = requests.post(url, headers=headers, json=payload)

    if response.status_code != 200:
        raise RuntimeError(f"Groq API error {response.status_code}: {response.text}")

    return response.json()["choices"][0]["message"]["content"]


# Example 1

In [12]:
query = "Which drug treats a disease caused by genetic mutation?"

result = hybrid_retrieval(query)

context = " ".join([d["text"] for d in result["vector_docs"]])
context += " Related entities: " + ", ".join(result["graph_entities"])

answer = generate_answer(query, context)

print("=== Vector Retrieved Docs ===")
for d in result["vector_docs"]:
    print("-", d["text"])

print("\n=== Graph Entities ===")
print(result["graph_entities"])

print("\n=== Final Answer ===")
print(answer)

=== Vector Retrieved Docs ===
- Cancer is caused by genetic mutations. Chemotherapy treats cancer.
- Tuberculosis is caused by Mycobacterium tuberculosis. Rifampicin is a key drug in tuberculosis treatment.

=== Graph Entities ===
['Rifampicin']

=== Final Answer ===
Not found in context.


# Example 2

In [13]:
query = "What drug is used to treat diabetes?"

result = hybrid_retrieval(query)

context = " ".join([d["text"] for d in result["vector_docs"]])
context += " Related entities: " + ", ".join(result["graph_entities"])

answer = generate_answer(query, context)

print("=== Vector Retrieved Docs ===")
for d in result["vector_docs"]:
    print("-", d["text"])

print("\n=== Graph Entities ===")
print(result["graph_entities"])

print("\n=== Final Answer ===")
print(answer)

=== Vector Retrieved Docs ===
- Diabetes is caused by insulin resistance. Metformin is used to treat diabetes.
- Alzheimer’s disease is associated with amyloid-beta plaque accumulation. Donepezil is used to manage Alzheimer’s symptoms.

=== Graph Entities ===
['Metformin', 'Donepezil']

=== Final Answer ===
Metformin.
