In [1]:
#Please make sure you have SparkNLP 2.4.2 and SparkNLP Enterprise 2.4.2

In [2]:
import sys, time

In [3]:
from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp_jsl.annotator import *
from sparknlp.pretrained import ResourceDownloader
import pyspark.sql.functions as F
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.feature import HashingTF, IDF
from pyspark.ml.feature import StringIndexerModel
from pyspark.ml.classification import OneVsRestModel

In [4]:
concepts = concepts = spark.read.format("csv").option("header","true").load("../../../data/resolution/snomed_sample.csv")\
.withColumn("term", F.expr("lower(term)"))

In [5]:
tokenizer_chars = ["'",",","/"," ",".","|","@","#","%","&","$","[","]","(",")","-",";","="]

In [6]:
docAssembler = DocumentAssembler().setInputCol("term").setOutputCol("document")

tokenizer = Tokenizer().setInputCols("document").setOutputCol("token")\
    .setSplitChars(tokenizer_chars)

pipelineModel = Pipeline().setStages([docAssembler, tokenizer]).fit(concepts)

In [7]:
embeddingsModel = WordEmbeddingsModel.pretrained("embeddings_clinical", "en", "clinical/models")\
    .setInputCols("document", "token")\
    .setOutputCol("embeddings")

embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]


In [8]:
doc2Chunk = Doc2Chunk().setInputCols("document").setOutputCol("chunk")

chunkEmbeddings = ChunkEmbeddings()\
    .setInputCols("chunk", "embeddings")\
    .setOutputCol("chunk_embeddings")

pipelineChunkEmbeddings = PipelineModel([doc2Chunk, chunkEmbeddings])

In [9]:
concepts_embedded = PipelineModel([pipelineModel, embeddingsModel, pipelineChunkEmbeddings]).transform(concepts)

In [10]:
concepts_embedded.write.mode("overwrite").save("data/concepts_embedded")

In [11]:
concepts_embedded = spark.read.load("data/concepts_embedded")

In [12]:
#Let's check embeddings coverage
concepts_embedded.selectExpr("conceptId","explode(embeddings) as embs")\
.selectExpr("conceptId","case when embs.metadata.isOOV=='false' then 1 else 0 end as coverage")\
.groupby("conceptId").agg(F.expr("avg(coverage) as cov")).orderBy("cov").toPandas()["cov"].mean()

0.9727644122164317

In [13]:
#word_distribution = concepts_embedded.selectExpr("explode(token.result) as word").groupby("word").count()
#word_distribution.orderBy("count",ascending=True).show(100, False)

In [14]:
#word_distribution.count()

In [15]:
#SNOMED Resolution
snomed_resolver_l1 = DocumentLogRegClassifierModel.pretrained("resolve_snomed_clinical_l1", "en", "clinical/models")\
    .setInputCols("token").setOutputCol("partition")
snomed_resolver_l2 = ResourceDownloader.downloadPipeline("resolve_snomed_clinical_l2", "en", "clinical/models")
snomed_resolver_l2.stages[-1].setInputCols("partition","token","chunk_embeddings")

resolve_snomed_clinical_l1 download started this may take some time.
Approximate size to download 15.3 MB
[OK!]
resolve_snomed_clinical_l2 download started this may take some time.
Approx size to download 583.4 MB
[OK!]


ChunkEntityResolverSelector_d41a7a88595b

In [16]:
snomed_resolution = PipelineModel([snomed_resolver_l1, RecursivePipelineModel(snomed_resolver_l2)])

In [17]:
start = time.time()
transformed_full = snomed_resolution.transform(concepts_embedded)

In [18]:
predicted = transformed_full.withColumn("prediction", F.expr("partition.result[0]")).cache()
metrics = predicted.withColumn("ok",F.expr("case when prediction==topTerm then 1 else 0 end"))\
                                   .groupby("topTerm").agg(F.expr("avg(ok) as recall"), F.expr("count(ok) as tr_cnt"))\
                                    .join(
predicted.withColumn("ok",F.expr("case when prediction==topTerm then 1 else 0 end"))\
                                   .groupby("prediction").agg(F.expr("avg(ok) as precision")),F.col("topTerm")==F.col("prediction")
).withColumn("f1", F.expr("2*precision*recall/(precision+recall)")).orderBy("f1")\
.selectExpr("topTerm","tr_cnt","round(precision,3) as train_precision","round(recall,3) as train_recall","round(f1, 3) as train_f1")

In [20]:
metrics.show(100, False)

+-------------------------------------------------------------+------+---------------+------------+--------+
|topTerm                                                      |tr_cnt|train_precision|train_recall|train_f1|
+-------------------------------------------------------------+------+---------------+------------+--------+
|General clinical state finding (finding)                     |5     |0.667          |0.4         |0.5     |
|Finding by method (finding)                                  |20    |0.727          |0.4         |0.516   |
|Evaluation finding (finding)                                 |58    |0.552          |0.552       |0.552   |
|Neurological finding (finding)                               |22    |0.571          |0.545       |0.558   |
|Wound finding (finding)                                      |9     |0.8            |0.444       |0.571   |
|Special concept (special concept)                            |8     |1.0            |0.5         |0.667   |
|Administrative sta

In [21]:
with_alternatives = predicted\
    .withColumn("resolution",F.expr("split(snomed_code.metadata[0]['all_k_results'],':|:')"))

In [22]:
evaled = with_alternatives\
    .withColumn("good", F.expr("case when conceptId=snomed_code.result[0] then 1 else 0 end"))\
    .withColumn("hat5", F.expr("case when array_contains(slice(resolution, 1, 5), conceptId) then 1 else 0 end"))\
    .withColumn("hat10", F.expr("case when array_contains(slice(resolution, 1, 10), conceptId) then 1 else 0 end"))\
    .withColumn("hat20", F.expr("case when array_contains(slice(resolution, 1, 20), conceptId) then 1 else 0 end"))\
    .withColumn("hat30", F.expr("case when array_contains(slice(resolution, 1, 30), conceptId) then 1 else 0 end"))\
    .withColumn("hat500", F.expr("case when array_contains(slice(resolution, 1, 500), conceptId) then 1 else 0 end"))

In [23]:
evaled.groupby("topTerm").agg(
    F.mean("good"), 
    F.mean("hat5"), 
    F.mean("hat10"), 
    F.mean("hat20"), 
    F.mean("hat30"), 
    F.mean("hat500"), 
    F.count("good")).orderBy("count(good)", ascending=False)\
.selectExpr("topTerm",
            "round(`avg(good)`, 2) as good",
            "round(`avg(hat5)`, 2) as hat5",
            "round(`avg(hat10)`, 2) as hat10",
            "round(`avg(hat20)`, 2) as hat20",
            "round(`avg(hat30)`, 2) as hat30",
            "round(`avg(hat500)`, 2) as hat500",
            "`count(good)` as total")\
.show(100,False)

+-------------------------------------------------------------+----+----+-----+-----+-----+------+-----+
|topTerm                                                      |good|hat5|hat10|hat20|hat30|hat500|total|
+-------------------------------------------------------------+----+----+-----+-----+-----+------+-----+
|Procedure (procedure)                                        |0.91|0.93|0.93 |0.93 |0.93 |0.93  |738  |
|Finding by site (finding)                                    |0.72|0.73|0.74 |0.74 |0.74 |0.74  |722  |
|Body structure (body structure)                              |0.9 |0.92|0.92 |0.92 |0.92 |0.92  |489  |
|Organism (organism)                                          |0.49|0.55|0.57 |0.57 |0.57 |0.57  |345  |
|Disease (disorder)                                           |0.68|0.69|0.69 |0.69 |0.69 |0.69  |325  |
|Substance (substance)                                        |0.73|0.79|0.8  |0.8  |0.8  |0.8   |307  |
|Clinical history and observation findings (finding)   

In [24]:
print(round((time.time()-start)/60, 2), "minutes")

4.35 minutes
