# RAG on FHIR: Using NLP to load Observations

This notebook uses [John Snow Labs SparkNLP for Healthcare](https://www.johnsnowlabs.com/) to find blood pressure readings in clinical notes, create FHIR Observation resources, and load them as nodes in a Neo4J graph. It looks for the clinical notes in DocumentReference resources that are already in the graph. 

This notebook assumes you have already loaded data into the Knowledge Graph as per the notebook [FHIR_GRAPHS](https://github.com/samschifman/RAG_on_FHIR/blob/main/RAG_on_FHIR_with_KG/FHIR_GRAPHS.ipynb). This notebook is not intended to be run on its own. 

This notebook is intended only as an example of what could be done. It is not a full implementation. For example, it assumes the date of the DocumentReference should be the date resulting Observation, regardless of the date found in the note. 

## Disclaimer
Nothing provided here is guaranteed or warrantied to work. It is provided as is and has not been tested extensively. Using this notebook is at the risk of the user. 

In [None]:
# Imports needed

import json
import os
import base64

import sparknlp_jsl
from sparknlp_jsl.annotator import *
from sparknlp_display import RelationExtractionVisualizer

from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.sql import functions as F


import pandas as pd
pd.set_option('display.max_colwidth', 200)

import warnings
warnings.filterwarnings('ignore')

# Imports from other local python files
from NEO4J_Graph import Graph
from FHIR_to_graph import resource_to_node, resource_to_edges, flat_fhir_to_json_str, flatten_fhir, resource_name, FHIR_to_string

import uuid
import re


## Load License from John Snow Labs

This cell loads the license for SparkNLP for Healthcare. It assumes the license is in a directory you need to add. 

Any method of loading the license into the environment that comes from JSL's documentation should work. 

In [None]:
with open('working/license/spark_nlp.json') as f:
    license_keys = json.load(f)

os.environ.update(license_keys)

## Install Libraries

Below is the set of libraries that worked for me. Using the latest version of these libraries failed due to some incompatibilities. However, I cannot guarantee that they will work for you exactly. You may need to fiddle with the version numbers.

In [None]:
# Installing pyspark and spark-nlp
! pip install --upgrade -q pyspark==3.4.1 spark-nlp==$PUBLIC_VERSION

# Installing NLU
! pip install --upgrade --q nlu==4.0.1rc4 --no-dependencies

# Installing Spark NLP Healthcare
! pip install --upgrade -q spark-nlp-jsl==$JSL_VERSION  --extra-index-url https://pypi.johnsnowlabs.com/$SECRET

# Installing Spark NLP Display Library for visualization
! pip install -q spark-nlp-display

## Create the Spark Session

This cell creates the Spark session needed to run the NLP. 

Again, this is what worked for me. It mostly follows JSL's documentation, but I tried several examples from them before I found one that worked on my system. However, your system may be different, so please refer to their documentation if you have problems. 

In [None]:
params = {"spark.driver.memory":"16G",
          "spark.kryoserializer.buffer.max":"2000M",
          "spark.driver.maxResultSize":"2000M"}

print("Spark NLP Version :", sparknlp.version())
print("Spark NLP_JSL Version :", sparknlp_jsl.version())

def start(SECRET):
    builder = SparkSession.builder \
        .appName("Spark NLP Licensed") \
        .master("local[*]") \
        .config("spark.driver.memory", "16G") \
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
        .config("spark.kryoserializer.buffer.max", "2000M") \
        .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:"+sparknlp.version()) \
        .config("spark.jars", "https://pypi.johnsnowlabs.com/"+SECRET+"/spark-nlp-jsl-"+sparknlp_jsl.version()+".jar")

    return builder.getOrCreate()

spark = sparknlp_jsl.start(license_keys['SECRET'],params=params, apple_silicon=True)

spark

## Creat the NLP Pipeline

This cell defines the NLP pipeline that will be used to find the blood pressures. 

In [None]:
# Annotator that transforms a text column from dataframe into an Annotation ready for NLP
documentAssembler = DocumentAssembler() \
    .setInputCol("text") \
    .setOutputCol("document")

sentenceDetector = SentenceDetectorDLModel.pretrained("sentence_detector_dl_healthcare","en","clinical/models") \
    .setInputCols(["document"]) \
    .setOutputCol("sentence")

# Tokenizer splits words in a relevant format for NLP
tokenizer = Tokenizer() \
    .setInputCols(["sentence"]) \
    .setOutputCol("token")

pos_tagger = PerceptronModel() \
    .pretrained("pos_clinical", "en", "clinical/models") \
    .setInputCols(["sentence", "token"]) \
    .setOutputCol("pos_tag")

# Clinical word embeddings trained on PubMED dataset
word_embeddings = WordEmbeddingsModel.pretrained("embeddings_clinical","en","clinical/models") \
    .setInputCols(["sentence","token"]) \
    .setOutputCol("embeddings")

# NER model trained on i2b2 (sampled from MIMIC) dataset
clinical_ner = MedicalNerModel.pretrained("ner_jsl","en","clinical/models") \
    .setInputCols(["sentence","token","embeddings"]) \
    .setOutputCol("ner") \
    .setLabelCasing("upper") #decide if we want to return the tags in upper or lower case

ner_converter = NerConverterInternal() \
    .setInputCols(["sentence","token","ner"]) \
    .setOutputCol("ner_chunk")

dependency_parser = DependencyParserModel() \
    .pretrained("dependency_conllu", "en") \
    .setInputCols(["sentence", "pos_tag", "token"]) \
    .setOutputCol("dependency")

clinical_re_Model = RelationExtractionModel() \
    .pretrained("re_test_result_date", "en", 'clinical/models') \
    .setInputCols(["embeddings", "pos_tag", "ner_chunk", "dependency"]) \
    .setOutputCol("relation") \
    .setPredictionThreshold(0.0) \
    .setMaxSyntacticDistance(5) \
    .setRelationPairs(["blood_pressure-date", "date-blood_pressure"])


nlpPipeline = Pipeline(
    stages=[
        documentAssembler,
        sentenceDetector,
        tokenizer,
        pos_tagger,
        word_embeddings,
        clinical_ner,
        ner_converter,
        dependency_parser,
        clinical_re_Model
    ])


empty_data = spark.createDataFrame([[""]]).toDF("text")

model = nlpPipeline.fit(empty_data)

## Establish Database Connection

The cell connects to the Neo4J instance. It relies on several environment variables. 

**PLEASE NOTE**: The variable have been changed to support multiple databases in the same instance. 

| Variable            | Description                          | Sample Value          |
|---------------------|--------------------------------------|-----------------------|
| FHIR_GRAPH_URL      | Where to find the instance of Neo4j. | bolt://localhost:7687 |
| FHIR_GRAPH_USER     | The username for the database.       | neo4j                 |
| FHIR_GRAPH_PASSWORD | The password for the database.       | password              |
| FHIR_GRAPH_DATABASE | The name of the database instance.   | neo4j                 |

In [None]:
NEO4J_URI = os.getenv('FHIR_GRAPH_URL')
USERNAME = os.getenv('FHIR_GRAPH_USER')
PASSWORD = os.getenv('FHIR_GRAPH_PASSWORD')
DATABASE = os.getenv('FHIR_GRAPH_DATABASE')

graph = Graph(NEO4J_URI, USERNAME, PASSWORD, DATABASE)

## Find DocumentReference Resources

This cell uses Cypher to find `DocumentReference` resources already in the Knowledge Graph. 

To show that it is working: It then extracts the `attachment_data` and decodes the Base64. Finally it runs that through the NLP pipeline and shows the results. 

In [None]:
cyoher = """
match (n:DocumentReference) return n
"""

document_reference_nodes = graph.query(cyoher)

encoded = document_reference_nodes[0][0][0]["content_0_attachment_data"]
note = base64.b64decode(encoded).decode('ascii')
nlp = model.transform(spark.createDataFrame([[note]]).toDF("text"))
nlp.show(truncate=False)

## Define Method to Extract Blood Pressure from NLP

This method is able find where in the NLP results the blood pressure string is. 

It then prints it from the NLP run above to prove that it is working. 

In [None]:
def get_bp_str(_nlp):
    blood_pressure_str =  _nlp.select(
        F.explode(nlp.relation.metadata).alias('cols')
    ).filter(
        "cols['entity2']='BLOOD_PRESSURE'"
    ).select(
        F.expr("cols['chunk2']").alias("bp"),
        F.expr("cols['chunk1']").alias("date"),
    ).collect()[0]['bp']
    return blood_pressure_str

blood_pressure_str = get_bp_str(nlp)
print(blood_pressure_str)


## Define Method to Parse Blood Pressure String

This method can find the numeric components of the blood pressure within the string. 

In [None]:

def parse_bp(bp_str):
    matches = re.match(r'[a-zA-Z ]*(\d{1,3})/(\d{1,3})', bp_str)
    return int(matches.group(1)), int(matches.group(2))

print(parse_bp("BP was 137/88 mm Hg"))

## Define Method to Extract Replacement Values

There are number of values needed from the `DocumentReference` and `blood pressure string` to fill in the new `Observation` resource. This method consolidates finding all those values in one place. 

In [None]:
def get_replacements(doc_ref, bp_str):
    id = uuid.uuid4()
    patient = doc_ref[0]["subject_reference"]
    encounter = doc_ref[0]["context_encounter_0_reference"]
    date_str = doc_ref[0]["date"]
    systolic, diastolic = parse_bp(bp_str)
    return id, patient, encounter, date_str, systolic, diastolic

id, patient, encounter, date_str, systolic, diastolic = get_replacements(document_reference_nodes[0][0], blood_pressure_str)
print(f'{id},  {patient},  {encounter},  {date_str}, {systolic}/{diastolic}')

## Define Method to Create Observation

This cell contains a template blood pressure `Observation` resource and the method to fill it in.

In [None]:
TEMPLATE_OBSERVATION = """
{
        "resourceType": "Observation",
        "id": "[ID]",
        "meta": {
          "profile": [
            "http://hl7.org/fhir/us/core/StructureDefinition/us-core-blood-pressure"
          ]
        },
        "status": "final",
        "category": [
          {
            "coding": [
              {
                "system": "http://terminology.hl7.org/CodeSystem/observation-category",
                "code": "vital-signs",
                "display": "Vital signs"
              }
            ]
          }
        ],
        "code": {
          "coding": [
            {
              "system": "http://loinc.org",
              "code": "85354-9",
              "display": "Blood pressure panel with all children optional"
            }
          ],
          "text": "Blood pressure panel with all children optional"
        },
        "subject": {
          "reference": "[PATIENT]"
        },
        "encounter": {
          "reference": "[ENCOUNTER]"
        },
        "effectiveDateTime": "[DATE]",
        "issued": "[DATE]",
        "component": [
          {
            "code": {
              "coding": [
                {
                  "system": "http://loinc.org",
                  "code": "8462-4",
                  "display": "Diastolic Blood Pressure"
                }
              ],
              "text": "Diastolic Blood Pressure"
            },
            "valueQuantity": {
              "value": [DIASTOLIC],
              "unit": "mm[Hg]",
              "system": "http://unitsofmeasure.org",
              "code": "mm[Hg]"
            }
          },
          {
            "code": {
              "coding": [
                {
                  "system": "http://loinc.org",
                  "code": "8480-6",
                  "display": "Systolic Blood Pressure"
                }
              ],
              "text": "Systolic Blood Pressure"
            },
            "valueQuantity": {
              "value": [SYSTOLIC],
              "unit": "mm[Hg]",
              "system": "http://unitsofmeasure.org",
              "code": "mm[Hg]"
            }
          }
        ]
      }
"""

def create_observation(id, patient, encounter, date_str, systolic, diastolic):
    resource = TEMPLATE_OBSERVATION
    resource = resource.replace("[ID]", id)
    resource = resource.replace("[PATIENT]", patient)
    resource = resource.replace("[ENCOUNTER]", encounter)
    resource = resource.replace("[DATE]", date_str)
    resource = resource.replace("[SYSTOLIC]", systolic)
    resource = resource.replace("[DIASTOLIC]", diastolic)
    return resource

print(create_observation(str(id), patient, encounter, date_str, str(systolic), str(diastolic)))

## Create Observation Resources 

This cell iterates through the list of `DocumentReferences`, runs them through NLP, and creates `Observations` for them. 

In [None]:

observations = []
for doc_ref in document_reference_nodes[0]:
    encoded = doc_ref[0]["content_0_attachment_data"]
    note = base64.b64decode(encoded).decode('ascii')
    nlp = model.transform(spark.createDataFrame([[note]]).toDF("text"))
    blood_pressure_str = get_bp_str(nlp)
    id, patient, encounter, date_str, systolic, diastolic = get_replacements(doc_ref, blood_pressure_str)
    observations.append(create_observation(str(id), patient, encounter, date_str, str(systolic), str(diastolic)))
    
print(observations)

## Create Cypher to Add Nodes and Edges

This cell iterates through the list of `Observations` created above and constructs the Cypher queries needed to add the nodes in the DB.

In [None]:

def inferred_resource_to_node(resource):
    resource_type = resource['resourceType']
    flat_resource = flat_fhir_to_json_str(flatten_fhir(resource), resource_name(resource), FHIR_to_string(resource))
    return f'CREATE (:{resource_type}:resource:inferred {flat_resource})'

nodes = []
edges = []
dates = set() # set is used here to make sure dates are unique
for resource_str in observations:
    resource = json.loads(resource_str)
    nodes.append(inferred_resource_to_node(resource))
    node_edges, node_dates = resource_to_edges(resource)
    edges += node_edges
    dates.update(node_dates)
    
print(nodes)
print(edges)

## Create Nodes in DB

This cell creates the nodes and edges in the DB.

In [None]:
for node in nodes:
    graph.query(node)

for edge in edges:
    try:
        graph.query(edge)
    except:
        print(f'Failed to create edge: {edge}')

## TODO: Add to embedding index

It is left to you to add the new nodes to the vector/embedding index if you want to use them in RAG. 

**Disclaimer:** Nothing provided here is guaranteed or warrantied to work. It is provided as is and has not been tested extensively. Using this notebook is at the risk of the user. 

Copyright &copy; 2024 Sam Schifman