# Part 5. Augumenting search results

In this part we will create additional functionality for richer information retrieval and setup a simple RAG system running on local ollama server. The goal of such system is to act as a medical assistant and retrieve accurate information about a given disease using a knowledge base.

## Tools

In [103]:
import pandas as pd
import json
import requests

from neo4j import Driver
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

from utils.generic import get_driver, Models
from utils.index_search_helpers import combined_search

In [2]:
driver = get_driver()

In [104]:
embed_model = HuggingFaceEmbedding(model_name=Models.BAAI_BGE_SMALL_EN_V1_5.value)



In [3]:
df = pd.read_csv("../data/processed/ncbi_specific_disease_singular_id.csv", sep=",")

## Simple RAG with a local ollama server

Let us define a query that would retrieve disease description (or the description of its ancestor) and return it to the user.

In [26]:
description_query = """
    OPTIONAL MATCH (d:Disease)
    WHERE (d.DiseaseID IS NOT NULL AND ANY(id IN SPLIT(toString(d.DiseaseID), '|') WHERE id = $disease_id))
        OR (d.AltDiseaseIDs IS NOT NULL AND ANY(altId IN SPLIT(toString(d.AltDiseaseIDs), '|') WHERE altId = $disease_id))
    OPTIONAL MATCH path = (d)-[:SUB_CATEGORY_OF*0..]->(ancestor:Disease)
    WITH d, ancestor, path
    WHERE ancestor.Definition IS NOT NULL
    RETURN 
        d.Definition AS Definition,
        ancestor.DiseaseName as Ancestor_name,
        ancestor.Definition AS Ancestor_description,
        length(path) as Distance
    ORDER BY length(path) ASC
    LIMIT $limit
"""

def get_node_description_dict(node: dict, limit: 10, driver: Driver):
    with driver.session() as session:
        result = session.run(
            description_query,
            disease_id=node["MESH_ID"],
            limit=limit
        )
        return next(result, None)

In [94]:
def retrieve_entities(disease_name, disease_embedding, driver=driver, limit=5):
    candidates = combined_search(
        disease_name=disease_name,
        embedding=disease_embedding,
        driver=driver,
        limit=limit
        )
    
    entities = []
    
    for candidate in candidates:
        description = get_node_description_dict(candidate, limit=5, driver=driver)
        if (description.get("Definition") is None):
            entities.append({
                    "DiseaseID": candidate.get("MESH_ID"),
                    "Ancestor_name": description.get("Ancestor_name"),
                    "Ancestor_description": description.get("Ancestor_description"),
                    "Distance_from_ancestor": description.get("Distance"),
                })
        else:
            entities.append({
                    "DiseaseID": candidate.get("MESH_ID"),
                    "Definition": description.get("Definition"),
                })
    
    return entities

We can test it on a single disease name.

In [34]:
test_disease_name = df.iloc[270]["Description"]
test_disease_id = df.iloc[270]["MESH ID"]
test_disease_ebbedding = df.iloc[270]["DiseaseEmbedding-BAAI-bge-small-en-v1_5"]

In [35]:
print(test_disease_name)
print(test_disease_id)

breast cancer
MESH:D001943


In [39]:
test = retrieve_entities(disease_name=test_disease_name, disease_embedding=json.loads(test_disease_ebbedding), driver=driver, limit=1)

In [40]:
test

[{'MESH:D001943': {'Definition:': 'Tumors or cancer of the human BREAST.',
   'Ancestor_name': 'Breast Neoplasms',
   'Ancestor_description': 'Tumors or cancer of the human BREAST.',
   'Distance': 0}}]

Now that we have the retriver - we can write the code needed for a simple RAG system using a local ollama server. We are using llama3 model as generator LLM with a medium temperature and a prompt that explains the task.

In [111]:
def query_ollama(prompt):
    url = "http://localhost:11434/api/generate"
    payload = {
        "model": "llama3",
        "prompt": prompt,
        "temperature": 0.5,
        "max_tokens": 100,
    }

    headers = {
        "Content-Type": "application/json"
    }

    response = requests.post(url, headers=headers, data=json.dumps(payload), stream=True)

    if response.status_code == 200:
        for line in response.iter_lines():
            if line:
                json_line = json.loads(line.decode("utf-8"))
                part = json_line.get("response", "")
                print(part, end="", flush=True)  # Print in real-time without newlines
    else:
        return f"Error: {response.status_code}, {response.text}"

In [137]:
def build_prompt(disease_name, disease_embedding):
    entities = retrieve_entities(disease_name, disease_embedding)
    
    prompt = f"""Give me retrieved information on {disease_name}. 
    Do not add anything else. 
    If there are multiple entries with the identical information - then return only one.
    Always include the DiseaseID.\n\n"""
    
    for entity in entities:
        prompt += "\n" + "\n".join(f"{k}: {v}" for k, v in entity.items())    
    return prompt

And now we can combine this into a single pipeline that would:
- recieve a disease name as a string;
- call the embedding model to get the embedding for the name;
- rertieve relevant nodes
- retrieve their definitions
- build a prompt
- pass this promnt into the LLM
- return the response to the user

In [116]:
def pipeline(disease_name):
    disease_embedding = embed_model.get_text_embedding(disease_name)

    prompt = build_prompt(disease_name, disease_embedding)
    
    response = query_ollama(prompt)
    
    return response

In [138]:
test_disease_name_res = pipeline(test_disease_name)

Here is the retrieved information on breast cancer:

DiseaseID: MESH:D001943
Definition: Tumors or cancer of the human BREAST.

Let us try another one.

In [118]:
test_disease_name_2 = df.iloc[2740]["Description"]
test_disease_id_2 = df.iloc[2740]["MESH ID"]

In [119]:
test_disease_name_2

'myotonic dystrophy'

In [120]:
test_disease_id_2

'MESH:D009223'

In [139]:
test_disease_name_res_2 = pipeline(test_disease_name_2)

DiseaseID: MESH:D009223
Definition: Neuromuscular disorder characterized by PROGRESSIVE MUSCULAR ATROPHY; MYOTONIA, and various multisystem atrophies. Mild INTELLECTUAL DISABILITY may also occur. Abnormal TRINUCLEOTIDE REPEAT EXPANSION in the 3' UNTRANSLATED REGIONS of DMPK PROTEIN gene is associated with Myotonic Dystrophy 1. DNA REPEAT EXPANSION of zinc finger protein-9 gene intron is associated with Myotonic Dystrophy 2.

## Summary

We have experimented with a different setup for entity linking and retrieval in the previous parts, and this notebook takes the best configuration and builds on top of it as a medical assistant, that can help to retrieve relevant information on diseases. Having a knowledge base built as a knowledge graph allows us to retrieve relevant information even if it is not available for the given node and use its parent instead. This is just one example of the potential usage of the entity linking we have developed, and it does have its limitations. However, we believe that such system have a great potential for future use and a room for improvements.