![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Public/14.Transformers_for_Token_Classification_in_Spark_NLP.ipynb)

# Transformers for Token Classification in Spark NLP

**BertForTokenClassification** can load Bert Models with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.

Pretrained models can be loaded with `pretrained()` of the companion object. The default model is `"bert_base_token_classifier_conll03"`, if no name is provided. <br/><br/>


### **Here are Bert Based Token Classification models available in Spark NLP**
<br/>

| Title                                                                                                                        | Name                                          | Language   |
|:-----------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------|:-----------|
| BERT Token Classification - NER CoNLL (bert_base_token_classifier_conll03)                                                   | bert_base_token_classifier_conll03            | en         |
| BERT Token Classification - NER OntoNotes (bert_base_token_classifier_ontonote)                                              | bert_base_token_classifier_ontonote           | en         |
| BERT Token Classification Large - NER CoNLL (bert_large_token_classifier_conll03)                                            | bert_large_token_classifier_conll03           | en         |
| BERT Token Classification Large - NER OntoNotes (bert_large_token_classifier_ontonote)                                       | bert_large_token_classifier_ontonote          | en         |
| BERT Token Classification - ParsBERT for Persian Language Understanding (bert_token_classifier_parsbert_armanner)            | bert_token_classifier_parsbert_armanner       | fa         |
| BERT Token Classification - ParsBERT for Persian Language Understanding (bert_token_classifier_parsbert_ner)                 | bert_token_classifier_parsbert_ner            | fa         |
| BERT Token Classification - ParsBERT for Persian Language Understanding (bert_token_classifier_parsbert_peymaner)            | bert_token_classifier_parsbert_peymaner       | fa         |
| BERT Token Classification - BETO Spanish Language Understanding (bert_token_classifier_spanish_ner)                          | bert_token_classifier_spanish_ner             | es         |
| BERT Token Classification - Swedish Language Understanding (bert_token_classifier_swedish_ner)                               | bert_token_classifier_swedish_ner             | sv         |
| BERT Token Classification - Turkish Language Understanding (bert_token_classifier_turkish_ner)                               | bert_token_classifier_turkish_ner             | tr         |
| DistilBERT Token Classification - NER CoNLL (distilbert_base_token_classifier_conll03)                                       | distilbert_base_token_classifier_conll03      | en         |
| DistilBERT Token Classification - NER OntoNotes (distilbert_base_token_classifier_ontonotes)                                 | distilbert_base_token_classifier_ontonotes    | en         |
| DistilBERT Token Classification - DistilbertNER for Persian Language Understanding (distilbert_token_classifier_persian_ner) | distilbert_token_classifier_persian_ner       | fa         |
| BERT Token Classification -  Few-NERD (bert_base_token_classifier_few_nerd)                                                  | bert_base_token_classifier_few_nerd           | en         |
| DistilBERT Token Classification -  Few-NERD (distilbert_base_token_classifier_few_nerd)                                      | distilbert_base_token_classifier_few_nerd     | en         |
| Named Entity Recognition for Japanese (BertForTokenClassification)                                                           | bert_token_classifier_ner_ud_gsd              | ja         |
| Detect PHI for Deidentification (BertForTokenClassifier)                                                                     | bert_token_classifier_ner_deid                | en         |
| Detect Clinical Entities (BertForTokenClassifier)                                                                            | bert_token_classifier_ner_jsl                 | en         |
| Detect Drug Chemicals (BertForTokenClassifier)                                                                               | bert_token_classifier_ner_drugs               | en         |
| Detect Clinical Entities (Slim version, BertForTokenClassifier)                                                              | bert_token_classifier_ner_jsl_slim            | en         |
| ALBERT Token Classification Base - NER CoNLL (albert_base_token_classifier_conll03)                                          | albert_base_token_classifier_conll03          | en         |
| ALBERT Token Classification Large - NER CoNLL (albert_large_token_classifier_conll03)                                        | albert_large_token_classifier_conll03         | en         |
| ALBERT Token Classification XLarge - NER CoNLL (albert_xlarge_token_classifier_conll03)                                      | albert_xlarge_token_classifier_conll03        | en         |
| DistilRoBERTa Token Classification - NER OntoNotes (distilroberta_base_token_classifier_ontonotes)                           | distilroberta_base_token_classifier_ontonotes | en         |
| RoBERTa Token Classification Base - NER CoNLL (roberta_base_token_classifier_conll03)                                        | roberta_base_token_classifier_conll03         | en         |
| RoBERTa Token Classification Base - NER OntoNotes (roberta_base_token_classifier_ontonotes)                                  | roberta_base_token_classifier_ontonotes       | en         |
| RoBERTa Token Classification Large - NER CoNLL (roberta_large_token_classifier_conll03)                                      | roberta_large_token_classifier_conll03        | en         |
| RoBERTa Token Classification Large - NER OntoNotes (roberta_large_token_classifier_ontonotes)                                | roberta_large_token_classifier_ontonotes      | en         |
| RoBERTa Token Classification For Persian (roberta_token_classifier_zwnj_base_ner)                                            | roberta_token_classifier_zwnj_base_ner        | fa         |
| XLM-RoBERTa Token Classification Base - NER XTREME (xlm_roberta_token_classifier_ner_40_lang)                                | xlm_roberta_token_classifier_ner_40_lang      | xx         |
| XLNet Token Classification Base - NER CoNLL (xlnet_base_token_classifier_conll03)                                            | xlnet_base_token_classifier_conll03           | en         |
| XLNet Token Classification Large - NER CoNLL (xlnet_large_token_classifier_conll03)                                          | xlnet_large_token_classifier_conll03          | en         |
| Detect Adverse Drug Events (BertForTokenClassification)                                                                      | bert_token_classifier_ner_ade                 | en         |
| Detect Anatomical Regions (BertForTokenClassification)                                                                       | bert_token_classifier_ner_anatomy             | en         |
| Detect Bacterial Species (BertForTokenClassification)                                                                        | bert_token_classifier_ner_bacteria            | en         |
| XLM-RoBERTa Token Classification Base - NER CoNLL (xlm_roberta_base_token_classifier_conll03)                                | xlm_roberta_base_token_classifier_conll03     | en         |
| XLM-RoBERTa Token Classification Base - NER OntoNotes (xlm_roberta_base_token_classifier_ontonotes)                          | xlm_roberta_base_token_classifier_ontonotes   | en         |
| Longformer Token Classification Base - NER CoNLL (longformer_base_token_classifier_conll03)                                  | longformer_base_token_classifier_conll03      | en         |
| Longformer Token Classification Base - NER CoNLL (longformer_large_token_classifier_conll03)                                 | longformer_large_token_classifier_conll03     | en         |
| Detect Chemicals in Medical text (BertForTokenClassification)                                                                | bert_token_classifier_ner_chemicals           | en         |
| Detect Chemical Compounds and Genes (BertForTokenClassifier)                                                                 | bert_token_classifier_ner_chemprot            | en         |
| Detect Cancer Genetics (BertForTokenClassification)                                                                          | bert_token_classifier_ner_bionlp              | en         |
| Detect Cellular/Molecular Biology Entities (BertForTokenClassification)                                                      | bert_token_classifier_ner_cellular            | en         |
| Detect concepts in drug development trials (BertForTokenClassification)                                                      | bert_token_classifier_drug_development_trials | en         |
| Detect Cancer Genetics (BertForTokenClassification)                                                                          | bert_token_classifier_ner_bionlp              | en         |
| Detect Adverse Drug Events (BertForTokenClassification)                                                                      | bert_token_classifier_ner_ade                 | en         |
| Detect Anatomical Regions (MedicalBertForTokenClassifier)                                                                    | bert_token_classifier_ner_anatomy             | en         |
| Detect Cellular/Molecular Biology Entities (BertForTokenClassification)                                                      | bert_token_classifier_ner_cellular            | en         |
| Detect Chemicals in Medical text (BertForTokenClassification)                                                                | bert_token_classifier_ner_chemicals           | en         |
| Detect Chemical Compounds and Genes (BertForTokenClassifier)                                                                 | bert_token_classifier_ner_chemprot            | en         |
| Detect PHI for Deidentification (BertForTokenClassifier)                                                                     | bert_token_classifier_ner_deid                | en         |
| Detect Drug Chemicals (BertForTokenClassifier)                                                                               | bert_token_classifier_ner_drugs               | en         |
| Detect Clinical Entities (BertForTokenClassifier)                                                                            | bert_token_classifier_ner_jsl                 | en         |
| Detect Clinical Entities (Slim version, BertForTokenClassifier)                                                              | bert_token_classifier_ner_jsl_slim            | en         |
| Detect Bacterial Species (BertForTokenClassification)                                                                        | bert_token_classifier_ner_bacteria            | en         |

## Colab Setup

In [None]:
! pip install -q pyspark==3.2.0 spark-nlp

In [None]:
import sparknlp

spark = sparknlp.start(spark32=True)

from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline,PipelineModel
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import pandas as pd


print("Spark NLP version", sparknlp.version())
print("Apache Spark version:", spark.version)

spark

Spark NLP version 3.4.0
Apache Spark version: 3.2.0


## BertForTokenClassification Pipeline

Now, let's create a Spark NLP Pipeline with `bert_base_token_classifier_conll03` model and check the results. <br/>

In [None]:
document_assembler = DocumentAssembler() \
    .setInputCol('text') \
    .setOutputCol('document')

tokenizer = Tokenizer() \
    .setInputCols(['document']) \
    .setOutputCol('token')

tokenClassifier = BertForTokenClassification \
    .pretrained('bert_base_token_classifier_conll03', 'en') \
    .setInputCols(['token', 'document']) \
    .setOutputCol('ner') \
    .setCaseSensitive(True) \
    .setMaxSentenceLength(512)

# since output column is IOB/IOB2 style, NerConverter can extract entities
ner_converter = NerConverter() \
    .setInputCols(['document', 'token', 'ner']) \
    .setOutputCol('entities')

pipeline = Pipeline(stages=[
    document_assembler, 
    tokenizer,
    tokenClassifier,
    ner_converter
])

example = spark.createDataFrame([['My name is John Parker! I live in New York and I am a member of the New York Road Runners.']]).toDF("text")
model = pipeline.fit(example)
result= model.transform(example)

bert_base_token_classifier_conll03 download started this may take some time.
Approximate size to download 385.4 MB
[OK!]


In [None]:
model.stages

[DocumentAssembler_cdbcfd158e49,
 REGEX_TOKENIZER_83aa852cb2bf,
 BERT_FOR_TOKEN_CLASSIFICATION_675a6a750b89,
 NerConverter_e24449920ccc]

We can check the classes of `bert_base_token_classifier_conll03` model by using `getClasses()` function

In [None]:
tokenClassifier.getClasses()

['B-LOC', 'I-ORG', 'I-MISC', 'I-LOC', 'I-PER', 'B-MISC', 'B-ORG', 'O', 'B-PER']

In [None]:
result.columns

['text', 'document', 'token', 'ner', 'entities']

In [None]:
result.printSchema()

root
 |-- text: string (nullable = true)
 |-- document: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |    |    |-- embeddings: array (nullable = true)
 |    |    |    |-- element: float (containsNull = false)
 |-- token: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 

Checking the ner labels of each token

In [None]:
result_df = result.select(F.explode(F.arrays_zip(result.token.result, result.ner.result, result.entities.result)).alias("cols"))\
                  .select(F.expr("cols['0']").alias("token"),
                          F.expr("cols['1']").alias("ner_label"))

result_df.show(50, truncate=100)

+-------+---------+
|  token|ner_label|
+-------+---------+
|     My|        O|
|   name|        O|
|     is|        O|
|   John|    B-PER|
| Parker|    I-PER|
|      !|        O|
|      I|        O|
|   live|        O|
|     in|        O|
|    New|    B-LOC|
|   York|    I-LOC|
|    and|        O|
|      I|        O|
|     am|        O|
|      a|        O|
| member|        O|
|     of|        O|
|    the|        O|
|    New|    B-ORG|
|   York|    I-ORG|
|   Road|    I-ORG|
|Runners|    I-ORG|
|      .|        O|
+-------+---------+



Inspecting the chunks

In [None]:
result_df_1= result.select(F.explode(F.arrays_zip(result.entities.result, result.entities.begin, result.entities.end, result.entities.metadata)).alias("col"))\
                   .select(F.expr("col['0']").alias("entities"),
                            F.expr("col['1']").alias("begin"),
                            F.expr("col['2']").alias("end"),
                            F.expr("col['3']['entity']").alias("ner_label"))
result_df_1.show(50, truncate=False)

+---------------------+-----+---+---------+
|entities             |begin|end|ner_label|
+---------------------+-----+---+---------+
|John Parker          |11   |21 |PER      |
|New York             |34   |41 |LOC      |
|New York Road Runners|68   |88 |ORG      |
+---------------------+-----+---+---------+



##  BertForTokenClassification By Using LightPipeline

Now,  we will use the `bert_large_token_classifier_ontonote` model with LightPipeline and fullAnnotate it with sample data.

In [None]:
document_assembler = DocumentAssembler() \
    .setInputCol('text') \
    .setOutputCol('document')

tokenizer = Tokenizer() \
    .setInputCols(['document']) \
    .setOutputCol('token')

tokenClassifier = BertForTokenClassification \
    .pretrained('bert_large_token_classifier_ontonote', 'en') \
    .setInputCols(['token', 'document']) \
    .setOutputCol('ner') \
    .setCaseSensitive(True) \
    .setMaxSentenceLength(512)

# since output column is IOB/IOB2 style, NerConverter can extract entities
ner_converter = NerConverter() \
    .setInputCols(['document', 'token', 'ner']) \
    .setOutputCol('entities')

pipeline = Pipeline(stages=[
    document_assembler, 
    tokenizer,
    tokenClassifier,
    ner_converter
])

empty_df = spark.createDataFrame([['']]).toDF("text")
model = pipeline.fit(example)

bert_large_token_classifier_ontonote download started this may take some time.
Approximate size to download 1.2 GB
[OK!]


In [None]:
light_model= LightPipeline(model)
light_result= light_model.fullAnnotate("Steven Rothery is the original guitarist and the longest continuous member of the British rock band Marillion.")[0]

In [None]:
light_result

{'document': [Annotation(document, 0, 109, Steven Rothery is the original guitarist and the longest continuous member of the British rock band Marillion., {})],
 'entities': [Annotation(chunk, 0, 13, Steven Rothery, {'entity': 'PERSON', 'sentence': '0', 'chunk': '0'}),
  Annotation(chunk, 82, 88, British, {'entity': 'NORP', 'sentence': '0', 'chunk': '1'}),
  Annotation(chunk, 100, 108, Marillion, {'entity': 'ORG', 'sentence': '0', 'chunk': '2'})],
 'ner': [Annotation(named_entity, 0, 5, B-PERSON, {'Some(I-CARDINAL)': '1.3950217E-6', 'Some(B-TIME)': '8.392712E-6', 'Some(I-LOC)': '2.4527526E-6', 'Some(B-GPE)': '1.4145754E-4', 'Some(B-WORK_OF_ART)': '1.2767659E-5', 'Some(I-GPE)': '7.0996966E-6', 'Some(B-LANGUAGE)': '3.469728E-6', 'Some(I-ORDINAL)': '7.2051887E-7', 'Some(B-FAC)': '3.322305E-6', 'Some(I-PRODUCT)': '3.2988514E-6', 'Some(B-NORP)': '7.93385E-6', 'Some(I-EVENT)': '2.5540264E-6', 'Some(B-ORG)': '2.3466631E-5', 'Some(O)': '0.0033311455', 'Some(B-ORDINAL)': '3.699E-6', 'Some(I-MON

Let's check the classes that `bert_large_token_classifier_ontonote` model can predict

In [None]:
tokenClassifier.getClasses()

['I-TIME',
 'B-PERSON',
 'B-GPE',
 'B-LAW',
 'B-NORP',
 'B-LOC',
 'I-ORG',
 'I-QUANTITY',
 'B-DATE',
 'B-PRODUCT',
 'B-FAC',
 'I-DATE',
 'I-WORK_OF_ART',
 'B-TIME',
 'B-QUANTITY',
 'I-PERCENT',
 'I-LAW',
 'I-GPE',
 'I-NORP',
 'I-ORDINAL',
 'I-EVENT',
 'I-LOC',
 'B-EVENT',
 'I-FAC',
 'B-ORDINAL',
 'B-LANGUAGE',
 'B-MONEY',
 'B-PERCENT',
 'I-LANGUAGE',
 'B-ORG',
 'I-MONEY',
 'I-PRODUCT',
 'O',
 'B-WORK_OF_ART',
 'I-CARDINAL',
 'I-PERSON',
 'B-CARDINAL']

In [None]:
light_result.keys()

dict_keys(['document', 'token', 'ner', 'entities'])

Checking the ner labels of each token

In [None]:
tokens= []
ner_labels= []

for i, k in list(zip(light_result["token"], light_result["ner"])):
  tokens.append(i.result)
  ner_labels.append(k.result)

result_df= pd.DataFrame({"tokens": tokens, "ner_labels": ner_labels})
result_df.head(20)

Unnamed: 0,tokens,ner_labels
0,Steven,B-PERSON
1,Rothery,I-PERSON
2,is,O
3,the,O
4,original,O
5,guitarist,O
6,and,O
7,the,O
8,longest,O
9,continuous,O


Let's check the chunk results

In [None]:
chunks= []
begin= []
end= []
ner_label= []

for i in light_result["entities"]:
  chunks.append(i.result)
  begin.append(i.begin)
  end.append(i.end)
  ner_label.append(i.metadata["entity"])

result_df= pd.DataFrame({"chunks": chunks, "begin": begin, "end": end, "ner_label": ner_label})
result_df.head(20)

Unnamed: 0,chunks,begin,end,ner_label
0,Steven Rothery,0,13,PERSON
1,British,82,88,NORP
2,Marillion,100,108,ORG
