In [1]:
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
import pyspark.sql.functions as F

import sparknlp
from sparknlp.annotator import *
from sparknlp.common import *
from sparknlp.base import *
from sparknlp.training import CoNLL

In [2]:
spark = sparknlp.start()
#spark = sparknlp.start(gpu=True)

In [3]:
print("Spark NLP version: ", sparknlp.version())
print("Apache Spark version: ", spark.version)

Spark NLP version:  2.4.3
Apache Spark version:  2.4.5


In [4]:
# read CoNLL train file
training_data = CoNLL().readDataset(spark, '/home/aminmoradi/sddm_project/CoNLL_parser/conll.train')

In [5]:
training_data.show(5)

+----------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|            text|            document|            sentence|               token|                 pos|               label|
+----------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|admission date :|[[document, 0, 15...|[[document, 0, 15...|[[token, 0, 8, ad...|[[pos, 0, 8, NN, ...|[[named_entity, 0...|
|      2018-10-25|[[document, 0, 9,...|[[document, 0, 9,...|[[token, 0, 9, 20...|[[pos, 0, 9, NN, ...|[[named_entity, 0...|
|discharge date :|[[document, 0, 15...|[[document, 0, 15...|[[token, 0, 8, di...|[[pos, 0, 8, NN, ...|[[named_entity, 0...|
|      2018-10-31|[[document, 0, 9,...|[[document, 0, 9,...|[[token, 0, 9, 20...|[[pos, 0, 9, NN, ...|[[named_entity, 0...|
| date of birth :|[[document, 0, 14...|[[document, 0, 14...|[[token, 0, 3, da...|[[pos, 0, 3, NN, ...|[[named_entity, 0...|
+-------

In [6]:
training_data.count()

16311

In [7]:
# read CoNLL test file
test_data = CoNLL().readDataset(spark, '/home/aminmoradi/sddm_project/CoNLL_parser/conll.test')

In [8]:
test_data.count()

27568

In [9]:
GloVe_embeddings = WordEmbeddingsModel.pretrained()\
    .setInputCols(["sentence", "token"])\
    .setOutputCol("glove")\
    .setCaseSensitive(False)

glove_100d download started this may take some time.
Approximate size to download 145.3 MB
[OK!]


In [14]:
test_data_glove = GloVe_embeddings.transform(test_data)
test_data_glove.write.parquet("test_withGloveEmbeds.parquet")

In [15]:
nerTagger = NerDLApproach()\
  .setInputCols(["sentence", "token", "glove"])\
  .setLabelColumn("label")\
  .setOutputCol("ner")\
  .setMaxEpochs(10)\
  .setLr(0.001)\
  .setPo(0.005)\
  .setBatchSize(8)\
  .setRandomSeed(0)\
  .setVerbose(2)\
  .setValidationSplit(0.25)\
  .setEvaluationLogExtended(True) \
  .setEnableOutputLogs(True)\
  .setIncludeConfidence(True)\
  .setTestDataset("test_withGloveEmbeds.parquet")


pipeline = Pipeline( stages = [
                GloVe_embeddings,
                nerTagger
  ])

In [16]:
ner_model = pipeline.fit(training_data)

In [17]:
predictions = ner_model.transform(test_data)
predictions.show(3)

+----------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|            text|            document|            sentence|               token|                 pos|               label|               glove|                 ner|
+----------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|admission date :|[[document, 0, 15...|[[document, 0, 15...|[[token, 0, 8, ad...|[[pos, 0, 8, NN, ...|[[named_entity, 0...|[[word_embeddings...|[[named_entity, 0...|
|      2012-10-31|[[document, 0, 9,...|[[document, 0, 9,...|[[token, 0, 9, 20...|[[pos, 0, 9, NN, ...|[[named_entity, 0...|[[word_embeddings...|[[named_entity, 0...|
|discharge date :|[[document, 0, 15...|[[document, 0, 15...|[[token, 0, 8, di...|[[pos, 0, 8, NN, ...|[[named_entity, 0...|[[word_embeddings...|[[named_entity, 0...|
+---

In [18]:
predictions.select(F.explode(F.arrays_zip('token.result','label.result','ner.result')).alias("cols")) \
        .select(F.expr("cols['0']").alias("token"),
            F.expr("cols['1']").alias("ground_truth"),
            F.expr("cols['2']").alias("prediction")).show(100,truncate=False)

+------------------------+------------+-----------+
|token                   |ground_truth|prediction |
+------------------------+------------+-----------+
|admission               |O           |O          |
|date                    |O           |O          |
|:                       |O           |O          |
|2012-10-31              |O           |O          |
|discharge               |O           |O          |
|date                    |O           |O          |
|:                       |O           |O          |
|2012-11-07              |O           |O          |
|date                    |O           |O          |
|of                      |O           |O          |
|birth                   |O           |O          |
|:                       |O           |O          |
|1941-03-23              |O           |O          |
|sex                     |O           |O          |
|:                       |O           |O          |
|m                       |O           |O          |
|service    

In [19]:
ner_model.stages

[WORD_EMBEDDINGS_MODEL_48cffc8b9a76, NerDLModel_38add6240ea3]

In [20]:
ner_model.stages[1].write().overwrite().save('NER_DL_GloVe_100d')

In [21]:
spark.stop()