# Proof of Concept Spark-NLP ETL Script

In [106]:
# the boilerplate

from johnsnowlabs import nlp, medical
import pandas as pd
import datetime
from pyspark.sql import functions as F

# Automatically load license data and start a session with all jars user has access to
spark = nlp.start()

Spark Session already created, some configs may not take.
👌 Detected license file /home/kate/projects/jsl-mimic-omop/spark_nlp_for_healthcare_spark_ocr_9323.json


## Pipeline

In [107]:
documentAssembler = nlp.DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

sentenceDetector = nlp.SentenceDetector()\
    .setInputCols("document")\
    .setOutputCol("sentence")

tokenizer = nlp.Tokenizer()\
    .setInputCols("sentence")\
    .setOutputCol("token")

word_embeddings = nlp.WordEmbeddingsModel.pretrained("embeddings_clinical", "en", "clinical/models")\
    .setInputCols("sentence", "token")\
    .setOutputCol("word_embeddings")

# to get general health entities
ner_jsl = medical.NerModel.pretrained("ner_jsl", "en", "clinical/models")\
      .setInputCols(["sentence", "token", "word_embeddings"])\
      .setOutputCol("ner_jsl")

ner_jsl_converter = medical.NerConverter()\
      .setInputCols(["sentence", "token", "ner_jsl"])\
      .setOutputCol("clinical_ner_chunk")\
      .setWhiteList(["Procedure","Kidney_Disease","Cerebrovascular_Disease","Heart_Disease", "Medical_Device",
                     "Disease_Syndrome_Disorder", "ImagingFindings", "Symptom", "VS_Finding",
                     "EKG_Findings", "Communicable_Disease","Substance", #no drug ingredient
                     "Triglycerides","Alcohol","Smoking","Pregnancy","Hypertension","Obesity",
                     "Injury_or_Poisoning","Hyperlipidemia","BMI","Oncological","Psychological_Condition","LDL","Diabetes"
                     #  "Internal_organ_or_component","External_body_part_or_region","Modifier",
                     # These stay out until I get a relalation extraction thing going
                     ])

# to get DRUG entities
posology_ner = medical.NerModel().pretrained("ner_posology", "en", "clinical/models") \
    .setInputCols(["sentence", "token", "word_embeddings"]) \
    .setOutputCol("posology_ner")

posology_ner_chunk = medical.NerConverter()\
    .setInputCols("sentence","token","posology_ner")\
    .setOutputCol("posology_ner_chunk") \
    .setWhiteList(["DRUG"])

# merge the chunks into a single ner_chunk
chunk_merger = medical.ChunkMergeApproach()\
    .setInputCols("clinical_ner_chunk","posology_ner_chunk")\
    .setOutputCol("ner_chunk")\
    .setMergeOverlapping(False)
    
# Assertion model
clinical_assertion = medical.AssertionDLModel.pretrained("assertion_jsl_augmented", "en", "clinical/models") \
    .setInputCols(["sentence", "ner_chunk", "word_embeddings"]) \
    .setOutputCol("assertion")


# convert chunks to doc to get sentence embeddings of them
chunk2doc = nlp.Chunk2Doc()\
    .setInputCols("ner_chunk")\
    .setOutputCol("doc_final_chunk")


sbiobert_embeddings = nlp.BertSentenceEmbeddings.pretrained("sbiobert_base_cased_mli","en","clinical/models")\
    .setInputCols(["doc_final_chunk"])\
    .setOutputCol("sbert_embeddings")\
    .setCaseSensitive(False)

# This got dropped because too much unreliable information was coming back

# # filter TEST entity embeddings
# router_sentence_loinc = medical.Router() \
#     .setInputCols("sbert_embeddings") \
#     .setFilterFieldsElements(["Test"]) \
#     .setOutputCol("test_embeddings")

# filter SNOMED-ish entity embeddings
router_sentence_snomed = medical.Router() \
    .setInputCols("sbert_embeddings") \
    .setFilterFieldsElements(["Procedure","Kidney_Disease","Cerebrovascular_Disease","Heart_Disease",
                     "Disease_Syndrome_Disorder", "ImagingFindings", "Symptom", "VS_Finding", "Medical_Device",
                     "EKG_Findings", "Communicable_Disease","Substance", #no drug ingredient, no test
                    #  "Internal_organ_or_component","External_body_part_or_region",
                     "Modifier",
                     "Triglycerides","Alcohol","Smoking","Pregnancy","Hypertension","Obesity",
                     "Injury_or_Poisoning","Hyperlipidemia","BMI","Oncological","Psychological_Condition","LDL","Diabetes"]) \
    .setOutputCol("problem_embeddings")

# filter DRUG entity embeddings
router_sentence_rxnorm = medical.Router() \
    .setInputCols("sbert_embeddings") \
    .setFilterFieldsElements(["DRUG"]) \
    .setOutputCol("drug_embeddings")

# # use test_embeddings only
# loinc_resolver = medical.SentenceEntityResolverModel.pretrained("sbiobertresolve_loinc","en", "clinical/models") \
#     .setInputCols(["test_embeddings"]) \
#     .setOutputCol("loinc_code") \
#     .setDistanceFunction("EUCLIDEAN")

# use problem_embeddings only
icd_resolver = medical.SentenceEntityResolverModel.pretrained("sbiobertresolve_snomed_findings_aux_concepts","en", "clinical/models") \
    .setInputCols(["problem_embeddings"]) \
    .setOutputCol("snomed_condition")\
    .setDistanceFunction("EUCLIDEAN")

# use drug_embeddings only
rxnorm_resolver = medical.SentenceEntityResolverModel.pretrained("sbiobertresolve_rxnorm_augmented","en", "clinical/models") \
    .setInputCols(["drug_embeddings"]) \
    .setOutputCol("rxnorm_code")\
    .setDistanceFunction("EUCLIDEAN")


pipeline = nlp.Pipeline(
    stages=[
        documentAssembler,
        sentenceDetector,
        tokenizer,
        word_embeddings,
        ner_jsl,
        ner_jsl_converter,
        posology_ner,
        posology_ner_chunk,
        chunk_merger,
        clinical_assertion,
        chunk2doc,
        sbiobert_embeddings,
        # router_sentence_loinc,
        router_sentence_snomed,
        router_sentence_rxnorm,
        # loinc_resolver,
        icd_resolver,
        rxnorm_resolver
])

empty_data = spark.createDataFrame([['']]).toDF("text")
model = pipeline.fit(empty_data)

embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]
ner_jsl download started this may take some time.
[OK!]
ner_posology download started this may take some time.
[OK!]
assertion_jsl_augmented download started this may take some time.
[OK!]
sbiobert_base_cased_mli download started this may take some time.
Approximate size to download 384.3 MB
[OK!]
sbiobertresolve_snomed_findings_aux_concepts download started this may take some time.
[OK!]
sbiobertresolve_rxnorm_augmented download started this may take some time.
[OK!]


## A parsing function

In [123]:
# This function is lifted from example code and modified to be more useful in a loop

def get_codes (person_id, note_id, result, input_cols=['rxnorm_code', 'snomed_condition'], aux=False):
    """A function that strips out key information from a LightPipeline result object and
    returns it as a pandas DataFrame

    Args:
        person_id (integer): The OMOP person_id associated with the note
        note_id (integer): The OMOP note_id associated with the note
        result (dict): A LightPipeline result, stored as a dictionary of cols of interest
        input_cols (list, optional): What pipeline output columns are being processed? Defaults to ['rxnorm_code', 'snomed_condition'].
        aux (bool, optional): Whether to process the aux information. Defaults to False.

    Returns:
        pandas.DataFrame: A dataframe with extracted codes and confidences for each input_col and assertions
        with one row per chunk.
    """

    # This maps an explicit OMOP vocabulary_id to each column being processed.
    vocabulary_ids = {"loinc_code": "LOINC", 
                      "rxnorm_code": "RxNorm",
                      "snomed_condition": "SNOMED"}

    # collect chunks and assertions
    chunks = []
    assertions =[]
    assertion_confidences = []
    begin = []
    end = []
    chunk_id = []
    
    for chunk in result['ner_chunk']:
        begin.append(chunk.begin)
        end.append(chunk.end)
        chunks.append(chunk.result)
        chunk_id.append(chunk.metadata['chunk'])
        
    for assertion in result['assertion']:
        assertions.append(assertion.result)
        assertion_confidences.append(assertion.metadata['confidence'])
        
    chunk_df = pd.DataFrame({'chunks':chunks, 'begin': begin, 'end':end, 'chunk_id':chunk_id, 'assertion':assertions, 'assertion_confidence':assertion_confidences})
    
    # collect codes and confidences for each input_col
    ner_results = []
    
    for col in input_cols:

        codes = []
        resolutions=[]
        confidences = []
        chunk_ids = []
        all_k_aux_labels=[]
        
        for code in result[col]:
            codes.append(code.result)
            confidences.append(code.metadata['confidence'])
            chunk_ids.append(code.metadata['chunk'])
            try:
                resolutions.append(code.metadata['resolved_text'])
            except:
                resolutions.append('')

            if aux:
                if code.metadata['all_k_aux_labels']:
                    try:
                        all_k_aux_labels.append(code.metadata['all_k_aux_labels'].split(':::')[0])
                    except:
                        all_k_aux_labels.append('')
                else:
                    all_k_aux_labels.append('')

        df = pd.DataFrame({'chunk_id':chunk_ids, 'code':codes, 'resolution':resolutions,'confidence':confidences, 'aux_label':all_k_aux_labels})
            
        df['vocabulary_id'] = vocabulary_ids[col]

        if not aux:
            df = df.drop(['aux_label'], axis=1)
                
        ner_results.append(df)

    merged_df = pd.concat(ner_results, axis=0, ignore_index=True)
    merged_df = pd.merge(merged_df, chunk_df, left_on='chunk_id', right_on='chunk_id')
    merged_df['person_id'] = person_id
    merged_df['note_id'] = note_id

    return merged_df

## Test on a little toy dataframe

In [109]:
test_notes = pd.DataFrame({'patient_id': [1, 2],
                           'document_id': [33, 34],
                           'note_text': ["""The patient is a 41-year-old Vietnamese female with a cough that started last week.
                                            She has had right-sided chest pain radiating to her back with fever starting yesterday.
                                            She has a history of pericarditis in May 2006 and developed cough with right-sided chest pain.
                                            CBC showed a WBC of 12.1 with 80% neutrophils and 10% bands. She does not have a history of asthma or COPD.
                                            Maternal history of mycoardial infarction at age 50. She uses an insulin pump and a wheelchair.
                                            MEDICATIONS
                                            1. Coumadin 1 mg daily. Last INR was on Tuesday, August 14, 2007, and her INR was 2.3.
                                            2. Amiodarone 100 mg p.o. daily.
                                        """, 
                                        """The patient has a headache and feels dizzy. She has a history of hypertension and diabetes. 
                                        She is currently taking Lisinopril 10 mg daily and Metformin 500 mg daily.
                                        Her mother died of a stroke at age 60. She has a history of smoking and alcohol use."""
                                        ]})
test_notes


Unnamed: 0,patient_id,document_id,note_text
0,1,33,The patient is a 41-year-old Vietnamese female...
1,2,34,The patient has a headache and feels dizzy. Sh...


In [124]:
light_model = nlp.LightPipeline(model)
test_notes['light_result']= light_model.fullAnnotate(list(test_notes['note_text']))

extracted_df = pd.concat([get_codes(row['patient_id'], row['document_id'], row['light_result'], aux=True) for index, row in test_notes.iterrows()])
extracted_df

Unnamed: 0,chunk_id,code,resolution,confidence,aux_label,vocabulary_id,chunks,begin,end,assertion,assertion_confidence,person_id,note_id
0,9,1654190,insulin detemir Pen Injector,0.2216,Clinical Drug Form,RxNorm,insulin,616,622,Past,0.9976,1,33
1,12,202421,coumadin [coumadin],0.9903,Brand Name,RxNorm,Coumadin,750,757,Past,1.0,1,33
2,13,703,amiodarone [amiodarone],0.4986,Ingredient,RxNorm,Amiodarone,881,890,Possible,0.4823,1,33
3,0,248592006,character of cough,0.2654,Observable Entity,SNOMED,cough,54,58,Past,0.9482,1,33
4,1,1264062004,burning chest pain,0.2747,No_Concept_Class,SNOMED,chest pain,152,161,Past,1.0,1,33
5,2,248431003,phase of fever,0.3945,Observable Entity,SNOMED,fever,190,194,Present,0.9997,1,33
6,3,391935006,pericardiolysis,0.6332,Procedure,SNOMED,pericarditis,281,292,Present,0.9818,1,33
7,4,248592006,character of cough,0.2654,Observable Entity,SNOMED,cough,320,324,Past,0.9379,1,33
8,5,1264062004,burning chest pain,0.2747,No_Concept_Class,SNOMED,chest pain,343,352,Past,1.0,1,33
9,6,401193004,asthma confirmed,0.2686,Context-dependent,SNOMED,asthma,491,496,Absent,1.0,1,33


## Parse 100 of our MIMIC patients

In [111]:
from sqlalchemy import create_engine
from sqlalchemy.types import Integer, Float, String

uri = f"postgresql+psycopg2://postgres:mypass@localhost:5432/postgres"
engine = create_engine(uri)

In [113]:
notes = pd.read_sql('''with ppl as (select distinct person_id from omop.note limit 100)
select note.note_id, ppl.person_id, note_text
from ppl
left join omop.note on ppl.person_id = note.person_id''', engine)

people = notes['person_id'].unique()
notes.head()

Unnamed: 0,note_id,person_id,note_text
0,386736537,392775850,"CVA (Stroke, Cerebral infarction), Ischemic\n ..."
1,386617569,392775850,Atrial fibrillation\nMarked left axis deviatio...
2,386617570,392775850,Atrial fibrillation\nLeft axis deviation - lef...
3,386617571,392775850,Atrial fibrillation with rapid ventricular res...
4,387605334,392775850,NPN MICU-B 7AM-7PM\nS/O: RESPIR: Remains intu...


In [114]:
print(notes.shape)

(5086, 3)


## Our processing loop

This runs in batches of 500 notes to balance between repeatedly loading/unloading models and overloading the server

In [117]:
batch_size = 500
batch_start = 0
start = datetime.datetime.now()

print(start.isoformat())

while batch_start < len(notes):
    extracted_df = pd.DataFrame()
    batch_notes = notes[batch_start:batch_start+batch_size]
    print(
        f"Batch Start: {batch_start}, number of notes: {len(batch_notes)}, start: {start.isoformat()}"
    )

    batch_notes["light_result"] = light_model.fullAnnotate(
        list(batch_notes["note_text"])
    )

    extracted_df = pd.concat(
        [
            extracted_df,
            pd.concat(
                [
                    get_codes(
                        row["person_id"], row["note_id"], row["light_result"], aux=True
                    )
                    for index, row in batch_notes.iterrows()
                ]
            ),
        ]
    )
    extracted_df.to_sql(
        "tmp_extracted_codes",
        engine,
        if_exists="append",
        index=False,
        schema="omop",
        dtype={
            "person_id": Integer(),
            "note_id": Integer(),
            "chunk_id": Integer(),
            "confidence": Float(),
            "vocabulary_id": String(),
            "code": String(),
            "resolution": String(),
            "assertion": String(),
            "assertion_confidence": Float(),
        },
        method="multi",
    )
    print(
        f"{len(extracted_df)} entities done in {datetime.timedelta(seconds=(datetime.datetime.now()-start).seconds)}"
    )
    print("==" * 20)
    start = datetime.datetime.now()
    batch_start += batch_size

2024-06-25T20:49:43.122575
Batch Start: 0, number of notes: 500, start: 2024-06-25T20:49:43.122575
11486 entities done in 0:34:04
Batch Start: 500, number of notes: 500, start: 2024-06-25T21:23:47.904187
8916 entities done in 0:18:30
Batch Start: 1000, number of notes: 500, start: 2024-06-25T21:42:18.726937
8041 entities done in 0:14:42
Batch Start: 1500, number of notes: 500, start: 2024-06-25T21:57:01.582953
9619 entities done in 0:17:54
Batch Start: 2000, number of notes: 500, start: 2024-06-25T22:14:55.694262
8274 entities done in 0:14:50
Batch Start: 2500, number of notes: 500, start: 2024-06-25T22:29:45.896461
9271 entities done in 0:25:10
Batch Start: 3000, number of notes: 500, start: 2024-06-25T22:54:55.981810
16697 entities done in 0:26:36
Batch Start: 3500, number of notes: 500, start: 2024-06-25T23:21:32.903358
6531 entities done in 0:12:54
Batch Start: 4000, number of notes: 500, start: 2024-06-25T23:34:27.386861
8970 entities done in 0:16:38
Batch Start: 4500, number of n

Let's see what we'd get converted to `note_nlp` form.

The filter on `confidence` can come off for a fuller view of results.
This needs to be tuned to customer requirements for confidence and assertions

In [134]:
note_nlp_form = pd.read_sql('''
SELECT e.note_id,
       concept_name, -- we don't need this for the final transform but it's helpful for a sanity check
       e.confidence, -- likewise, this doesn't come over but is helpful information while troubleshooting
       e.chunks                                               as snippet,
       e.begin                                                as offset,
       concept_id                                             as note_nlp_concept_id,
       e.code                                                 as note_nlp_source_concept_id,
       'JSL spark-nlp'                                        as nlp_system,
       '2024-06-25'                                           as nlp_date,
       case when e.assertion = 'Present' then 1 else 0 end    as term_exists,
       case when e.assertion = 'Past' then 'Past' else '' end as term_temporal,
       e.assertion                                            as term_modifiers,
       concept_code                                           as code,
       domain_id,
       c.vocabulary_id,
       concept_class_id
       
from omop.concept c
         inner join omop.TMP_EXTRACTED_CODES_BAK e
                    on e.code = c.concept_code and e.vocabulary_id = c.vocabulary_id
where c.domain_id in ('Measurement', 'Drug', 'Condition', 'Observation', 'Procedure')
  and not concept_class_id in ('Attribute', 'Morph Abnormality')
  and e.confidence >= 0.95
  and invalid_reason is NULL
                    
            ''',
            engine)

note_nlp_form.head(25)

Unnamed: 0,note_id,concept_name,confidence,snippet,offset,note_nlp_concept_id,note_nlp_source_concept_id,nlp_system,nlp_date,term_exists,term_temporal,term_modifiers,code,domain_id,vocabulary_id,concept_class_id
0,387121615,Multiplanar reconstruction,0.9964,Multiplanar\n reconstructions,746,4134630,261958001,JSL spark-nlp,2024-06-25,0,Past,Past,261958001,Observation,SNOMED,Qualifier Value
1,387121983,Motor vehicle accident,0.973,motor vehicle accident,530,435134,418399005,JSL spark-nlp,2024-06-25,0,,Family,418399005,Observation,SNOMED,Event
2,387606451,Percocet,0.9708,PERCOCET,512,19106396,42844,JSL spark-nlp,2024-06-25,0,,Family,42844,Drug,RxNorm,Brand Name
3,386361904,Biaxin,0.9931,Biaxin,848,19047891,203729,JSL spark-nlp,2024-06-25,0,,Planned,203729,Drug,RxNorm,Brand Name
4,386361904,Cesarean section,0.9626,Cesarean section,703,4015701,11466000,JSL spark-nlp,2024-06-25,0,Past,Past,11466000,Procedure,SNOMED,Procedure
5,386361904,Appendectomy,0.9595,Appendectomy,746,4198190,80146002,JSL spark-nlp,2024-06-25,0,Past,Past,80146002,Procedure,SNOMED,Procedure
6,387605943,Atrovent,0.9619,Atrovent,302,19020409,151390,JSL spark-nlp,2024-06-25,0,,Planned,151390,Drug,RxNorm,Brand Name
7,387598026,Tylenol,0.9631,tylenol,468,19042336,202433,JSL spark-nlp,2024-06-25,0,,Family,202433,Drug,RxNorm,Brand Name
8,387594431,Dilantin,0.9824,dilantin,1084,19043651,202740,JSL spark-nlp,2024-06-25,0,Past,Past,202740,Drug,RxNorm,Brand Name
9,387594431,Coughing,0.9985,coughing,847,4137801,263731006,JSL spark-nlp,2024-06-25,1,,Present,263731006,Observation,SNOMED,Observable Entity
