![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/13.Snomed_Entity_Resolver_Model_Training.ipynb)

# 13. Snomed Entity Resolver Model Training

In [1]:
import json

from google.colab import files

license_keys = files.upload()

with open(list(license_keys.keys())[0]) as f:
    license_keys = json.load(f)

license_keys.keys()


Saving license_keys_272.json to license_keys_272.json


dict_keys(['PUBLIC_VERSION', 'JSL_VERSION', 'SECRET', 'SPARK_NLP_LICENSE', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'SPARK_OCR_LICENSE', 'SPARK_OCR_SECRET'])

In [None]:
import os

# Install java
! apt-get update -qq
! apt-get install -y openjdk-8-jdk-headless -qq > /dev/null

os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]
! java -version

secret = license_keys['SECRET']

os.environ['SPARK_NLP_LICENSE'] = license_keys['SPARK_NLP_LICENSE']
os.environ['AWS_ACCESS_KEY_ID']= license_keys['AWS_ACCESS_KEY_ID']
os.environ['AWS_SECRET_ACCESS_KEY'] = license_keys['AWS_SECRET_ACCESS_KEY']
jsl_version = license_keys['JSL_VERSION']
version = license_keys['PUBLIC_VERSION']

! pip install --ignore-installed -q pyspark==2.4.4

! python -m pip install --upgrade spark-nlp-jsl==$jsl_version  --extra-index-url https://pypi.johnsnowlabs.com/$secret

! pip install --ignore-installed -q spark-nlp==$version

! pip -q install spark-nlp-display

import sparknlp

print (sparknlp.version())

import json
import os
from pyspark.ml import Pipeline
from pyspark.sql import SparkSession

import sparknlp_jsl
import sys, os, time
from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp.util import *
from sparknlp_jsl.annotator import *

from sparknlp.pretrained import ResourceDownloader

from pyspark.sql import functions as F
from pyspark.ml import Pipeline, PipelineModel

# sparknlp_jsl.start(secret, public=version) if you want to start with different version of public sparknlp
from pyspark.sql.types import StructType, StructField, StringType

params = {"spark.driver.memory":"16G",
"spark.kryoserializer.buffer.max":"2000M",
"spark.driver.maxResultSize":"2000M"}
spark = sparknlp_jsl.start(secret, params=params)

## Load datasets

In [3]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/AskAPatient.fold-0.test.txt
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/AskAPatient.fold-0.train.txt
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/AskAPatient.fold-0.validation.txt

In [4]:
import pandas as pd

cols = ["conceptId","_term","term"]

aap_tr = pd.read_csv("AskAPatient.fold-0.train.txt",sep="\t",encoding="ISO-8859-1",header=None)
aap_tr.columns = cols
aap_tr["conceptId"] = aap_tr.conceptId.apply(str)

aap_ts = pd.read_csv("AskAPatient.fold-0.test.txt",sep="\t",header=None)
aap_ts.columns = cols
aap_ts["conceptId"] = aap_ts.conceptId.apply(str)

aap_vl = pd.read_csv("AskAPatient.fold-0.validation.txt",sep="\t",header=None)
aap_vl.columns = cols
aap_vl["conceptId"] = aap_vl.conceptId.apply(str)

In [5]:
aap_vl.head()

Unnamed: 0,conceptId,_term,term
0,267032009,Tired all the time,persisten feeling of tiredness
1,22298006,Myocardial infarction,HEART ATTACK
2,3877011000036101,Lipitor,LIPITOR
3,415690000,Sweating,sweated
4,248491001,Swollen knee,swelling at knee


In [6]:
aap_train_sdf = spark.createDataFrame(aap_tr)
aap_test_sdf = spark.createDataFrame(aap_ts)
aap_val_sdf = spark.createDataFrame(aap_vl)

# Chunk Entity Resolver (Glove Embeddings)

## Create Training Pipeline

In [7]:

document = DocumentAssembler()\
    .setInputCol("term")\
    .setOutputCol("document")

chunk = Doc2Chunk()\
    .setInputCols("document")\
    .setOutputCol("chunk")\

token = Tokenizer()\
    .setInputCols(['document'])\
    .setOutputCol('token')

embeddings = WordEmbeddingsModel.pretrained("embeddings_healthcare_100d", "en", "clinical/models")\
      .setInputCols(["document", "token"])\
      .setOutputCol("embeddings")

chunk_emb = ChunkEmbeddings()\
      .setInputCols("chunk", "embeddings")\
      .setOutputCol("chunk_embeddings")

snomed_training_pipeline = Pipeline(
    stages = [
    document,
    chunk,
    token,
    embeddings,
    chunk_emb])

snomed_training_model = snomed_training_pipeline.fit(aap_train_sdf)

snomed_data = snomed_training_model.transform(aap_train_sdf).cache()


embeddings_healthcare_100d download started this may take some time.
Approximate size to download 475.8 MB
[OK!]


In [8]:
snomed_extractor = ChunkEntityResolverApproach() \
    .setInputCols("token", "chunk_embeddings") \
    .setOutputCol("recognized") \
    .setNeighbours(1000) \
    .setAlternatives(25) \
    .setNormalizedCol("_term") \
    .setLabelCol("conceptId") \
    .setEnableWmd(True).setEnableTfidf(True).setEnableJaccard(True)\
    .setEnableSorensenDice(True).setEnableJaroWinkler(True).setEnableLevenshtein(True)\
    .setDistanceWeights([1, 2, 2, 1, 1, 1]) \
    .setAllDistancesMetadata(True)\
    .setPoolingStrategy("MAX") \
    .setThreshold(1e32)

In [9]:
%time model = snomed_extractor.fit(snomed_data)

CPU times: user 85 ms, sys: 19.1 ms, total: 104 ms
Wall time: 25.7 s


## Prediction Pipeline

In [10]:
prediction_Model = PipelineModel(stages=[snomed_training_model, model])

In [11]:

aap_train_pred= prediction_Model.transform(aap_train_sdf).cache()
aap_test_pred= prediction_Model.transform(aap_test_sdf).cache()
aap_val_pred= prediction_Model.transform(aap_val_sdf).cache()

In [12]:
aap_test_pred.selectExpr("conceptId","term","_term","recognized[0].result","recognized[0].metadata.resolved_text","recognized[0].metadata.all_k_resolutions").show(truncate=50)

+----------------+--------------------------------+------------------------------------+--------------------+-------------------------------------+--------------------------------------------------+
|       conceptId|                            term|                               _term|recognized[0].result|recognized[0].metadata[resolved_text]|         recognized[0].metadata[all_k_resolutions]|
+----------------+--------------------------------+------------------------------------+--------------------+-------------------------------------+--------------------------------------------------+
|       108367008|                     dislocating|                Dislocation of joint|           304297005| Decreased range of shoulder movement|Decreased range of shoulder movement:::Increase...|
|3384011000036100|                       Arthrotec|                           Arthrotec|    3384011000036100|                            Arthrotec|Arthrotec:::Celebrex 200 mg capsule: hard:::Cel...|
|    

In [13]:
aap_test_pred.selectExpr("conceptId","term","_term","recognized[0].result","recognized[0].metadata.resolved_text","recognized[0].metadata.all_k_resolutions").show(truncate=50)

+----------------+--------------------------------+------------------------------------+--------------------+-------------------------------------+--------------------------------------------------+
|       conceptId|                            term|                               _term|recognized[0].result|recognized[0].metadata[resolved_text]|         recognized[0].metadata[all_k_resolutions]|
+----------------+--------------------------------+------------------------------------+--------------------+-------------------------------------+--------------------------------------------------+
|       108367008|                     dislocating|                Dislocation of joint|           304297005| Decreased range of shoulder movement|Decreased range of shoulder movement:::Increase...|
|3384011000036100|                       Arthrotec|                           Arthrotec|    3384011000036100|                            Arthrotec|Arthrotec:::Celebrex 200 mg capsule: hard:::Cel...|
|    

## Train Using the entire dataset

In [14]:
all_data = aap_train_sdf.union(aap_test_sdf).union(aap_val_sdf)

snomed_training_model = snomed_training_pipeline.fit(all_data)

snomed_data = snomed_training_model.transform(all_data).cache()

%time model = snomed_extractor.fit(snomed_data)

CPU times: user 81.5 ms, sys: 23.3 ms, total: 105 ms
Wall time: 16.8 s


In [15]:
model.write().overwrite().save("chunkresolve_snomed_askapatient_hc_100d")

## Prediction on random texts

In [16]:
# instead of loading a large clinical_ner, we will use the one coming from 100d pretrained_pipeline
from sparknlp.pretrained import PretrainedPipeline

pp_ner = PretrainedPipeline('explain_clinical_doc_carp', 'en', 'clinical/models')

explain_clinical_doc_carp download started this may take some time.
Approx size to download 526.5 MB
[OK!]


In [17]:
pp_ner.model.stages

[DocumentAssembler_8aeb50463a0d,
 SentenceDetector_635a56ed49ab,
 REGEX_TOKENIZER_6f0bd3b85024,
 WORD_EMBEDDINGS_MODEL_a5c1afb0b657,
 POS_be8d41751649,
 NerDLModel_706522935b2e,
 NerConverter_b818c367ba56,
 dependency_68159e3d6dac,
 NerDLModel_01b90ff03d9e,
 NerConverter_335d7d4208fc,
 RelationExtractionModel_0a71121bf321,
 ASSERTION_DL_941a00a50db4]

In [18]:
ner_100d = pp_ner.model.stages[-4]

In [19]:
documentAssembler = DocumentAssembler()\
  .setInputCol("term")\
  .setOutputCol("document")

# Sentence Detector annotator, processes various sentences per line

sentenceDetector = SentenceDetector()\
  .setInputCols(["document"])\
  .setOutputCol("sentence")\
  .setCustomBounds([","])

# Tokenizer splits words in a relevant format for NLP

tokenizer = Tokenizer()\
  .setInputCols(["sentence"])\
  .setOutputCol("raw_token")\

stopwords = StopWordsCleaner()\
  .setInputCols(["raw_token"])\
  .setOutputCol("token")

word_embeddings = WordEmbeddingsModel.pretrained("embeddings_healthcare_100d", "en", "clinical/models")\
      .setInputCols(["document", "token"])\
      .setOutputCol("embeddings")

ner_100d\
  .setInputCols(["sentence", "token", "embeddings"]) \
  .setOutputCol("ner")

snomed_ner_converter = NerConverterInternal() \
  .setInputCols(["sentence", "token", "ner"]) \
  .setOutputCol("greedy_chunk")\
  .setWhiteList(['PROBLEM','TEST'])

chunk_embeddings = ChunkEmbeddings()\
  .setInputCols('greedy_chunk', 'embeddings')\
  .setOutputCol('chunk_embeddings')

snomed_resolver = \
    ChunkEntityResolverModel.load("chunkresolve_snomed_askapatient_hc_100d")\
    .setInputCols("token","chunk_embeddings").setOutputCol("snomed_resolution")

pipeline_snomed = Pipeline(
    stages = [
    documentAssembler,
    sentenceDetector,
    tokenizer,
    stopwords,
    word_embeddings,
    ner_100d,
    snomed_ner_converter,
    chunk_embeddings,
    snomed_resolver
  ])

empty_data = spark.createDataFrame([['']]).toDF("term")

model_snomed = pipeline_snomed.fit(empty_data)


embeddings_healthcare_100d download started this may take some time.
Approximate size to download 475.8 MB
[OK!]


In [20]:
model_snomed_lp = LightPipeline(model_snomed)


In [21]:

result = model_snomed_lp.annotate('I have a biceps muscle pain and extreme muscle pain in shoulders')

list(zip(result['greedy_chunk'],result['snomed_resolution']))

[('biceps muscle pain', '288227007'), ('extreme muscle pain', '76948002')]

In [22]:

result = model_snomed_lp.annotate('I have a flu and a headache')

list(zip(result['greedy_chunk'],result['snomed_resolution']))

[('flu and a headache', '6142004')]

In [23]:
from pyspark.sql import functions as F

snomed_output = model_snomed.transform(spark.createDataFrame([['I have a biceps muscle pain and extreme muscle pain in shoulders']]).toDF("term"))

snomed_output.select(F.explode(F.arrays_zip("greedy_chunk.result","greedy_chunk.metadata","snomed_resolution.result","snomed_resolution.metadata")).alias("snomed_result")) \
    .select(F.expr("snomed_result['0']").alias("chunk"),
            F.expr("snomed_result['1'].entity").alias("entity"),
            F.expr("snomed_result['3'].all_k_resolutions").alias("target_text"),
            F.expr("snomed_result['2']").alias("code"),
            F.expr("snomed_result['3'].confidence").alias("confidence")).show(truncate = 100)

+-------------------+-------+----------------------------------------------------------------------------------------------------+---------+----------+
|              chunk| entity|                                                                                         target_text|     code|confidence|
+-------------------+-------+----------------------------------------------------------------------------------------------------+---------+----------+
| biceps muscle pain|PROBLEM|Myalgia/myositis - upper arm:::Myalgia:::Neck pain:::Myalgia/myositis - shoulder:::Backache:::Foo...|288227007|    0.0915|
|extreme muscle pain|PROBLEM|Severe pain:::Muscle fatigue:::Muscle weakness:::Myopathy:::Abdominal pain:::Constant pain:::Myal...| 76948002|    0.1527|
+-------------------+-------+----------------------------------------------------------------------------------------------------+---------+----------+



# Sentence Entity Resolver (BioBert sentence embeddings) (after v2.7)

In [24]:
aap_train_sdf.show()

+----------------+--------------------+--------------------+
|       conceptId|               _term|                term|
+----------------+--------------------+--------------------+
|       108367008|Dislocation of joint|Dislocation of joint|
|3384011000036100|           Arthrotec|           Arthrotec|
|       166717003|Serum creatinine ...|Serum creatinine ...|
|3877011000036101|             Lipitor|             Lipitor|
|       402234004|         Foot eczema|         Foot eczema|
|       404640003|           Dizziness|           Dizziness|
|       271681002|        Stomach ache|        Stomach ache|
|        76948002|         Severe pain|         Severe pain|
|        36031001|        Burning feet|        Burning feet|
|        76948002|         Severe pain|         Severe pain|
|        42399005|       Renal failure|       Renal failure|
|       288227007|Myalgia/myositis ...|Myalgia/myositis ...|
|       419723007|       Mentally dull|       Mentally dull|
|       248490000|    Bl

In [25]:
aap_train_sdf.printSchema()

root
 |-- conceptId: string (nullable = true)
 |-- _term: string (nullable = true)
 |-- term: string (nullable = true)



In [26]:

documentAssembler = DocumentAssembler()\
  .setInputCol("_term")\
  .setOutputCol("sentence")

bert_embeddings = BertSentenceEmbeddings.pretrained("sent_biobert_pubmed_base_cased")\
  .setInputCols(["sentence"])\
  .setOutputCol("bert_embeddings")

snomed_training_pipeline = Pipeline(
    stages = [
    documentAssembler,
    bert_embeddings])

snomed_training_model = snomed_training_pipeline.fit(aap_train_sdf)

snomed_data = snomed_training_model.transform(aap_train_sdf)


sent_biobert_pubmed_base_cased download started this may take some time.
Approximate size to download 386.4 MB
[OK!]


In [27]:
bertExtractor = SentenceEntityResolverApproach()\
  .setNeighbours(25)\
  .setThreshold(1000)\
  .setInputCols("bert_embeddings")\
  .setNormalizedCol("_term")\
  .setLabelCol("conceptId")\
  .setOutputCol('snomed_code')\
  .setDistanceFunction("EUCLIDIAN")\
  .setCaseSensitive(False)

%time snomed_model = bertExtractor.fit(snomed_data)

CPU times: user 499 ms, sys: 99.8 ms, total: 599 ms
Wall time: 1h 8min 22s


In [28]:
# save if you will need that later
snomed_model.write().overwrite().save("biobertresolve_snomed_askapatient")

In [29]:
prediction_Model = PipelineModel(stages=[snomed_training_model, snomed_model])

aap_train_pred= prediction_Model.transform(aap_train_sdf).cache()
aap_test_pred= prediction_Model.transform(aap_test_sdf).cache()
aap_val_pred= prediction_Model.transform(aap_val_sdf).cache()

In [30]:
aap_test_pred.selectExpr("conceptId","term","_term","snomed_code[0].result","snomed_code[0].metadata.resolved_text","snomed_code[0].metadata.all_k_resolutions").show(truncate=50)

+----------------+--------------------------------+------------------------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       conceptId|                            term|                               _term|snomed_code[0].result|snomed_code[0].metadata[resolved_text]|        snomed_code[0].metadata[all_k_resolutions]|
+----------------+--------------------------------+------------------------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       108367008|                     dislocating|                Dislocation of joint|            108367008|                  Dislocation of joint|Dislocation of joint:::Disorder of joint of sho...|
|3384011000036100|                       Arthrotec|                           Arthrotec|     3384011000036100|                             Arthrotec|                                         Arthro

In [31]:
aap_val_pred.selectExpr("conceptId","term","_term","snomed_code[0].result","snomed_code[0].metadata.resolved_text","snomed_code[0].metadata.all_k_resolutions").show(truncate=50)

+----------------+------------------------------+---------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       conceptId|                          term|                _term|snomed_code[0].result|snomed_code[0].metadata[resolved_text]|        snomed_code[0].metadata[all_k_resolutions]|
+----------------+------------------------------+---------------------+---------------------+--------------------------------------+--------------------------------------------------+
|       267032009|persisten feeling of tiredness|   Tired all the time|            267032009|                    Tired all the time|          Tired all the time:::Tightness in throat|
|        22298006|                  HEART ATTACK|Myocardial infarction|             22298006|                 Myocardial infarction|                             Myocardial infarction|
|3877011000036101|                       LIPITOR|              Lipitor|     3877