![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/13.Snomed_Entity_Resolver_Model_Training.ipynb)

# 13. Snomed Entity Resolver Model Training

In [4]:
import os

jsl_secret = os.getenv('SECRET')

import sparknlp
sparknlp_version = sparknlp.version()
import sparknlp_jsl
jsl_version = sparknlp_jsl.version()

print (jsl_secret)

In [None]:
import json
import os
import sparknlp_jsl
import sparknlp
from pyspark.ml import Pipeline, PipelineModel
from pyspark.sql import SparkSession
import sys, time
from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp.util import *
from sparknlp_jsl.annotator import *

from sparknlp.pretrained import ResourceDownloader
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType

params = {"spark.driver.memory":"16G",
"spark.kryoserializer.buffer.max":"2000M",
"spark.driver.maxResultSize":"2000M"}

spark = sparknlp_jsl.start(jsl_secret,params=params)

print (sparknlp.version())
print (sparknlp_jsl.version())

## Load datasets

In [6]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/AskAPatient.fold-0.test.txt
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/AskAPatient.fold-0.train.txt
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/AskAPatient.fold-0.validation.txt

In [7]:
import pandas as pd

cols = ["conceptId","_term","term"]

aap_tr = pd.read_csv("AskAPatient.fold-0.train.txt",sep="\t",encoding="ISO-8859-1",header=None)
aap_tr.columns = cols
aap_tr["conceptId"] = aap_tr.conceptId.apply(str)

aap_ts = pd.read_csv("AskAPatient.fold-0.test.txt",sep="\t",header=None)
aap_ts.columns = cols
aap_ts["conceptId"] = aap_ts.conceptId.apply(str)

aap_vl = pd.read_csv("AskAPatient.fold-0.validation.txt",sep="\t",header=None)
aap_vl.columns = cols
aap_vl["conceptId"] = aap_vl.conceptId.apply(str)

In [8]:
aap_vl.head()

Unnamed: 0,conceptId,_term,term
0,267032009,Tired all the time,persisten feeling of tiredness
1,22298006,Myocardial infarction,HEART ATTACK
2,3877011000036101,Lipitor,LIPITOR
3,415690000,Sweating,sweated
4,248491001,Swollen knee,swelling at knee


In [9]:
aap_train_sdf = spark.createDataFrame(aap_tr)
aap_test_sdf = spark.createDataFrame(aap_ts)
aap_val_sdf = spark.createDataFrame(aap_vl)

# Sentence Entity Resolver (sBioBert sentence embeddings) (after v2.7)

In [10]:
aap_train_sdf.show()

+----------------+--------------------+--------------------+
|       conceptId|               _term|                term|
+----------------+--------------------+--------------------+
|       108367008|Dislocation of joint|Dislocation of joint|
|3384011000036100|           Arthrotec|           Arthrotec|
|       166717003|Serum creatinine ...|Serum creatinine ...|
|3877011000036101|             Lipitor|             Lipitor|
|       402234004|         Foot eczema|         Foot eczema|
|       404640003|           Dizziness|           Dizziness|
|       271681002|        Stomach ache|        Stomach ache|
|        76948002|         Severe pain|         Severe pain|
|        36031001|        Burning feet|        Burning feet|
|        76948002|         Severe pain|         Severe pain|
|        42399005|       Renal failure|       Renal failure|
|       288227007|Myalgia/myositis ...|Myalgia/myositis ...|
|       419723007|       Mentally dull|       Mentally dull|
|       248490000|    Bl

In [11]:
aap_train_sdf.printSchema()

root
 |-- conceptId: string (nullable = true)
 |-- _term: string (nullable = true)
 |-- term: string (nullable = true)



In [12]:
documentAssembler = DocumentAssembler()\
    .setInputCol("_term")\
    .setOutputCol("sentence")

bert_embeddings = BertSentenceEmbeddings.pretrained("sbiobert_base_cased_mli", "en", "clinical/models")\
    .setInputCols(["sentence"])\
    .setOutputCol("bert_embeddings")

snomed_training_pipeline = Pipeline(stages = [
    documentAssembler,
    bert_embeddings])

snomed_training_model = snomed_training_pipeline.fit(aap_train_sdf)

snomed_data = snomed_training_model.transform(aap_train_sdf)


sbiobert_base_cased_mli download started this may take some time.
Approximate size to download 384.3 MB
[OK!]


In [13]:
bertExtractor = SentenceEntityResolverApproach()\
  .setNeighbours(25)\
  .setThreshold(1000)\
  .setInputCols("bert_embeddings")\
  .setNormalizedCol("_term")\
  .setLabelCol("conceptId")\
  .setOutputCol('snomed_code')\
  .setDistanceFunction("EUCLIDIAN")\
  .setCaseSensitive(False)

%time snomed_model = bertExtractor.fit(snomed_data)

CPU times: user 9.59 s, sys: 1.06 s, total: 10.6 s
Wall time: 30min


In [14]:
# save if you will need that later
snomed_model.write().overwrite().save("biobertresolve_snomed_askapatient")

In [15]:
prediction_Model = PipelineModel(stages=[snomed_training_model, snomed_model])

aap_train_pred= prediction_Model.transform(aap_train_sdf).cache()
aap_test_pred= prediction_Model.transform(aap_test_sdf).cache()
aap_val_pred= prediction_Model.transform(aap_val_sdf).cache()

In [16]:
aap_test_pred.selectExpr("conceptId","term","_term","snomed_code[0].result","snomed_code[0].metadata.resolved_text","snomed_code[0].metadata.all_k_resolutions").show(truncate=50)

+----------------+--------------------------------+------------------------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       conceptId|                            term|                               _term|snomed_code[0].result|snomed_code[0].metadata[resolved_text]|        snomed_code[0].metadata[all_k_resolutions]|
+----------------+--------------------------------+------------------------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       108367008|                     dislocating|                Dislocation of joint|            108367008|                  Dislocation of joint|Dislocation of joint:::Joint injury:::Disorder ...|
|3384011000036100|                       Arthrotec|                           Arthrotec|     3384011000036100|                             Arthrotec|                                         Arthro

In [17]:
aap_val_pred.selectExpr("conceptId","term","_term","snomed_code[0].result","snomed_code[0].metadata.resolved_text","snomed_code[0].metadata.all_k_resolutions").show(truncate=50)

+----------------+------------------------------+---------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       conceptId|                          term|                _term|snomed_code[0].result|snomed_code[0].metadata[resolved_text]|        snomed_code[0].metadata[all_k_resolutions]|
+----------------+------------------------------+---------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       267032009|persisten feeling of tiredness|   Tired all the time|            267032009|                    Tired all the time|Tired all the time:::Always sleepy:::Constant pain|
|        22298006|                  HEART ATTACK|Myocardial infarction|             22298006|                 Myocardial infarction|                             Myocardial infarction|
|3877011000036101|                       LIPITOR|              Lipitor|     3877