# RAG on FHIR with Standard Deviation 

This notebook shows how to use Standard Deviation instead of K-Nearest Neighbor to find resources to supply for RAG. 

This notebook assumes you have already loaded data into the Knowledge Graph as per the notebook [FHIR_GRAPHS](https://github.com/samschifman/RAG_on_FHIR/blob/main/RAG_on_FHIR_with_KG/FHIR_GRAPHS.ipynb). This notebook is not intended to be run on its own. 

## Disclaimer
Nothing provided here is guaranteed or warrantied to work. It is provided as is and has not been tested extensively. Using this notebook is at the risk of the user. 

## Install and Import Libraries

In [None]:
!pip install sentence-transformers langchain neo4j

In [None]:
import os

from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOllama
from langchain import PromptTemplate

# Imports from other local python files
from NEO4J_Graph import Graph

## Establish Database Connection

The cell connects to the Neo4J instance. It relies on several environment variables. 

**PLEASE NOTE**: The variable have been changed to support multiple databases in the same instance. 

| Variable            | Description                          | Sample Value          |
|---------------------|--------------------------------------|-----------------------|
| FHIR_GRAPH_URL      | Where to find the instance of Neo4j. | bolt://localhost:7687 |
| FHIR_GRAPH_USER     | The username for the database.       | neo4j                 |
| FHIR_GRAPH_PASSWORD | The password for the database.       | password              |
| FHIR_GRAPH_DATABASE | The name of the database instance.   | neo4j                 |

In [None]:
NEO4J_URI = os.getenv('FHIR_GRAPH_URL')
USERNAME = os.getenv('FHIR_GRAPH_USER')
PASSWORD = os.getenv('FHIR_GRAPH_PASSWORD')
DATABASE = os.getenv('FHIR_GRAPH_DATABASE')

graph = Graph(NEO4J_URI, USERNAME, PASSWORD, DATABASE)

## Setup Prompt Templates

This cell sets the prompt template to use when calling the LLM.

In [None]:
my_prompt = '''
System: The context below contains entries about the patient's healthcare. 
Please limit your answer to the information provided in the context. Do not make up facts. 
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----------------
{context}
Human: {question}
'''

prompt = PromptTemplate.from_template(my_prompt)

## Define some helper methods for showing results. 

These methods are used later to make it easy to see what will be sent to the LLM and ask it questions. 

In [None]:
def show_similar(_question, _vector_index):
    response = _vector_index.similarity_search(_question)
    for page in response:
        print(page.page_content)
        print(' ')
    print(' ')
    print(f'Total number of responses: {len(response)}')


def show_answers(_question, _vector_qa, number_of_times):
    print('Answering...')
    for i in range(number_of_times):
        answer = _vector_qa.run(_question)
        print(f'Answer {i+1}:')
        print(answer)
        print(' ')

## Pick the LLM model to use

Ollama can run multiple models. I had the most luck with mistral. However, you could try others. The list of possible 
models is [here](https://ollama.ai/library).

In [None]:
ollama_model = 'mistral'

## Define Standard Deviation Wrapper

This cell contains the code that wraps a Neo4jVector Store in code that will calculate the standard deviation and return results within 1 Std Dev of the closest value. By default it looks at the 1,000 nearest neighbors to calculate the Std Dev. 

Optionally, the wrapper can also deduplicate the values returned. This can be useful when using this in conjunction with a `retrieval_query`, which is not shown here. I leave it to you to combine this with what is shown in [FHIR_GRAPHS](https://github.com/samschifman/RAG_on_FHIR/blob/main/RAG_on_FHIR_with_KG/FHIR_GRAPHS.ipynb), but once they are combined being able to remove duplicates is key. 

In [None]:
from typing import List, Optional, Any, Iterable, Type

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore, VST

import numpy as np


def deduplicate(existing_entries: List, doc: Document) -> Document:
    entries = doc.page_content.split("Entry:\n")
    entries = list(filter(lambda e: e and e not in existing_entries, entries))
    existing_entries += entries
    doc.page_content = 'Primary Entry:\n' + 'Supporting Entry:\n'.join(entries)
    return doc


class StandardDevNeo4jVector(VectorStore):

    def __init__(self, vector_store: Neo4jVector, should_deduplicate:bool = False):
        self._vector_store = vector_store
        self._deduplicate = should_deduplicate

    def add_texts(self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, **kwargs: Any) -> List[str]:
        return self._vector_store.add_texts(texts.metadatas, kwargs)

    def similarity_search(self, query: str, k: int = 1000, **kwargs: Any) -> List[Document]:
        docs = self._vector_store.similarity_search_with_score(query=query, k=k, **kwargs)
        scores = [x[1] for x in docs]
        standard_deviation = np.std(scores)
        threshold = scores[0] - standard_deviation

        filtered_docs = list(filter(lambda pair: pair[1] > threshold, docs))
        filtered_docs = list(map(lambda pair: pair[0], filtered_docs))

        if self._deduplicate:
            existing = list()
            for i, doc in enumerate(filtered_docs):
                filtered_docs[i] = deduplicate(existing, doc)
        return filtered_docs

    @classmethod
    def from_texts(cls: Type[VectorStore],
                   texts: List[str],
                   embedding: Embeddings,
                   metadatas: Optional[List[dict]] = None,
                   **kwargs: Any) -> VectorStore:
        return StandardDevNeo4jVector(Neo4jVector.from_texts(texts, embedding, metadatas, **kwargs))


## Create Vector Index Reference and QA Chain

Again, here I assume you have run the code in [FHIR_GRAPHS](https://github.com/samschifman/RAG_on_FHIR/blob/main/RAG_on_FHIR_with_KG/FHIR_GRAPHS.ipynb) to create the `fhir_text` index. This code wraps that index in a Stnd Dev wrapper and passes that to the QA chain. 

In [None]:
from langchain_community.embeddings import HuggingFaceBgeEmbeddings


vector_index = StandardDevNeo4jVector(Neo4jVector.from_existing_index(
    HuggingFaceBgeEmbeddings(model_name='BAAI/bge-small-en-v1.5'),
    url=NEO4J_URI,
    username=USERNAME,
    password=PASSWORD,
    database=DATABASE,
    index_name='fhir_text'
))


vector_qa = RetrievalQA.from_chain_type(
    llm=ChatOllama(model=ollama_model),
    chain_type='stuff',
    retriever=vector_index.as_retriever(),
    verbose=False,
    chain_type_kwargs={'verbose': False, 'prompt': prompt}
)

## Ask a Question of the RAG

This cell defines a question and then asks it. It will first show what is returned from the Stnd Dev Vector Index and then ask the question of the LLM including those results in the prompt. 

In [None]:
question = "What are the blood pressure readings?"

show_similar(question, vector_index)
print(' ')
print(' ')

show_answers(question, vector_qa, 1)

**Disclaimer:** Nothing provided here is guaranteed or warrantied to work. It is provided as is and has not been tested extensively. Using this notebook is at the risk of the user. 

Copyright &copy; 2024 Sam Schifman