In [1]:
#Please make sure you have SparkNLP 2.4.5 and SparkNLP Enterprise 2.4.6

In [2]:
import sys, time
import sparknlp_jsl
from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp_jsl.annotator import *
from sparknlp.pretrained import ResourceDownloader
import pyspark.sql.functions as F
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.feature import HashingTF, IDF
from pyspark.ml.feature import StringIndexerModel
from pyspark.ml.classification import OneVsRestModel

In [3]:
spark = sparknlp_jsl.start("####")

In [4]:
concepts = concepts = spark.read.format("csv").option("header","true").load("../../../data/resolution/rxnorm_sample.csv")\
.withColumn("STR", F.expr("lower(STR)"))

In [5]:
tokenizer_chars = ["'",",","/"," ",".","|","@","#","%","&","$","[","]","(",")","-",";","="]

In [6]:
docAssembler = DocumentAssembler().setInputCol("STR").setOutputCol("document")

tokenizer = Tokenizer().setInputCols("document").setOutputCol("token")\
    .setSplitChars(tokenizer_chars)

pipelineModel = Pipeline().setStages([docAssembler, tokenizer]).fit(concepts)

In [7]:
embeddingsModel = WordEmbeddingsModel.pretrained("embeddings_clinical", "en", "clinical/models")\
    .setInputCols("document", "token")\
    .setOutputCol("embeddings")

embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]


In [8]:
doc2Chunk = Doc2Chunk().setInputCols("document").setOutputCol("chunk")
chunkTokenizer = ChunkTokenizer().setInputCols("chunk").setOutputCol("chunk_token").fit(concepts)

chunkEmbeddings = ChunkEmbeddings()\
    .setInputCols("chunk", "embeddings")\
    .setOutputCol("chunk_embeddings")

pipelineChunkEmbeddings = PipelineModel([doc2Chunk, chunkTokenizer, chunkEmbeddings])

In [9]:
concepts_embedded = PipelineModel([pipelineModel, embeddingsModel, pipelineChunkEmbeddings]).transform(concepts)

In [10]:
concepts_embedded.write.mode("overwrite").save("data/rxnorm_concepts_embedded")

In [11]:
concepts_embedded = spark.read.load("data/rxnorm_concepts_embedded")

In [12]:
#Let's check embeddings coverage
concepts_embedded.selectExpr("STY_TTY","explode(embeddings) as embs")\
.selectExpr("STY_TTY","case when embs.metadata.isOOV=='false' then 1 else 0 end as coverage")\
.groupby("STY_TTY").agg(F.expr("avg(coverage) as cov")).orderBy("cov").toPandas()["cov"].mean()

0.9471964951859947

In [13]:
#word_distribution = concepts_embedded.selectExpr("explode(token.result) as word").groupby("word").count()
#word_distribution.orderBy("count",ascending=False).show(100, False)

In [14]:
#word_distribution.count()

In [15]:
#RxNorm Resolution
rxnorm_resolution = EnsembleEntityResolverModel.pretrained("ensembleresolve_rxnorm_clinical", "en", "clinical/models")\
    .setInputCols("chunk_token","chunk_embeddings").setOutputCol("rxnorm_result")\
    .setNeighbours(300).setAlternatives(300)

ensembleresolve_rxnorm_clinical download started this may take some time.
Approximate size to download 783.3 MB
[OK!]


In [16]:
start = time.time()
transformed_full = rxnorm_resolution.transform(concepts_embedded)

In [17]:
with_alternatives = transformed_full\
    .withColumn("resolution",F.expr("split(rxnorm_result.metadata[0]['all_k_results'],':::')"))

In [21]:
evaled = with_alternatives\
    .withColumn("good", F.expr("case when RXCUI=rxnorm_result.result[0] then 1 else 0 end"))\
    .withColumn("hat5", F.expr("case when array_contains(slice(resolution, 1, 5), RXCUI) then 1 else 0 end"))\
    .withColumn("hat10", F.expr("case when array_contains(slice(resolution, 1, 10), RXCUI) then 1 else 0 end"))\
    .withColumn("hat20", F.expr("case when array_contains(slice(resolution, 1, 20), RXCUI) then 1 else 0 end"))\
    .withColumn("hat30", F.expr("case when array_contains(slice(resolution, 1, 30), RXCUI) then 1 else 0 end"))\
    .withColumn("hat500", F.expr("case when array_contains(slice(resolution, 1, 100), RXCUI) then 1 else 0 end"))

In [24]:
evaled.groupby("STY_TTY").agg(
    F.mean("good"), 
    F.mean("hat5"), 
    F.mean("hat10"), 
    F.mean("hat20"), 
    F.mean("hat30"), 
    F.mean("hat500"), 
    F.count("good")).orderBy("count(good)", ascending=False)\
.selectExpr("STY_TTY",
            "round(`avg(good)`, 2) as good",
            "round(`avg(hat5)`, 2) as hat5",
            "round(`avg(hat10)`, 2) as hat10",
            "round(`avg(hat20)`, 2) as hat20",
            "round(`avg(hat30)`, 2) as hat30",
            "round(`avg(hat500)`, 2) as hat500",
            "`count(good)` as total")\
.show(100,False)

+------------------------------------------------------+----+----+-----+-----+-----+------+-----+
|STY_TTY                                               |good|hat5|hat10|hat20|hat30|hat500|total|
+------------------------------------------------------+----+----+-----+-----+-----+------+-----+
|Clinical Drug Clinical Drug                           |0.86|0.88|0.91 |0.91 |0.91 |0.91  |65   |
|Medical Device                                        |1.0 |1.0 |1.0  |1.0  |1.0  |1.0   |49   |
|Pharmacologic Substance                               |0.84|0.84|0.84 |0.84 |0.84 |0.84  |45   |
|Clinical Drug Semantic Drug Component                 |1.0 |1.0 |1.0  |1.0  |1.0  |1.0   |36   |
|Clinical Drug Semantic branded drug group             |0.83|0.83|0.83 |0.83 |0.83 |0.83  |36   |
|Clinical Drug Semantic Clinical Drug                  |0.94|0.94|0.94 |0.94 |0.94 |0.94  |32   |
|Clinical Drug Clinical drug name in abbreviated format|0.9 |0.9 |0.9  |0.9  |0.9  |0.9   |31   |
|Clinical Drug Seman

In [25]:
print(round((time.time()-start)/60, 2), "minutes")

2.99 minutes
