Show how to use pretrained assertion status

In [1]:
import sys
sys.path.append('../../')

from pyspark.sql import SparkSession
from pyspark.ml import PipelineModel

from sparknlp.annotator import *
from sparknlp.common import *
from sparknlp.base import *
from sparknlp.pretrained import ResourceDownloader

from pathlib import Path

if sys.version_info[0] < 3:
    from urllib import urlretrieve
else:
    from urllib.request import urlretrieve

In [2]:
spark = SparkSession.builder \
    .appName("assertion-status")\
    .master("local[1]")\
    .config("spark.driver.memory","4G")\
    .config("spark.driver.maxResultSize", "2G")\
    .config("spark.jar", "lib/sparknlp.jar")\
    .getOrCreate()

Create some data for testing purposes

In [3]:
from pyspark.sql import Row
R = Row('sentence', 'start', 'end')
test_data = spark.createDataFrame([R('Sister with stomach cancer .',2,3),
                      R('A thallium stress test showed tachycardia and severe dyspnea',5,5),
                      R('Positive for shortness of breath, no cough',2,4),
                      R('Positive for shortness of breath, no cough',7,7)])

Create some pipelines, one for each type of assertion classification algorithm, model download can take some time.

In [4]:
import time

# instantiate the downloader
downloader = ResourceDownloader()

documentAssembler = DocumentAssembler() \
    .setInputCol("sentence") \
    .setOutputCol("document")

# download logistic regression based assertion status trained on negex dataset
assertion_fast_lg = downloader.downloadModel(AssertionLogRegModel, "as_fast_lg", "en") \
    .setInputCols(["document"]) \
    .setOutputCol("assertion") \

# download bidirectional lstm based assertion status trained on negex dataset
assertion_fast_dl = downloader.downloadModel(AssertionDLModel, "as_fast_dl", "en") \
    .setInputCols(["document"]) \
    .setOutputCol("assertion") \
    
# download bidirectional lstm based assertion status trained on i2b2 dataset
assertion_full_dl = downloader.downloadModel(AssertionDLModel, "as_full_dl", "en") \
    .setInputCols(["document"])\
    .setOutputCol("assertion")
        

finisher = Finisher() \
    .setInputCols(["assertion"]) \
    .setIncludeKeys(True)

pipeline_fast_lg = Pipeline(stages = [documentAssembler, assertion_fast_lg, finisher])
pipeline_fast_dl = PipelineModel(stages = [documentAssembler, assertion_fast_dl, finisher])
pipeline_full_dl = Pipeline(stages = [documentAssembler, assertion_full_dl, finisher])

Py4JJavaError: An error occurred while calling z:com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader.downloadModel.
: java.net.SocketException: Connection reset
	at java.net.SocketInputStream.read(SocketInputStream.java:210)
	at java.net.SocketInputStream.read(SocketInputStream.java:141)
	at sun.security.ssl.InputRecord.readFully(InputRecord.java:465)
	at sun.security.ssl.InputRecord.readV3Record(InputRecord.java:593)
	at sun.security.ssl.InputRecord.read(InputRecord.java:532)
	at sun.security.ssl.SSLSocketImpl.readRecord(SSLSocketImpl.java:983)
	at sun.security.ssl.SSLSocketImpl.readDataRecord(SSLSocketImpl.java:940)
	at sun.security.ssl.AppInputStream.read(AppInputStream.java:105)
	at org.apache.http.impl.io.SessionInputBufferImpl.streamRead(SessionInputBufferImpl.java:139)
	at org.apache.http.impl.io.SessionInputBufferImpl.read(SessionInputBufferImpl.java:200)
	at org.apache.http.impl.io.ContentLengthInputStream.read(ContentLengthInputStream.java:178)
	at org.apache.http.conn.EofSensorInputStream.read(EofSensorInputStream.java:137)
	at com.amazonaws.internal.SdkFilterInputStream.read(SdkFilterInputStream.java:82)
	at com.amazonaws.event.ProgressInputStream.read(ProgressInputStream.java:180)
	at com.amazonaws.internal.SdkFilterInputStream.read(SdkFilterInputStream.java:82)
	at com.amazonaws.services.s3.internal.S3AbortableInputStream.read(S3AbortableInputStream.java:125)
	at com.amazonaws.internal.SdkFilterInputStream.read(SdkFilterInputStream.java:82)
	at com.amazonaws.internal.SdkFilterInputStream.read(SdkFilterInputStream.java:82)
	at com.amazonaws.internal.SdkFilterInputStream.read(SdkFilterInputStream.java:82)
	at com.amazonaws.event.ProgressInputStream.read(ProgressInputStream.java:180)
	at com.amazonaws.internal.SdkFilterInputStream.read(SdkFilterInputStream.java:82)
	at com.amazonaws.util.LengthCheckInputStream.read(LengthCheckInputStream.java:107)
	at com.amazonaws.internal.SdkFilterInputStream.read(SdkFilterInputStream.java:82)
	at java.io.FilterInputStream.read(FilterInputStream.java:107)
	at org.apache.commons.io.IOUtils.copyLarge(IOUtils.java:1792)
	at org.apache.commons.io.IOUtils.copyLarge(IOUtils.java:1769)
	at org.apache.commons.io.IOUtils.copy(IOUtils.java:1744)
	at org.apache.commons.io.FileUtils.copyInputStreamToFile(FileUtils.java:1512)
	at com.johnsnowlabs.nlp.pretrained.S3ResourceDownloader$$anonfun$download$1.apply(S3ResourceDownloader.scala:100)
	at com.johnsnowlabs.nlp.pretrained.S3ResourceDownloader$$anonfun$download$1.apply(S3ResourceDownloader.scala:86)
	at scala.Option.flatMap(Option.scala:171)
	at com.johnsnowlabs.nlp.pretrained.S3ResourceDownloader.download(S3ResourceDownloader.scala:85)
	at com.johnsnowlabs.nlp.pretrained.ResourceDownloader$.downloadResource(ResourceDownloader.scala:84)
	at com.johnsnowlabs.nlp.pretrained.ResourceDownloader$.downloadModel(ResourceDownloader.scala:100)
	at com.johnsnowlabs.nlp.pretrained.ResourceDownloader$.downloadModel(ResourceDownloader.scala:95)
	at com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader$.downloadModel(ResourceDownloader.scala:171)
	at com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader.downloadModel(ResourceDownloader.scala)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:280)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:214)
	at java.lang.Thread.run(Thread.java:748)


Now let's use these pipelines and see the results

In [None]:
pipeline_fast_lg.transform(test_data).show()
pipeline_fast_dl.transform(test_data).show()
pipeline_full_dl.transform(test_data).show()