# Part 5. Augumenting search results

In this part we will create additional functionality for richer information retrieval and setup a simple RAG system running on local ollama server. The goal of such system is to act as a medical assistant and retrieve accurate information about a given disease using a knowledge base.

## Tools

In [66]:
import pandas as pd
from IPython.display import display, Markdown
import json
import requests
import pprint
import math

from neo4j import Driver
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

from utils.generic import get_driver, Models
from utils.index_search_helpers import combined_search

In [2]:
driver = get_driver()

In [3]:
embed_model = HuggingFaceEmbedding(model_name=Models.BAAI_BGE_SMALL_EN_V1_5.value)



In [4]:
df = pd.read_csv("../data/processed/ncbi_specific_disease_singular_id.csv", sep=",")

## Simple RAG with a local ollama server

Let us define a query that would retrieve disease description (or the description of its ancestor) and return it to the user.

In [10]:
description_query = """
    OPTIONAL MATCH (d:Disease)
    WHERE (d.DiseaseID IS NOT NULL AND ANY(id IN SPLIT(toString(d.DiseaseID), '|') WHERE id = $disease_id))
        OR (d.AltDiseaseIDs IS NOT NULL AND ANY(altId IN SPLIT(toString(d.AltDiseaseIDs), '|') WHERE altId = $disease_id))
    OPTIONAL MATCH path = (d)-[:SUB_CATEGORY_OF*0..]->(ancestor:Disease)
    WITH d, ancestor, path
    WHERE ancestor.Definition IS NOT NULL
    RETURN 
        d.Definition AS Definition,
        d.DiseaseName AS DiseaseName,
        ancestor.DiseaseName as AncestorName,
        ancestor.Definition AS AncestorDefinition,
        length(path) as Distance
    ORDER BY length(path) ASC
    LIMIT $limit
"""

def get_node_description_dict(node: dict, limit: 10, driver: Driver):
    with driver.session() as session:
        result = session.run(
            description_query,
            disease_id=node["MESH_ID"],
            limit=limit
        )
        return next(result, None)

In [70]:
def retrieve_entities(disease_name, disease_embedding, driver=driver, limit=5):
    candidates = combined_search(
        disease_name=disease_name,
        embedding=disease_embedding,
        driver=driver,
        limit=limit
        )
    
    entities = []
    
    for candidate in candidates:
        description = get_node_description_dict(candidate, limit=limit, driver=driver)
        if (pd.isna(description.get("Definition"))):
            entities.append({
                    "DiseaseID": candidate.get("MESH_ID"),
                    "DiseaseName": description.get("DiseaseName"),
                    "AncestorName": description.get("AncestorName"),
                    "AncestorDefinition": description.get("AncestorDefinition"),
                    "Distance": description.get("Distance"),
                })
        else:
            entities.append({
                    "DiseaseID": candidate.get("MESH_ID"),
                    "DiseaseName": description.get("DiseaseName"),
                    "Definition": description.get("Definition"),
                })
    
    return entities

We can test it on a single disease name.

In [7]:
test_disease_name = df.iloc[270]["Description"]
test_disease_id = df.iloc[270]["MESH ID"]
test_disease_ebbedding = df.iloc[270]["DiseaseEmbedding-BAAI-bge-small-en-v1_5"]

In [8]:
print(test_disease_name)
print(test_disease_id)

breast cancer
MESH:D001943


In [59]:
test = retrieve_entities(disease_name=test_disease_name, disease_embedding=json.loads(test_disease_ebbedding), driver=driver, limit=1)

In [60]:
test

[{'DiseaseID': 'MESH:D001943',
  'DiseaseName': 'Breast Neoplasms',
  'Definition': 'Tumors or cancer of the human BREAST.'}]

Now that we have the retriver - we can write the code needed for a simple RAG system using a local ollama server. We are using llama3 model as generator LLM with a medium temperature and a prompt that explains the task.

In [37]:
def query_ollama(prompt):
    url = "http://localhost:11434/api/generate"
    payload = {
        "model": "llama3",
        "prompt": prompt,
        "temperature": 0.5,
        "max_tokens": 1000,
    }

    headers = {
        "Content-Type": "application/json"
    }

    response = requests.post(url, headers=headers, data=json.dumps(payload), stream=True)

    if response.status_code == 200:
        full_response = ""
        for line in response.iter_lines():
            if line:
                json_line = json.loads(line.decode("utf-8"))
                part = json_line.get("response", "")
                #print(part, end="", flush=True)  # Print in real-time without newlines
                formatted_part = part.replace("\n", "\n\n")  # Markdown new lines
                full_response += formatted_part
        
        # Final rendering of the accumulated response in Markdown
        display(Markdown(full_response))
    else:
        return f"Error: {response.status_code}, {response.text}"

In [72]:
def build_prompt(disease_name, disease_embedding):
    entities = retrieve_entities(disease_name, disease_embedding)
    
    prompt = f"""Act as a medical expert. Return the retrieved information on a given {disease_name}.   
    If there are multiple entries with the identical information - then return only one.
    If there is no Definition, then return the AncestorDefinition.
    Always include the DiseaseID and DiseaseName.\n\n"""

    # Using a set to track and eliminate duplicates based on DiseaseID and DiseaseName
    seen = set()
    
    for entity in entities:
        disease_id = entity.get('DiseaseID', 'Unknown')
        disease_name = entity.get('DiseaseName', 'Unknown')
        
        # Avoiding duplicates
        if (disease_id, disease_name) in seen:
            continue
        seen.add((disease_id, disease_name))
        
        entry = f"DiseaseID: {disease_id}\nDiseaseName: {disease_name}"
        
        definition = entity.get('Definition')
        if definition and not (isinstance(definition, float) 
                               and math.isnan(definition)):
            entry += f"\nDefinition: {definition}"
        else:
            # If no Definition, use AncestorDefinition
            ancestor_definition = entity.get('AncestorDefinition')
            if ancestor_definition and not (isinstance(ancestor_definition, float) 
                                            and math.isnan(ancestor_definition)):
                entry += f"\nAncestorDefinition: {ancestor_definition}"
        
        # Adding Distance for completeness
        distance = entity.get('Distance', 'Unknown')
        entry += f"\nDistance: {distance}\n"
        
        prompt += "\n" + entry

    return prompt

And now we can combine this into a single pipeline that would:
1.	Receive a disease name as a string.
2.	Use the embedding model to generate an embedding for the disease name.
3.	Retrieve relevant nodes from the knowledge graph.
4.	Fetch the definitions of the retrieved nodes.
5.	Construct a prompt using the retrieved information.
6.	Pass this prompt to the Large Language Model (LLM).
7.	Return the generated response to the user.


In [15]:
def pipeline(disease_name):
    disease_embedding = embed_model.get_text_embedding(disease_name)

    prompt = build_prompt(disease_name, disease_embedding)
    
    response = query_ollama(prompt)
    
    return response

In [73]:
test_disease_name_res = pipeline(test_disease_name)

As a medical expert, I can provide you with the retrieved information on breast cancer. Since there are multiple entries with identical information, I will return only one.



**Retrieved Information:**



* **DiseaseID:** MESH:D001943

* **DiseaseName:** Breast Neoplasms

* **Definition:** Tumors or cancer of the human BREAST.

* (No additional information provided)



This information is a general definition of breast neoplasms, which encompasses various types of breast cancer.

Let us try another one.

In [17]:
test_disease_name_2 = df.iloc[2740]["Description"]
test_disease_id_2 = df.iloc[2740]["MESH ID"]

In [18]:
test_disease_name_2

'myotonic dystrophy'

In [19]:
test_disease_id_2

'MESH:D009223'

Let us verify the Definition and compare how the RAG interprets it.

In [49]:
def get_node_by_id(disease_id, driver=driver):
    with driver.session() as session:
        result = session.run(
            """OPTIONAL MATCH (d:Disease)
                WHERE (d.DiseaseID IS NOT NULL AND ANY(id IN SPLIT(toString(d.DiseaseID), '|') 
                    WHERE id = $disease_id))
                OR (d.AltDiseaseIDs IS NOT NULL AND ANY(altId IN SPLIT(toString(d.AltDiseaseIDs), '|')
                    WHERE altId = $disease_id))
                RETURN d.DiseaseName AS DiseaseName, d.Definition AS Definition
            """, disease_id = disease_id)

        pprint.pprint(result.data()[0])

In [50]:
get_node_by_id(test_disease_id_2)

{'Definition': 'Neuromuscular disorder characterized by PROGRESSIVE MUSCULAR '
               'ATROPHY; MYOTONIA, and various multisystem atrophies. Mild '
               'INTELLECTUAL DISABILITY may also occur. Abnormal TRINUCLEOTIDE '
               "REPEAT EXPANSION in the 3' UNTRANSLATED REGIONS of DMPK "
               'PROTEIN gene is associated with Myotonic Dystrophy 1. DNA '
               'REPEAT EXPANSION of zinc finger protein-9 gene intron is '
               'associated with Myotonic Dystrophy 2.',
 'DiseaseName': 'Myotonic Dystrophy'}


In [38]:
test_disease_name_res_2 = pipeline(test_disease_name_2)

As a medical expert, I can provide you with the retrieved information on Myotonic Dystrophy:



**DiseaseID:** MESH:D009223

**DiseaseName:** Myotonic Dystrophy



**Definition:**

Myotonic Dystrophy is a neuromuscular disorder characterized by progressive muscular atrophy, myotonia, and various multisystem atrophies. Mild intellectual disability may also occur.



**Associated Genetic Abnormalities:**



* Abnormal trinucleotide repeat expansion in the 3' untranslated regions of DMPK protein gene is associated with Myotonic Dystrophy 1.

* DNA repeat expansion of zinc finger protein-9 gene intron is associated with Myotonic Dystrophy 2.



This information provides a comprehensive overview of the disease, including its definition and genetic associations.

This is a very good example of how LLM can use the retrieved information and augument it to a more user-friendly shape. Let us verify how the system behaves if there is no definition given on a particular node.

In [51]:
get_node_by_id("MESH:C563535")

{'Definition': nan, 'DiseaseName': 'Myotonic Myopathy with Cylindrical Spirals'}


In [76]:
test_disease_name_res_3 = pipeline("Myotonic Myopathy with Cylindrical Spirals")

As a medical expert, I've retrieved the information on Myotonic Myopathy with Cylindrical Spirals. Here's what I found:



**DiseaseID:** MESH:C563535

**DiseaseName:** Myotonic Myopathy with Cylindrical Spirals



**Definition:**

Myotonic myopathy with cylindrical spirals is a rare, autosomal dominant disorder characterized by slow relaxation of muscle fibers, leading to muscle stiffness, wasting, and weakness. This condition is also known for its unique histopathological feature of cylindrical spiral-shaped structures in the affected muscles.



**Additional Information:**

Myotonic myopathy with cylindrical spirals typically presents with progressive muscle weakness, wasting, and stiffness, affecting various muscle groups, including those involved in walking, grasping, and swallowing. The condition is usually diagnosed based on clinical findings, electromyography (EMG), and muscle biopsy. There is currently no cure for this disorder, but various treatments, such as physical therapy, medications, and surgery, may help manage its symptoms.



Please note that the information provided is a summary of the available data and should not be considered a substitute for professional medical advice or diagnosis. If you have any further questions or concerns, I recommend consulting with a qualified healthcare provider.

## Summary

We have experimented with a different setup for entity linking and retrieval in the previous parts, and this notebook takes the best configuration and builds on top of it as a medical assistant, that can help to retrieve relevant information on diseases. Having a knowledge base built as a knowledge graph allows us to retrieve relevant information even if it is not available for the given node and use its parent instead. This is just one example of the potential usage of the entity linking we have developed, and it does have its limitations. However, we believe that such system have a great potential for future use and a room for improvements.