In [1]:
#Please make sure you have SparkNLP 2.4.5 and SparkNLP Enterprise 2.4.6

In [2]:
import sys, time
import sparknlp_jsl
from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp_jsl.annotator import *
from sparknlp.pretrained import ResourceDownloader
import pyspark.sql.functions as F
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.feature import HashingTF, IDF
from pyspark.ml.feature import StringIndexerModel
from pyspark.ml.classification import OneVsRestModel

In [3]:
spark = sparknlp_jsl.start("####")

In [4]:
concepts = concepts = spark.read.format("csv").option("header","true").load("../../../data/resolution/snomed_sample.csv")\
.withColumn("term", F.expr("lower(term)"))

In [5]:
tokenizer_chars = ["'",",","/"," ",".","|","@","#","%","&","$","[","]","(",")","-",";","="]

In [6]:
docAssembler = DocumentAssembler().setInputCol("term").setOutputCol("document")

tokenizer = Tokenizer().setInputCols("document").setOutputCol("token")\
    .setSplitChars(tokenizer_chars)

pipelineModel = Pipeline().setStages([docAssembler, tokenizer]).fit(concepts)

In [7]:
embeddingsModel = WordEmbeddingsModel.pretrained("embeddings_clinical", "en", "clinical/models")\
    .setInputCols("document", "token")\
    .setOutputCol("embeddings")

embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]


In [10]:
doc2Chunk = Doc2Chunk().setInputCols("document").setOutputCol("chunk")
chunkTokenizer = ChunkTokenizer().setInputCols("chunk").setOutputCol("chunk_token").fit(concepts)

chunkEmbeddings = ChunkEmbeddings()\
    .setInputCols("chunk", "embeddings")\
    .setOutputCol("chunk_embeddings")

pipelineChunkEmbeddings = PipelineModel([doc2Chunk, chunkEmbeddings])

In [13]:
concepts_embedded = PipelineModel([pipelineModel, embeddingsModel, doc2Chunk, chunkTokenizer, pipelineChunkEmbeddings])\
.transform(concepts)

In [14]:
concepts_embedded.write.mode("overwrite").save("data/concepts_embedded")

In [15]:
concepts_embedded = spark.read.load("data/concepts_embedded")

In [16]:
#Let's check embeddings coverage
concepts_embedded.selectExpr("conceptId","explode(embeddings) as embs")\
.selectExpr("conceptId","case when embs.metadata.isOOV=='false' then 1 else 0 end as coverage")\
.groupby("conceptId").agg(F.expr("avg(coverage) as cov")).orderBy("cov").toPandas()["cov"].mean()

0.9727644122164317

In [17]:
#word_distribution = concepts_embedded.selectExpr("explode(token.result) as word").groupby("word").count()
#word_distribution.orderBy("count",ascending=True).show(100, False)

In [18]:
#word_distribution.count()

In [21]:
#SNOMED Resolution
snomed_resolution = \
    EnsembleEntityResolverModel.pretrained("ensembleresolve_snomed_clinical","en","clinical/models")\
    .setInputCols("chunk_token","chunk_embeddings").setOutputCol("snomed_code")\
.setNeighbours(300).setAlternatives(300)

ensembleresolve_snomed_clinical download started this may take some time.
Approximate size to download 592.9 MB
[OK!]


In [22]:
start = time.time()
transformed_full = snomed_resolution.transform(concepts_embedded)

In [23]:
with_alternatives = transformed_full\
    .withColumn("resolution",F.expr("split(snomed_code.metadata[0]['all_k_results'],':::')"))

In [24]:
evaled = with_alternatives\
    .withColumn("good", F.expr("case when conceptId=snomed_code.result[0] then 1 else 0 end"))\
    .withColumn("hat5", F.expr("case when array_contains(slice(resolution, 1, 5), conceptId) then 1 else 0 end"))\
    .withColumn("hat10", F.expr("case when array_contains(slice(resolution, 1, 10), conceptId) then 1 else 0 end"))\
    .withColumn("hat20", F.expr("case when array_contains(slice(resolution, 1, 20), conceptId) then 1 else 0 end"))\
    .withColumn("hat30", F.expr("case when array_contains(slice(resolution, 1, 30), conceptId) then 1 else 0 end"))\
    .withColumn("hat500", F.expr("case when array_contains(slice(resolution, 1, 500), conceptId) then 1 else 0 end"))

In [25]:
evaled.groupby("topTerm").agg(
    F.mean("good"), 
    F.mean("hat5"), 
    F.mean("hat10"), 
    F.mean("hat20"), 
    F.mean("hat30"), 
    F.mean("hat500"), 
    F.count("good")).orderBy("count(good)", ascending=False)\
.selectExpr("topTerm",
            "round(`avg(good)`, 2) as good",
            "round(`avg(hat5)`, 2) as hat5",
            "round(`avg(hat10)`, 2) as hat10",
            "round(`avg(hat20)`, 2) as hat20",
            "round(`avg(hat30)`, 2) as hat30",
            "round(`avg(hat500)`, 2) as hat500",
            "`count(good)` as total")\
.show(100,False)

+-------------------------------------------------------------+----+----+-----+-----+-----+------+-----+
|topTerm                                                      |good|hat5|hat10|hat20|hat30|hat500|total|
+-------------------------------------------------------------+----+----+-----+-----+-----+------+-----+
|Procedure (procedure)                                        |0.98|0.98|0.98 |0.98 |0.98 |0.98  |738  |
|Finding by site (finding)                                    |0.93|0.93|0.93 |0.93 |0.93 |0.93  |722  |
|Body structure (body structure)                              |0.93|0.94|0.94 |0.95 |0.95 |0.95  |489  |
|Organism (organism)                                          |0.78|0.81|0.82 |0.82 |0.83 |0.83  |345  |
|Disease (disorder)                                           |0.81|0.81|0.81 |0.81 |0.81 |0.81  |325  |
|Substance (substance)                                        |0.89|0.93|0.93 |0.94 |0.95 |0.95  |307  |
|Clinical history and observation findings (finding)   

In [26]:
print(round((time.time()-start)/60, 2), "minutes")

5.38 minutes
