In [1]:
import sys, time
sys.path.append("/home/fernandrez/JSL/repos/spark-nlp/python")
sys.path.append("/home/fernandrez/JSL/repos/spark-nlp-internal/python")

In [2]:
import sys, time

from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp_jsl.annotator import *
from sparknlp.pretrained import ResourceDownloader
import pyspark.sql.functions as F
#from pyspark.sql.types import StructType, StructField, StringType
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.feature import HashingTF, IDF
from pyspark.ml.feature import StringIndexerModel
from pyspark.ml.classification import OneVsRestModel

In [3]:
concepts = concepts = spark.read.format("csv").option("header","true").load("../../../../data/resolution/rxnorm_sample.csv")\
.withColumn("STR", F.expr("lower(STR)"))

In [4]:
tokenizer_chars = ["'",",","/"," ",".","|","@","#","%","&","$","[","]","(",")","-",";","="]

In [5]:
docAssembler = DocumentAssembler().setInputCol("STR").setOutputCol("document")

tokenizer = Tokenizer().setInputCols("document").setOutputCol("token")\
    .setSplitChars(tokenizer_chars)

pipelineModel = Pipeline().setStages([docAssembler, tokenizer]).fit(concepts)

In [6]:
ngrammer = NGramGenerator()\
    .setInputCols(["token"])\
    .setOutputCol("ngram")\
    .setEnableCumulative(True)\
    .setDelimiter("_")

ngramToken = Chunk2Token()\
    .setInputCols("ngram")\
    .setOutputCol("ngram_token")

pipelineNgrams = PipelineModel(stages=[
    ngrammer,
    ngramToken])

In [7]:
embeddingsModel = WordEmbeddingsModel.pretrained("embeddings_icdoem_2ng", "en", "clinical/models")\
    .setInputCols("document", "token")\
    .setOutputCol("embeddings")

embeddings_icdoem_2ng download started this may take some time.
Approximate size to download 10.9 GB
[OK!]


In [8]:
doc2Chunk = Doc2Chunk().setInputCols("document").setOutputCol("chunk")

chunkEmbeddings = ChunkEmbeddings()\
    .setInputCols("chunk", "embeddings")\
    .setOutputCol("chunk_embeddings")

pipelineChunkEmbeddings = PipelineModel([doc2Chunk, chunkEmbeddings])

In [9]:
concepts_embedded = PipelineModel([pipelineModel, pipelineNgrams, embeddingsModel, pipelineChunkEmbeddings]).transform(concepts)

In [10]:
concepts_embedded.write.mode("overwrite").save("data/rxnorm_concepts_embedded")

In [11]:
concepts_embedded = spark.read.load("data/rxnorm_concepts_embedded")

In [12]:
#Let's check embeddings coverage
concepts_embedded.selectExpr("STY_TTY","explode(embeddings) as embs")\
.selectExpr("STY_TTY","case when embs.metadata.isOOV=='false' then 1 else 0 end as coverage")\
.groupby("STY_TTY").agg(F.expr("avg(coverage) as cov")).orderBy("cov").toPandas()["cov"].mean()

0.4221108830768571

In [13]:
word_distribution = concepts_embedded.selectExpr("explode(token.result) as word").groupby("word").count()
word_distribution.orderBy("count",ascending=False).show(100, False)

+---------------+-----+
|word           |count|
+---------------+-----+
|0              |479  |
|mg             |336  |
|1              |248  |
|5              |204  |
|-              |180  |
|2              |175  |
|.              |162  |
|ml             |134  |
|oral           |133  |
|3              |113  |
|,              |86   |
|)              |69   |
|(              |69   |
|in             |63   |
|product        |62   |
|tablet         |61   |
|4              |60   |
|g              |57   |
|6              |56   |
|topical        |47   |
|7              |43   |
|8              |39   |
|acid           |32   |
|solution       |26   |
|9              |26   |
|capsule        |25   |
|mcg            |23   |
|release        |19   |
|powder         |18   |
|hydrochloride  |18   |
|vitamin        |16   |
|"              |16   |
|extended       |16   |
|gram           |15   |
|potassium      |14   |
|liquid         |14   |
|injectable     |14   |
|chloride       |13   |
|iu             

In [14]:
word_distribution.count()

958

In [15]:
#Currently working on making the first layer available using the Pretrained framework -> 2.4.2
#model_idx = ResourceDownloader.downloadModel("StringIndexerModel", "resolve_snomed_l1_idx_icdoem_2ng", "en", "clinical/models")
#model_tfidf = ResourceDownloader.downloadPipeline("resolve_snomed_l1_tfidf_icdoem_2ng", "en", "clinical/models")
#model_ovrlrc = ResourceDownloader.downloadModel("OneVsRestModel", "resolve_snomed_l1_ovrlrc_icdoem_2ng", "en", "clinical/models")

In [16]:
sidx = StringIndexerModel.load("_models/rxnorm_indexer")

In [17]:
layer1 = DocumentLogRegClassifierApproach()\
    .setInputCols("ngram_token")\
    .setOutputCol("partition")\
    .setLabels(sidx.labels)\
    .setVectorizationModelPath("_models/rxnorm_tfidfer")\
    .setClassificationModelPath("_models/rxnorm_ovrlrc")\
    .fit(concepts_embedded)\
    .setMergeChunks(False)

In [19]:
#Second layer is available through the Pretrained framework
layer_2 = RecursivePipelineModel(
    ResourceDownloader.downloadPipeline("resolve_rxnorm_l2_icdoem_2ng", "en", "clinical/models")
)

resolve_rxnorm_l2_icdoem_2ng download started this may take some time.
Approx size to download 514.4 MB
[OK!]


In [20]:
fullPipeline = Pipeline().setStages([layer1, layer_2]).fit(concepts_embedded)

In [21]:
start = time.time()
transformed_full = fullPipeline.transform(concepts_embedded)

In [33]:
predicted = transformed_full.withColumn("prediction", F.expr("partition.result[0]")).cache()
metrics = predicted.withColumn("ok",F.expr("case when prediction==STY_TTY then 1 else 0 end"))\
                                   .groupby("STY_TTY").agg(F.expr("avg(ok) as recall"), F.expr("count(ok) as tr_cnt"))\
                                    .join(
predicted.withColumn("ok",F.expr("case when prediction==STY_TTY then 1 else 0 end"))\
                                   .groupby("prediction").agg(F.expr("avg(ok) as precision")),F.col("STY_TTY")==F.col("prediction")
).withColumn("f1", F.expr("2*precision*recall/(precision+recall)")).orderBy("f1")\
.selectExpr("STY_TTY","tr_cnt","round(precision,3) as train_precision","round(recall,3) as train_recall","round(f1, 3) as train_f1")

In [34]:
metrics.show(100, False)

+------------------------------------------------------+------+---------------+------------+--------+
|STY_TTY                                               |tr_cnt|train_precision|train_recall|train_f1|
+------------------------------------------------------+------+---------------+------------+--------+
|Manufactured Object                                   |3     |1.0            |0.333       |0.5     |
|Indicator, Reagent, or Diagnostic Aid                 |2     |0.5            |0.5         |0.5     |
|Clinical Drug Semantic Drug Component                 |23    |0.447          |0.913       |0.6     |
|Biomedical or Dental Material                         |5     |0.5            |0.8         |0.615   |
|Organic Chemical                                      |8     |0.583          |0.875       |0.7     |
|Pharmacologic Substance                               |46    |0.673          |0.804       |0.733   |
|Clinical Drug Semantic Clinical Drug                  |39    |0.667          |0.9

In [25]:
with_alternatives = predicted\
    .withColumn("resolution",F.expr("split(substring(rxnorm_code.metadata[0]['all_k_results'],2,length(rxnorm_code.metadata[0]['all_k_results'])-2),'\\\\],\\\\[')"))

In [28]:
evaled = with_alternatives\
    .withColumn("good", F.expr("case when RXCUI=rxnorm_code.result[0] then 1 else 0 end"))\
    .withColumn("hat5", F.expr("case when array_contains(slice(resolution, 1, 5), RXCUI) then 1 else 0 end"))\
    .withColumn("hat10", F.expr("case when array_contains(slice(resolution, 1, 10), RXCUI) then 1 else 0 end"))\
    .withColumn("hat20", F.expr("case when array_contains(slice(resolution, 1, 20), RXCUI) then 1 else 0 end"))\
    .withColumn("hat30", F.expr("case when array_contains(slice(resolution, 1, 30), RXCUI) then 1 else 0 end"))\
    .withColumn("hat500", F.expr("case when array_contains(slice(resolution, 1, 500), RXCUI) then 1 else 0 end"))

In [29]:
evaled.groupby("STY_TTY").agg(
    F.mean("good"), 
    F.mean("hat5"), 
    F.mean("hat10"), 
    F.mean("hat20"), 
    F.mean("hat30"), 
    F.mean("hat500"), 
    F.count("good")).orderBy("count(good)", ascending=False)\
.selectExpr("STY_TTY",
            "round(`avg(good)`, 2) as good",
            "round(`avg(hat5)`, 2) as hat5",
            "round(`avg(hat10)`, 2) as hat10",
            "round(`avg(hat20)`, 2) as hat20",
            "round(`avg(hat30)`, 2) as hat30",
            "round(`avg(hat500)`, 2) as hat500",
            "`count(good)` as total")\
.show(100,False)

+------------------------------------------------------+----+----+-----+-----+-----+------+-----+
|STY_TTY                                               |good|hat5|hat10|hat20|hat30|hat500|total|
+------------------------------------------------------+----+----+-----+-----+-----+------+-----+
|Clinical Drug Clinical Drug                           |0.55|0.85|0.93 |0.95 |0.95 |0.95  |86   |
|Pharmacologic Substance                               |0.74|0.78|0.78 |0.78 |0.8  |0.8   |46   |
|Clinical Drug Semantic Clinical Drug                  |0.28|0.67|0.85 |0.92 |0.92 |0.92  |39   |
|Clinical Drug Semantic branded drug group             |0.34|0.45|0.45 |0.48 |0.48 |0.48  |29   |
|Medical Device                                        |0.43|0.71|0.79 |0.86 |0.86 |0.86  |28   |
|Clinical Drug Clinical drug name in abbreviated format|0.44|0.78|0.78 |0.93 |0.93 |0.93  |27   |
|Clinical Drug Semantic Drug Component                 |0.17|0.57|0.61 |0.83 |0.91 |0.91  |23   |
|Clinical Drug Seman

In [30]:
print(round((time.time()-start)/60, 2), "minutes")

2.64 minutes
