![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/healthcare-nlp/05.1.Clinical_Entity_Resolver_Model_Training.ipynb)

# Clinical Entity Resolver Model Training

In [None]:
# Install the johnsnowlabs library to access Spark-OCR and Spark-NLP for Healthcare, Finance, and Legal.
! pip install -q johnsnowlabs

In [None]:
from google.colab import files
print('Please Upload your John Snow Labs License using the button below')
license_keys = files.upload()

In [None]:
from johnsnowlabs import nlp, medical, visual

# After uploading your license run this to install all licensed Python Wheels and pre-download Jars the Spark Session JVM
nlp.install()

In [4]:
from johnsnowlabs import nlp, medical, visual
import pandas as pd

# Automatically load license data and start a session with all jars user has access to
spark = nlp.start()

👌 Detected license file /content/5.0.0.spark_nlp_for_healthcare.json
👌 Launched [92mcpu optimized[39m session with with: 🚀Spark-NLP==5.0.0, 💊Spark-Healthcare==5.0.0, running on ⚡ PySpark==3.1.2


## Load datasets

In [5]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/AskAPatient.fold-0.test.txt
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/AskAPatient.fold-0.train.txt
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/AskAPatient.fold-0.validation.txt

In [6]:
cols = ["conceptId","ground_truth","concept_name"]

aap_tr = pd.read_csv("AskAPatient.fold-0.train.txt",sep="\t",encoding="ISO-8859-1",header=None)
aap_tr.columns = cols
aap_tr["conceptId"] = aap_tr.conceptId.apply(str)

aap_ts = pd.read_csv("AskAPatient.fold-0.test.txt",sep="\t",header=None)
aap_ts.columns = cols
aap_ts["conceptId"] = aap_ts.conceptId.apply(str)

aap_vl = pd.read_csv("AskAPatient.fold-0.validation.txt",sep="\t",header=None)
aap_vl.columns = cols
aap_vl["conceptId"] = aap_vl.conceptId.apply(str)

In [7]:
aap_tr.head()

Unnamed: 0,conceptId,ground_truth,concept_name
0,108367008,Dislocation of joint,Dislocation of joint
1,3384011000036100,Arthrotec,Arthrotec
2,166717003,Serum creatinine raised,Serum creatinine raised
3,3877011000036101,Lipitor,Lipitor
4,402234004,Foot eczema,Foot eczema


In [8]:
# Create spark dataframes

aap_train_sdf = spark.createDataFrame(aap_tr)
aap_test_sdf = spark.createDataFrame(aap_ts)
aap_val_sdf = spark.createDataFrame(aap_vl)

# Sentence Entity Resolver (sBioBert sentence embeddings)

In [9]:
aap_train_sdf.show()

+----------------+--------------------+--------------------+
|       conceptId|        ground_truth|        concept_name|
+----------------+--------------------+--------------------+
|       108367008|Dislocation of joint|Dislocation of joint|
|3384011000036100|           Arthrotec|           Arthrotec|
|       166717003|Serum creatinine ...|Serum creatinine ...|
|3877011000036101|             Lipitor|             Lipitor|
|       402234004|         Foot eczema|         Foot eczema|
|       404640003|           Dizziness|           Dizziness|
|       271681002|        Stomach ache|        Stomach ache|
|        76948002|         Severe pain|         Severe pain|
|        36031001|        Burning feet|        Burning feet|
|        76948002|         Severe pain|         Severe pain|
|        42399005|       Renal failure|       Renal failure|
|       288227007|Myalgia/myositis ...|Myalgia/myositis ...|
|       419723007|       Mentally dull|       Mentally dull|
|       248490000|    Bl

In [10]:
aap_train_sdf.printSchema()

root
 |-- conceptId: string (nullable = true)
 |-- ground_truth: string (nullable = true)
 |-- concept_name: string (nullable = true)



## Get Embeddings

Now we will get the sentence embeddings of `concept_name` column.

In [11]:
documentAssembler = nlp.DocumentAssembler()\
    .setInputCol("concept_name")\
    .setOutputCol("sentence")

bert_embeddings = nlp.BertSentenceEmbeddings.pretrained("sbiobert_base_cased_mli", "en", "clinical/models")\
    .setInputCols(["sentence"])\
    .setOutputCol("bert_embeddings")
    # .setCaseSensitive(False)

embeddings_pipeline = nlp.Pipeline(stages = [
    documentAssembler,
    bert_embeddings])

embeddings_model = embeddings_pipeline.fit(aap_train_sdf)
snomed_data = embeddings_model.transform(aap_train_sdf)

sbiobert_base_cased_mli download started this may take some time.
Approximate size to download 384.3 MB
[OK!]


In [12]:
snomed_data.show()

+----------------+--------------------+--------------------+--------------------+--------------------+
|       conceptId|        ground_truth|        concept_name|            sentence|     bert_embeddings|
+----------------+--------------------+--------------------+--------------------+--------------------+
|       108367008|Dislocation of joint|Dislocation of joint|[{document, 0, 19...|[{sentence_embedd...|
|3384011000036100|           Arthrotec|           Arthrotec|[{document, 0, 8,...|[{sentence_embedd...|
|       166717003|Serum creatinine ...|Serum creatinine ...|[{document, 0, 22...|[{sentence_embedd...|
|3877011000036101|             Lipitor|             Lipitor|[{document, 0, 6,...|[{sentence_embedd...|
|       402234004|         Foot eczema|         Foot eczema|[{document, 0, 10...|[{sentence_embedd...|
|       404640003|           Dizziness|           Dizziness|[{document, 0, 8,...|[{sentence_embedd...|
|       271681002|        Stomach ache|        Stomach ache|[{document, 0

We have `bert_embeddings` column in our training dataframe that we will use as input while training the model.

# Train SNOMED Model

In [13]:
bertExtractor = medical.SentenceEntityResolverApproach()\
  .setNeighbours(25)\
  .setThreshold(1000)\
  .setInputCols("bert_embeddings")\
  .setNormalizedCol("concept_name")\
  .setLabelCol("conceptId")\
  .setOutputCol('snomed_code')\
  .setDistanceFunction("EUCLIDIAN")\
  .setCaseSensitive(False)

%time snomed_model = bertExtractor.fit(snomed_data)

CPU times: user 1.13 s, sys: 163 ms, total: 1.29 s
Wall time: 1min 57s


In [14]:
# save if you will need that later
snomed_model.write().overwrite().save("biobertresolve_snomed_askapatient")

## Test Model

In [15]:
prediction_Model = nlp.PipelineModel(stages=[embeddings_model, snomed_model])

aap_test_pred= prediction_Model.transform(aap_test_sdf).cache()
aap_val_pred= prediction_Model.transform(aap_val_sdf).cache()

In [16]:
aap_test_pred.selectExpr("conceptId","concept_name","ground_truth","snomed_code[0].result","snomed_code[0].metadata.resolved_text","snomed_code[0].metadata.all_k_resolutions").show(truncate=50)

+----------------+--------------------------------+------------------------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       conceptId|                    concept_name|                        ground_truth|snomed_code[0].result|snomed_code[0].metadata[resolved_text]|        snomed_code[0].metadata[all_k_resolutions]|
+----------------+--------------------------------+------------------------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       108367008|                     dislocating|                Dislocation of joint|            387603000|                           balance off|balance off:::Impaired mobility:::Reduced mobil...|
|3384011000036100|                       Arthrotec|                           Arthrotec|     3384011000036100|                             Arthrotec|                                         Arthro

In [17]:
aap_val_pred.selectExpr("conceptId","concept_name","ground_truth","snomed_code[0].result","snomed_code[0].metadata.resolved_text","snomed_code[0].metadata.all_k_resolutions").show(truncate=50)

+----------------+------------------------------+---------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       conceptId|                  concept_name|         ground_truth|snomed_code[0].result|snomed_code[0].metadata[resolved_text]|        snomed_code[0].metadata[all_k_resolutions]|
+----------------+------------------------------+---------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       267032009|persisten feeling of tiredness|   Tired all the time|             84229001|                     extreme tiredness|extreme tiredness:::feeling tired a lot:::feeli...|
|        22298006|                  HEART ATTACK|Myocardial infarction|             22298006|                          HEART ATTACH|HEART ATTACH:::HEADACHES:::LIGHT HEADED:::HAIR ...|
|3877011000036101|                       LIPITOR|              Lipitor|     3877

# Train Model with Auxilary Informations

We can add auxialry information to our model. In here we will add an aux column with the ground truths of the codes. So the model will return the ground truths of all codes that the model will return in `all_k_aux_labels` in metadata.

In [18]:
bertExtractor_aux = medical.SentenceEntityResolverApproach()\
  .setNeighbours(25)\
  .setThreshold(1000)\
  .setInputCols("bert_embeddings")\
  .setNormalizedCol("concept_name")\
  .setLabelCol("conceptId")\
  .setOutputCol('snomed_code')\
  .setDistanceFunction("EUCLIDIAN")\
  .setCaseSensitive(False)\
  .setUseAuxLabel(True)\
  .setAuxLabelCol("ground_truth")

%time snomed_aux_model = bertExtractor_aux.fit(snomed_data)

CPU times: user 2.03 s, sys: 249 ms, total: 2.28 s
Wall time: 3min 36s


In [19]:
# save if you will need that later
snomed_aux_model.write().overwrite().save("biobertresolve_snomed_askapatient_aux")

## Test Aux Model

Please check the `all_k_aux_labels` column in metadata.

In [20]:
aux_prediction_Model = nlp.PipelineModel(stages=[embeddings_model, snomed_aux_model])

aap_test_pred_aux= aux_prediction_Model.transform(aap_test_sdf).cache()
aap_val_pred_aux= aux_prediction_Model.transform(aap_val_sdf).cache()

In [21]:
aap_test_pred_aux.selectExpr("conceptId","concept_name","ground_truth","snomed_code[0].result","snomed_code[0].metadata.resolved_text","snomed_code[0].metadata.all_k_resolutions", "snomed_code[0].metadata.all_k_aux_labels").show(truncate=50)

+----------------+--------------------------------+------------------------------------+---------------------+--------------------------------------+--------------------------------------------------+--------------------------------------------------+
|       conceptId|                    concept_name|                        ground_truth|snomed_code[0].result|snomed_code[0].metadata[resolved_text]|        snomed_code[0].metadata[all_k_resolutions]|         snomed_code[0].metadata[all_k_aux_labels]|
+----------------+--------------------------------+------------------------------------+---------------------+--------------------------------------+--------------------------------------------------+--------------------------------------------------+
|       108367008|                     dislocating|                Dislocation of joint|            387603000|                           balance off|balance off:::Impaired mobility:::Reduced mobil...|Impairment of balance:::Impaired mobility:::

In [22]:
aap_val_pred_aux.selectExpr("conceptId","concept_name","ground_truth","snomed_code[0].result","snomed_code[0].metadata.resolved_text","snomed_code[0].metadata.all_k_resolutions", "snomed_code[0].metadata.all_k_aux_labels").show(truncate=50)

+----------------+------------------------------+---------------------+---------------------+--------------------------------------+--------------------------------------------------+--------------------------------------------------+
|       conceptId|                  concept_name|         ground_truth|snomed_code[0].result|snomed_code[0].metadata[resolved_text]|        snomed_code[0].metadata[all_k_resolutions]|         snomed_code[0].metadata[all_k_aux_labels]|
+----------------+------------------------------+---------------------+---------------------+--------------------------------------+--------------------------------------------------+--------------------------------------------------+
|       267032009|persisten feeling of tiredness|   Tired all the time|             84229001|                     extreme tiredness|extreme tiredness:::feeling tired a lot:::feeli...|Fatigue:::Tired all the time:::Feeling tired:::...|
|        22298006|                  HEART ATTACK|Myocardial 