![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/13.Snomed_Entity_Resolver_Model_Training.ipynb)

# 13. Snomed Entity Resolver Model Training

In [None]:
import json, os
from google.colab import files

if 'spark_jsl.json' not in os.listdir():
  license_keys = files.upload()
  os.rename(list(license_keys.keys())[0], 'spark_jsl.json')

with open('spark_jsl.json') as f:
    license_keys = json.load(f)

# Defining license key-value pairs as local variables
locals().update(license_keys)
os.environ.update(license_keys)

In [None]:
# Installing pyspark and spark-nlp
! pip install --upgrade -q pyspark==3.4.1 spark-nlp==$PUBLIC_VERSION

# Installing Spark NLP Healthcare
! pip install --upgrade -q spark-nlp-jsl==$JSL_VERSION  --extra-index-url https://pypi.johnsnowlabs.com/$SECRET

In [None]:
import json
import os
import sys, time

import sparknlp_jsl
import sparknlp

from pyspark.ml import Pipeline, PipelineModel
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType

from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp.util import *
from sparknlp_jsl.annotator import *
from sparknlp.pretrained import ResourceDownloader


params = {"spark.driver.memory":"16G",
          "spark.kryoserializer.buffer.max":"2000M",
          "spark.driver.maxResultSize":"2000M"}

# change runtime and start with GPU
spark = sparknlp_jsl.start(license_keys['SECRET'],params=params, gpu=True)

print("Spark NLP Version :", sparknlp.version())
print("Spark NLP_JSL Version :", sparknlp_jsl.version())

spark

Spark NLP Version : 5.3.1
Spark NLP_JSL Version : 5.3.1


## Healthcare NLP for Data Scientists Course

If you are not familiar with the components in this notebook, you can check [Healthcare NLP for Data Scientists Udemy Course](https://www.udemy.com/course/healthcare-nlp-for-data-scientists/) and the [MOOC Notebooks](https://github.com/JohnSnowLabs/spark-nlp-workshop/tree/master/Spark_NLP_Udemy_MOOC/Healthcare_NLP) for each components.

## Load datasets

In [None]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/AskAPatient.fold-0.test.txt
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/AskAPatient.fold-0.train.txt
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/AskAPatient.fold-0.validation.txt

In [None]:
import pandas as pd

cols = ["conceptId","ground_truth","concept_name"]

aap_tr = pd.read_csv("AskAPatient.fold-0.train.txt",sep="\t",encoding="ISO-8859-1",header=None)
aap_tr.columns = cols
aap_tr["conceptId"] = aap_tr.conceptId.apply(str)

aap_ts = pd.read_csv("AskAPatient.fold-0.test.txt",sep="\t",header=None)
aap_ts.columns = cols
aap_ts["conceptId"] = aap_ts.conceptId.apply(str)

aap_vl = pd.read_csv("AskAPatient.fold-0.validation.txt",sep="\t",header=None)
aap_vl.columns = cols
aap_vl["conceptId"] = aap_vl.conceptId.apply(str)

In [None]:
aap_tr.head()

Unnamed: 0,conceptId,ground_truth,concept_name
0,108367008,Dislocation of joint,Dislocation of joint
1,3384011000036100,Arthrotec,Arthrotec
2,166717003,Serum creatinine raised,Serum creatinine raised
3,3877011000036101,Lipitor,Lipitor
4,402234004,Foot eczema,Foot eczema


In [None]:
# Create spark dataframes

aap_train_sdf = spark.createDataFrame(aap_tr)
aap_test_sdf = spark.createDataFrame(aap_ts)
aap_val_sdf = spark.createDataFrame(aap_vl)

# Sentence Entity Resolver (sBioBert sentence embeddings)

In [None]:
aap_train_sdf.show()

+----------------+--------------------+--------------------+
|       conceptId|        ground_truth|        concept_name|
+----------------+--------------------+--------------------+
|       108367008|Dislocation of joint|Dislocation of joint|
|3384011000036100|           Arthrotec|           Arthrotec|
|       166717003|Serum creatinine ...|Serum creatinine ...|
|3877011000036101|             Lipitor|             Lipitor|
|       402234004|         Foot eczema|         Foot eczema|
|       404640003|           Dizziness|           Dizziness|
|       271681002|        Stomach ache|        Stomach ache|
|        76948002|         Severe pain|         Severe pain|
|        36031001|        Burning feet|        Burning feet|
|        76948002|         Severe pain|         Severe pain|
|        42399005|       Renal failure|       Renal failure|
|       288227007|Myalgia/myositis ...|Myalgia/myositis ...|
|       419723007|       Mentally dull|       Mentally dull|
|       248490000|    Bl

In [None]:
aap_train_sdf.printSchema()

root
 |-- conceptId: string (nullable = true)
 |-- ground_truth: string (nullable = true)
 |-- concept_name: string (nullable = true)



## Get Embeddings

Now we will get the sentence embeddings of `concept_name` column.

In [None]:
documentAssembler = DocumentAssembler()\
    .setInputCol("concept_name")\
    .setOutputCol("sentence")

bert_embeddings = BertSentenceEmbeddings.pretrained("sbiobert_base_cased_mli", "en", "clinical/models")\
    .setInputCols(["sentence"])\
    .setOutputCol("bert_embeddings")
    # .setCaseSensitive(False)

embeddings_pipeline = Pipeline(stages = [
    documentAssembler,
    bert_embeddings])

embeddings_model = embeddings_pipeline.fit(aap_train_sdf)
snomed_data = embeddings_model.transform(aap_train_sdf)

sbiobert_base_cased_mli download started this may take some time.
Approximate size to download 384.3 MB
[OK!]


In [None]:
snomed_data.show()

+----------------+--------------------+--------------------+--------------------+--------------------+
|       conceptId|        ground_truth|        concept_name|            sentence|     bert_embeddings|
+----------------+--------------------+--------------------+--------------------+--------------------+
|       108367008|Dislocation of joint|Dislocation of joint|[{document, 0, 19...|[{sentence_embedd...|
|3384011000036100|           Arthrotec|           Arthrotec|[{document, 0, 8,...|[{sentence_embedd...|
|       166717003|Serum creatinine ...|Serum creatinine ...|[{document, 0, 22...|[{sentence_embedd...|
|3877011000036101|             Lipitor|             Lipitor|[{document, 0, 6,...|[{sentence_embedd...|
|       402234004|         Foot eczema|         Foot eczema|[{document, 0, 10...|[{sentence_embedd...|
|       404640003|           Dizziness|           Dizziness|[{document, 0, 8,...|[{sentence_embedd...|
|       271681002|        Stomach ache|        Stomach ache|[{document, 0

We have `bert_embeddings` column in our training dataframe that we will use as input while training the model.

# Train SNOMED Model

In [None]:
bertExtractor = SentenceEntityResolverApproach()\
  .setNeighbours(25)\
  .setThreshold(1000)\
  .setInputCols("bert_embeddings")\
  .setNormalizedCol("concept_name")\
  .setLabelCol("conceptId")\
  .setOutputCol('snomed_code')\
  .setDistanceFunction("EUCLIDIAN")\
  .setCaseSensitive(False)\
  .setDatasetInfo("the model version:531")

%time snomed_model = bertExtractor.fit(snomed_data)

CPU times: user 1.95 s, sys: 252 ms, total: 2.2 s
Wall time: 5min 13s


In [None]:
# save if you will need that later
snomed_model.write().overwrite().save("biobertresolve_snomed_askapatient")

## Test Model

In [None]:
prediction_Model = PipelineModel(stages=[embeddings_model, snomed_model])

aap_test_pred= prediction_Model.transform(aap_test_sdf).cache()
aap_val_pred= prediction_Model.transform(aap_val_sdf).cache()

In [None]:
aap_test_pred.selectExpr("conceptId","concept_name","ground_truth","snomed_code[0].result","snomed_code[0].metadata.resolved_text","snomed_code[0].metadata.all_k_resolutions").show(truncate=50)

+----------------+--------------------------------+------------------------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       conceptId|                    concept_name|                        ground_truth|snomed_code[0].result|snomed_code[0].metadata[resolved_text]|        snomed_code[0].metadata[all_k_resolutions]|
+----------------+--------------------------------+------------------------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       108367008|                     dislocating|                Dislocation of joint|            387603000|                           balance off|balance off:::Impaired mobility:::Reduced mobil...|
|3384011000036100|                       Arthrotec|                           Arthrotec|     3384011000036100|                             Arthrotec|                                         Arthro

In [None]:
aap_val_pred.selectExpr("conceptId","concept_name","ground_truth","snomed_code[0].result","snomed_code[0].metadata.resolved_text","snomed_code[0].metadata.all_k_resolutions").show(truncate=50)

+----------------+------------------------------+---------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       conceptId|                  concept_name|         ground_truth|snomed_code[0].result|snomed_code[0].metadata[resolved_text]|        snomed_code[0].metadata[all_k_resolutions]|
+----------------+------------------------------+---------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       267032009|persisten feeling of tiredness|   Tired all the time|             84229001|                     extreme tiredness|extreme tiredness:::feeling tired a lot:::feeli...|
|        22298006|                  HEART ATTACK|Myocardial infarction|             22298006|                          HEART ATTACH|HEART ATTACH:::HEADACHES:::LIGHT HEADED:::HAIR ...|
|3877011000036101|                       LIPITOR|              Lipitor|     3877

# Train Model with Auxilary Informations

We can add auxialry information to our model. In here we will add an aux column with the ground truths of the codes. So the model will return the ground truths of all codes that the model will return in `all_k_aux_labels` in metadata.

In [None]:
bertExtractor_aux = SentenceEntityResolverApproach()\
  .setNeighbours(25)\
  .setThreshold(1000)\
  .setInputCols("bert_embeddings")\
  .setNormalizedCol("concept_name")\
  .setLabelCol("conceptId")\
  .setOutputCol('snomed_code')\
  .setDistanceFunction("EUCLIDIAN")\
  .setCaseSensitive(False)\
  .setUseAuxLabel(True)\
  .setAuxLabelCol("ground_truth")

%time snomed_aux_model = bertExtractor_aux.fit(snomed_data)

CPU times: user 1.69 s, sys: 248 ms, total: 1.94 s
Wall time: 6min 46s


In [None]:
# save if you will need that later
snomed_aux_model.write().overwrite().save("biobertresolve_snomed_askapatient_aux")

## Test Aux Model

Please check the `all_k_aux_labels` column in metadata.

In [None]:
aux_prediction_Model = PipelineModel(stages=[embeddings_model, snomed_aux_model])

aap_test_pred_aux= aux_prediction_Model.transform(aap_test_sdf).cache()
aap_val_pred_aux= aux_prediction_Model.transform(aap_val_sdf).cache()

In [None]:
aap_test_pred_aux.selectExpr("conceptId","concept_name","ground_truth","snomed_code[0].result","snomed_code[0].metadata.resolved_text","snomed_code[0].metadata.all_k_resolutions", "snomed_code[0].metadata.all_k_aux_labels").show(truncate=50)

+----------------+--------------------------------+------------------------------------+---------------------+--------------------------------------+--------------------------------------------------+--------------------------------------------------+
|       conceptId|                    concept_name|                        ground_truth|snomed_code[0].result|snomed_code[0].metadata[resolved_text]|        snomed_code[0].metadata[all_k_resolutions]|         snomed_code[0].metadata[all_k_aux_labels]|
+----------------+--------------------------------+------------------------------------+---------------------+--------------------------------------+--------------------------------------------------+--------------------------------------------------+
|       108367008|                     dislocating|                Dislocation of joint|            387603000|                           balance off|balance off:::Impaired mobility:::Reduced mobil...|Impairment of balance:::Impaired mobility:::

In [None]:
aap_val_pred_aux.selectExpr("conceptId","concept_name","ground_truth","snomed_code[0].result","snomed_code[0].metadata.resolved_text","snomed_code[0].metadata.all_k_resolutions", "snomed_code[0].metadata.all_k_aux_labels").show(truncate=50)

+----------------+------------------------------+---------------------+---------------------+--------------------------------------+--------------------------------------------------+--------------------------------------------------+
|       conceptId|                  concept_name|         ground_truth|snomed_code[0].result|snomed_code[0].metadata[resolved_text]|        snomed_code[0].metadata[all_k_resolutions]|         snomed_code[0].metadata[all_k_aux_labels]|
+----------------+------------------------------+---------------------+---------------------+--------------------------------------+--------------------------------------------------+--------------------------------------------------+
|       267032009|persisten feeling of tiredness|   Tired all the time|             84229001|                     extreme tiredness|extreme tiredness:::feeling tired a lot:::feeli...|Fatigue:::Tired all the time:::Feeling tired:::...|
|        22298006|                  HEART ATTACK|Myocardial 