In [1]:
#Please make sure you have SparkNLP 2.4.2 and SparkNLP Enterprise 2.4.2

In [2]:
import sys, time

In [3]:
from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp_jsl.annotator import *
from sparknlp.pretrained import ResourceDownloader
import pyspark.sql.functions as F
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.feature import HashingTF, IDF
from pyspark.ml.feature import StringIndexerModel
from pyspark.ml.classification import OneVsRestModel

In [4]:
concepts = concepts = spark.read.format("csv").option("header","true").load("../../../data/resolution/rxnorm_sample.csv")\
.withColumn("STR", F.expr("lower(STR)"))

In [5]:
tokenizer_chars = ["'",",","/"," ",".","|","@","#","%","&","$","[","]","(",")","-",";","="]

In [6]:
docAssembler = DocumentAssembler().setInputCol("STR").setOutputCol("document")

tokenizer = Tokenizer().setInputCols("document").setOutputCol("token")\
    .setSplitChars(tokenizer_chars)

pipelineModel = Pipeline().setStages([docAssembler, tokenizer]).fit(concepts)

In [7]:
embeddingsModel = WordEmbeddingsModel.pretrained("embeddings_clinical", "en", "clinical/models")\
    .setInputCols("document", "token")\
    .setOutputCol("embeddings")

embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]


In [8]:
doc2Chunk = Doc2Chunk().setInputCols("document").setOutputCol("chunk")

chunkEmbeddings = ChunkEmbeddings()\
    .setInputCols("chunk", "embeddings")\
    .setOutputCol("chunk_embeddings")

pipelineChunkEmbeddings = PipelineModel([doc2Chunk, chunkEmbeddings])

In [9]:
concepts_embedded = PipelineModel([pipelineModel, embeddingsModel, pipelineChunkEmbeddings]).transform(concepts)

In [10]:
concepts_embedded.write.mode("overwrite").save("data/rxnorm_concepts_embedded")

In [11]:
concepts_embedded = spark.read.load("data/rxnorm_concepts_embedded")

In [12]:
#Let's check embeddings coverage
concepts_embedded.selectExpr("STY_TTY","explode(embeddings) as embs")\
.selectExpr("STY_TTY","case when embs.metadata.isOOV=='false' then 1 else 0 end as coverage")\
.groupby("STY_TTY").agg(F.expr("avg(coverage) as cov")).orderBy("cov").toPandas()["cov"].mean()

0.9571541962682356

In [13]:
#word_distribution = concepts_embedded.selectExpr("explode(token.result) as word").groupby("word").count()
#word_distribution.orderBy("count",ascending=False).show(100, False)

In [14]:
#word_distribution.count()

In [15]:
#RxNorm Resolution
rxnorm_resolver_l1 = DocumentLogRegClassifierModel.pretrained("resolve_rxnorm_clinical_l1", "en", "clinical/models")\
    .setInputCols("token").setOutputCol("partition")
rxnorm_resolver_l2 = ResourceDownloader.downloadPipeline("resolve_rxnorm_clinical_l2", "en", "clinical/models")
rxnorm_resolver_l2.stages[-1].setInputCols("partition","token","chunk_embeddings")

resolve_rxnorm_clinical_l1 download started this may take some time.
Approximate size to download 7.5 MB
[OK!]
resolve_rxnorm_clinical_l2 download started this may take some time.
Approx size to download 276.1 MB
[OK!]


ChunkEntityResolverSelector_82f461d80b1e

In [16]:
fullPipeline = PipelineModel([rxnorm_resolver_l1, RecursivePipelineModel(rxnorm_resolver_l2)])

In [17]:
start = time.time()
transformed_full = fullPipeline.transform(concepts_embedded)

In [18]:
predicted = transformed_full.withColumn("prediction", F.expr("partition.result[0]")).cache()
metrics = predicted.withColumn("ok",F.expr("case when prediction==STY_TTY then 1 else 0 end"))\
                                   .groupby("STY_TTY").agg(F.expr("avg(ok) as recall"), F.expr("count(ok) as tr_cnt"))\
                                    .join(
predicted.withColumn("ok",F.expr("case when prediction==STY_TTY then 1 else 0 end"))\
                                   .groupby("prediction").agg(F.expr("avg(ok) as precision")),F.col("STY_TTY")==F.col("prediction")
).withColumn("f1", F.expr("2*precision*recall/(precision+recall)")).orderBy("f1")\
.selectExpr("STY_TTY","tr_cnt","round(precision,3) as train_precision","round(recall,3) as train_recall","round(f1, 3) as train_f1")

In [19]:
metrics.show(100, False)

+------------------------------------------------------+------+---------------+------------+--------+
|STY_TTY                                               |tr_cnt|train_precision|train_recall|train_f1|
+------------------------------------------------------+------+---------------+------------+--------+
|Manufactured Object                                   |1     |0.0            |0.0         |null    |
|Clinical Drug Semantic Clinical Drug                  |32    |0.0            |0.0         |null    |
|Clinical Drug Semantic branded drug                   |23    |0.2            |0.043       |0.071   |
|Clinical Drug Semantic clinical drug and form         |10    |0.237          |0.9         |0.375   |
|Biomedical or Dental Material                         |5     |0.4            |0.4         |0.4     |
|Immunologic Factor                                    |2     |0.5            |0.5         |0.5     |
|Clinical Drug Semantic clinical drug group            |19    |0.875          |0.3

In [20]:
with_alternatives = predicted\
    .withColumn("resolution",F.expr("split(rxnorm_code.metadata[0]['all_k_results'],':|:')"))

In [21]:
evaled = with_alternatives\
    .withColumn("good", F.expr("case when RXCUI=rxnorm_code.result[0] then 1 else 0 end"))\
    .withColumn("hat5", F.expr("case when array_contains(slice(resolution, 1, 5), RXCUI) then 1 else 0 end"))\
    .withColumn("hat10", F.expr("case when array_contains(slice(resolution, 1, 10), RXCUI) then 1 else 0 end"))\
    .withColumn("hat20", F.expr("case when array_contains(slice(resolution, 1, 20), RXCUI) then 1 else 0 end"))\
    .withColumn("hat30", F.expr("case when array_contains(slice(resolution, 1, 30), RXCUI) then 1 else 0 end"))\
    .withColumn("hat500", F.expr("case when array_contains(slice(resolution, 1, 500), RXCUI) then 1 else 0 end"))

In [22]:
evaled.groupby("STY_TTY").agg(
    F.mean("good"), 
    F.mean("hat5"), 
    F.mean("hat10"), 
    F.mean("hat20"), 
    F.mean("hat30"), 
    F.mean("hat500"), 
    F.count("good")).orderBy("count(good)", ascending=False)\
.selectExpr("STY_TTY",
            "round(`avg(good)`, 2) as good",
            "round(`avg(hat5)`, 2) as hat5",
            "round(`avg(hat10)`, 2) as hat10",
            "round(`avg(hat20)`, 2) as hat20",
            "round(`avg(hat30)`, 2) as hat30",
            "round(`avg(hat500)`, 2) as hat500",
            "`count(good)` as total")\
.show(100,False)

+------------------------------------------------------+----+----+-----+-----+-----+------+-----+
|STY_TTY                                               |good|hat5|hat10|hat20|hat30|hat500|total|
+------------------------------------------------------+----+----+-----+-----+-----+------+-----+
|Clinical Drug Clinical Drug                           |0.61|0.7 |0.7  |0.7  |0.7  |0.7   |69   |
|Pharmacologic Substance                               |0.78|0.78|0.78 |0.78 |0.78 |0.78  |46   |
|Medical Device                                        |0.73|0.86|0.95 |0.95 |0.95 |0.95  |37   |
|Clinical Drug Semantic Clinical Drug                  |0.06|0.06|0.06 |0.06 |0.06 |0.06  |32   |
|Clinical Drug Clinical drug name in abbreviated format|0.68|0.75|0.75 |0.75 |0.75 |0.75  |28   |
|Clinical Drug Semantic Drug Component                 |0.71|0.89|0.89 |0.89 |0.89 |0.89  |28   |
|Clinical Drug Semantic branded drug                   |0.0 |0.04|0.04 |0.04 |0.04 |0.04  |23   |
|Clinical Drug Seman

In [24]:
print(round((time.time()-start)/60, 2), "minutes")

0.26 minutes
