Show how to use pretrained assertion status

In [1]:
import sys
sys.path.append('../../')

from pyspark.sql import SparkSession
from pyspark.ml import PipelineModel

from sparknlp.annotator import *
from sparknlp.common import *
from sparknlp.base import *
from sparknlp.pretrained import ResourceDownloader

from pathlib import Path

if sys.version_info[0] < 3:
    from urllib import urlretrieve
else:
    from urllib.request import urlretrieve

In [2]:
spark = SparkSession.builder \
    .appName("assertion-status")\
    .master("local[1]")\
    .config("spark.driver.memory","4G")\
    .config("spark.driver.maxResultSize", "2G")\
    .config("spark.jar", "lib/sparknlp.jar")\
    .getOrCreate()

Create some data for testing purposes

In [3]:
from pyspark.sql import Row
R = Row('sentence', 'start', 'end')
test_data = spark.createDataFrame([R('Sister with stomach cancer .',2,3),
                      R('A thallium stress test showed tachycardia and severe dyspnea',5,5),
                      R('Positive for shortness of breath, no cough',2,4),
                      R('Positive for shortness of breath, no cough',7,7)])

Create some pipelines, one for each type of assertion classification algorithm, model download can take some time.

In [4]:
import time

# instantiate the downloader
downloader = ResourceDownloader()

embeddingsFile = 'PubMed-shuffle-win-2.bin'
embeddingsUrl = 'https://s3.amazonaws.com/auxdata.johnsnowlabs.com/PubMed-shuffle-win-2.bin'
# this may take a couple minutes
if not Path(embeddingsFile).is_file():
    urlretrieve(embeddingsUrl, embeddingsFile)

documentAssembler = DocumentAssembler() \
    .setInputCol("sentence") \
    .setOutputCol("document")

# download logistic regression based assertion status trained on negex dataset
assertion_fast_lg = downloader.downloadModel(AssertionLogRegModel, "as_fast_lg", "en") \
    .setInputCols(["document"]) \
    .setOutputCol("assertion") \

# download bidirectional lstm based assertion status trained on negex dataset
assertion_fast_dl = downloader.downloadModel(AssertionDLModel, "as_fast_dl", "en") \
    .setInputCols(["document"]) \
    .setOutputCol("assertion") \
    
# download bidirectional lstm based assertion status trained on i2b2 dataset
assertion_full_dl = downloader.downloadModel(AssertionDLModel, "as_full_dl", "en") \
    .setInputCols(["document"])\
    .setOutputCol("assertion")
        

finisher = Finisher() \
    .setInputCols(["assertion"]) \
    .setIncludeKeys(True)

pipeline_fast_lg = Pipeline(stages = [documentAssembler, assertion_fast_lg, finisher])
pipeline_fast_dl = PipelineModel(stages = [documentAssembler, assertion_fast_dl, finisher])
pipeline_full_dl = Pipeline(stages = [documentAssembler, assertion_full_dl, finisher])

Py4JJavaError: An error occurred while calling z:com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader.downloadModel.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 1.0 failed 1 times, most recent failure: Lost task 0.0 in stage 1.0 (TID 1, localhost, executor driver): java.io.InvalidClassException: org.apache.spark.ml.PipelineStage; local class incompatible: stream classdesc serialVersionUID = 3275105016155696140, local class serialVersionUID = 8307843718209149723
	at java.io.ObjectStreamClass.initNonProxy(ObjectStreamClass.java:687)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2033)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1567)
	at java.io.ObjectInputStream.readArray(ObjectInputStream.java:1966)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1561)
	at java.io.ObjectInputStream.readObject(ObjectInputStream.java:427)
	at org.apache.spark.util.Utils$.deserialize(Utils.scala:160)
	at org.apache.spark.SparkContext$$anonfun$objectFile$1$$anonfun$apply$13.apply(SparkContext.scala:1232)
	at org.apache.spark.SparkContext$$anonfun$objectFile$1$$anonfun$apply$13.apply(SparkContext.scala:1232)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:389)
	at scala.collection.Iterator$class.foreach(Iterator.scala:893)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1336)
	at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48)
	at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310)
	at scala.collection.AbstractIterator.to(Iterator.scala:1336)
	at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302)
	at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1336)
	at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289)
	at scala.collection.AbstractIterator.toArray(Iterator.scala:1336)
	at org.apache.spark.rdd.RDD$$anonfun$take$1$$anonfun$29.apply(RDD.scala:1354)
	at org.apache.spark.rdd.RDD$$anonfun$take$1$$anonfun$29.apply(RDD.scala:1354)
	at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1954)
	at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1954)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:99)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:325)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1435)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1423)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1422)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1422)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:802)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1650)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1605)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1594)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:628)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1928)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1941)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1954)
	at org.apache.spark.rdd.RDD$$anonfun$take$1.apply(RDD.scala:1354)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:362)
	at org.apache.spark.rdd.RDD.take(RDD.scala:1327)
	at org.apache.spark.rdd.RDD$$anonfun$first$1.apply(RDD.scala:1368)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:362)
	at org.apache.spark.rdd.RDD.first(RDD.scala:1367)
	at com.johnsnowlabs.nlp.serialization.StructFeature.deserializeObject(Feature.scala:113)
	at com.johnsnowlabs.nlp.serialization.Feature.deserialize(Feature.scala:44)
	at com.johnsnowlabs.nlp.FeaturesReader$$anonfun$load$1.apply(ParamsAndFeaturesReadable.scala:15)
	at com.johnsnowlabs.nlp.FeaturesReader$$anonfun$load$1.apply(ParamsAndFeaturesReadable.scala:14)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at com.johnsnowlabs.nlp.FeaturesReader.load(ParamsAndFeaturesReadable.scala:14)
	at com.johnsnowlabs.nlp.FeaturesReader.load(ParamsAndFeaturesReadable.scala:8)
	at com.johnsnowlabs.nlp.pretrained.ResourceDownloader$.downloadModel(ResourceDownloader.scala:101)
	at com.johnsnowlabs.nlp.pretrained.ResourceDownloader$.downloadModel(ResourceDownloader.scala:95)
	at com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader$.downloadModel(ResourceDownloader.scala:171)
	at com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader.downloadModel(ResourceDownloader.scala)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:280)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:214)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.io.InvalidClassException: org.apache.spark.ml.PipelineStage; local class incompatible: stream classdesc serialVersionUID = 3275105016155696140, local class serialVersionUID = 8307843718209149723
	at java.io.ObjectStreamClass.initNonProxy(ObjectStreamClass.java:687)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1876)
	at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1745)
	at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2033)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1567)
	at java.io.ObjectInputStream.readArray(ObjectInputStream.java:1966)
	at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1561)
	at java.io.ObjectInputStream.readObject(ObjectInputStream.java:427)
	at org.apache.spark.util.Utils$.deserialize(Utils.scala:160)
	at org.apache.spark.SparkContext$$anonfun$objectFile$1$$anonfun$apply$13.apply(SparkContext.scala:1232)
	at org.apache.spark.SparkContext$$anonfun$objectFile$1$$anonfun$apply$13.apply(SparkContext.scala:1232)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:389)
	at scala.collection.Iterator$class.foreach(Iterator.scala:893)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1336)
	at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48)
	at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310)
	at scala.collection.AbstractIterator.to(Iterator.scala:1336)
	at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302)
	at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1336)
	at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289)
	at scala.collection.AbstractIterator.toArray(Iterator.scala:1336)
	at org.apache.spark.rdd.RDD$$anonfun$take$1$$anonfun$29.apply(RDD.scala:1354)
	at org.apache.spark.rdd.RDD$$anonfun$take$1$$anonfun$29.apply(RDD.scala:1354)
	at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1954)
	at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1954)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:99)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:325)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more


Now let's use these pipelines and see the results

In [None]:
pipeline_fast_lg.transform(test_data).show()
pipeline_fast_dl.transform(test_data).show()
pipeline_full_dl.transform(test_data).show()