![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/dl-ner/NER_CoNLL2003_training_using_DeBertaEmbeddings.ipynb)

# NER Model Development with DebertaEmbeddings Based on CoNLL 2003 Dataset
The DeBERTa model was proposed in https://arxiv.org/abs/2006.03654 DeBERTa: Decoding-enhanced BERT with Disentangled Attention by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen It is based on Google’s BERT model released in 2018 and Facebook’s RoBERTa model released in 2019. Compared to RoBERTa-Large, a DeBERTa model trained on half of the training data performs consistently better on a wide range of NLP tasks, achieving improvements on MNLI by +0.9% (90.2% vs. 91.1%), on SQuAD v2.0 by +2.3% (88.4% vs. 90.7%) and RACE by +3.6% (83.2% vs. 86.8%).

In [None]:
! wget -q http://setup.johnsnowlabs.com/colab.sh -O - | bash

Installing PySpark 3.2.3 and Spark NLP 5.2.0
setup Colab for PySpark 3.2.3 and Spark NLP 5.2.0


In [None]:
import sparknlp
import pyspark.sql.functions as F
from sparknlp.annotator import *
from sparknlp.base import *
from sparknlp.pretrained import PretrainedPipeline
from pyspark.ml import Pipeline

# for GPU training >> sparknlp.start(gpu = True)
# for Spark 2.3 =>> sparknlp.start(spark23 = True)
spark = sparknlp.start()

print("Spark NLP version", sparknlp.version())
print("Apache Spark version:", spark.version)

Spark NLP version 5.2.0
Apache Spark version: 3.2.3


In [None]:
#download training data
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp/master/src/test/resources/conll2003/eng.train
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp/master/src/test/resources/conll2003/eng.testa

In [None]:
from sparknlp.training import CoNLL

training_data = CoNLL().readDataset(spark, './eng.train')
testing_data = CoNLL().readDataset(spark, './eng.testa')

In [None]:
print(f"(Train count: {training_data.count()} Test count: {testing_data.count()})")

(Train count: 14041 Test count: 3250)


In [None]:
training_data.select(
    F.explode(F.arrays_zip('token', 'pos', 'label')).alias("cols")
).select(
    F.col("cols.token.result").alias("token"),
    F.col("cols.pos.result").alias("pos"),
    F.col("cols.label.result").alias("ner_label")
).show(truncate=50)

+----------+---+---------+
|     token|pos|ner_label|
+----------+---+---------+
|        EU|NNP|    B-ORG|
|   rejects|VBZ|        O|
|    German| JJ|   B-MISC|
|      call| NN|        O|
|        to| TO|        O|
|   boycott| VB|        O|
|   British| JJ|   B-MISC|
|      lamb| NN|        O|
|         .|  .|        O|
|     Peter|NNP|    B-PER|
| Blackburn|NNP|    I-PER|
|  BRUSSELS|NNP|    B-LOC|
|1996-08-22| CD|        O|
|       The| DT|        O|
|  European|NNP|    B-ORG|
|Commission|NNP|    I-ORG|
|      said|VBD|        O|
|        on| IN|        O|
|  Thursday|NNP|        O|
|        it|PRP|        O|
+----------+---+---------+
only showing top 20 rows



## 1. Create Spark NLP train pipeline

In [None]:
embeddings = DeBertaEmbeddings.pretrained("deberta_v3_base", "en") \
      .setInputCols("document", "token") \
      .setOutputCol("embeddings")

nerTagger = NerDLApproach()\
      .setInputCols(["sentence", "token", "embeddings"])\
      .setLabelColumn("label")\
      .setOutputCol("ner")\
      .setMaxEpochs(2)\
      .setLr(0.002)\
      .setBatchSize(16)\
      .setRandomSeed(0)\
      .setVerbose(1)\
      .setValidationSplit(0.15)\

ner_converter = NerConverter() \
    .setInputCols(['document', 'token', 'ner']) \
    .setOutputCol('ner_chunk')

ner_pipeline = Pipeline(stages=[
      embeddings,
      nerTagger,
      ner_converter
 ])

deberta_v3_base download started this may take some time.
Approximate size to download 415 MB
[OK!]


## 2. Train model

In [None]:
%%time
ner_model = ner_pipeline.fit(training_data.limit(5000).repartition(1))

CPU times: user 8.48 s, sys: 1.18 s, total: 9.66 s
Wall time: 37min 13s


In [None]:
predictions = ner_model.transform(testing_data.limit(1000))

In [None]:
preds_df = predictions.select(
    F.explode(F.arrays_zip('token', 'label', 'ner')).alias("cols")
).select(
    F.col("cols.token.result").alias("token"),
    F.col("cols.label.result").alias("ground_truth"),
    F.col("cols.ner.result").alias("prediction")
)

preds_df.show(truncate=50)

+--------------+------------+----------+
|         token|ground_truth|prediction|
+--------------+------------+----------+
|       CRICKET|           O|         O|
|             -|           O|         O|
|LEICESTERSHIRE|       B-ORG|     B-LOC|
|          TAKE|           O|         O|
|          OVER|           O|         O|
|            AT|           O|         O|
|           TOP|           O|         O|
|         AFTER|           O|         O|
|       INNINGS|           O|     B-LOC|
|       VICTORY|           O|         O|
|             .|           O|         O|
|        LONDON|       B-LOC|     B-LOC|
|    1996-08-30|           O|         O|
|          West|      B-MISC|    B-MISC|
|        Indian|      I-MISC|    I-MISC|
|   all-rounder|           O|         O|
|          Phil|       B-PER|     B-PER|
|       Simmons|       I-PER|     I-PER|
|          took|           O|         O|
|          four|           O|         O|
+--------------+------------+----------+
only showing top

## 3. Benchmark

In [None]:
from sklearn.metrics import classification_report

preds_df_pd = preds_df.toPandas()
print(classification_report(preds_df_pd['ground_truth'], preds_df_pd['prediction']))

              precision    recall  f1-score   support

       B-LOC       0.78      0.96      0.86       559
      B-MISC       0.79      0.66      0.72       190
       B-ORG       0.81      0.65      0.72       355
       B-PER       0.97      0.98      0.97       654
       I-LOC       0.74      0.70      0.72        69
      I-MISC       0.77      0.44      0.56        93
       I-ORG       0.66      0.82      0.73       181
       I-PER       0.97      0.98      0.97       443
           O       1.00      0.99      1.00     11589

    accuracy                           0.97     14133
   macro avg       0.83      0.80      0.81     14133
weighted avg       0.97      0.97      0.97     14133

