![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)



[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Public/5.Text_Classification_with_ClassifierDL.ipynb)

# 5 Text Classification with ClassifierDL

**Relevant blogpost:** https://towardsdatascience.com/text-classification-in-spark-nlp-with-bert-and-universal-sentence-encoders-e644d618ca32

## Colab Setup

In [1]:
!wget https://setup.johnsnowlabs.com/colab.sh -O - | bash /dev/stdin -p 3.2.1 -s 4.0.1 -g

--2022-08-24 17:48:45--  https://setup.johnsnowlabs.com/colab.sh
Resolving setup.johnsnowlabs.com (setup.johnsnowlabs.com)... 51.158.130.125
Connecting to setup.johnsnowlabs.com (setup.johnsnowlabs.com)|51.158.130.125|:443... connected.
HTTP request sent, awaiting response... 302 Moved Temporarily
Location: https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp/master/scripts/colab_setup.sh [following]
--2022-08-24 17:48:46--  https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp/master/scripts/colab_setup.sh
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1191 (1.2K) [text/plain]
Saving to: ‘STDOUT’


2022-08-24 17:48:46 (67.4 MB/s) - written to stdout [1191/1191]

Installing PySpark 3.2.1 and Spark NLP 4.0.1
setup Colab for PySpark 3.2.1 and Spark NLP

In [1]:
import sparknlp

from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline
import pandas as pd
import os

spark = sparknlp.start(gpu = True)# for GPU training >> sparknlp.start(gpu = True)

print("Spark NLP version", sparknlp.version())
print("Apache Spark version:", spark.version)

spark

Spark NLP version 4.0.1
Apache Spark version: 3.2.1


<b><h1><font color='darkred'>!!! WARNING !!! </font><h1><b>



**If you get an error related to Java port not found 55, it is probably because that the Colab memory cannot handle the model and the Spark session died. In that case, try on a larger machine or restart the kernel at the top and then come back here and rerun.**

## ClassiferDL with Word Embeddings and Text Preprocessing

### Load Dataset

In [3]:
! wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Public/data/news_category_train.csv
! wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Public/data/news_category_test.csv

In [4]:
!ls -lt

total 24948
-rw-r--r-- 1 root root  1504408 Aug 24 17:51 news_category_test.csv
-rw-r--r-- 1 root root 24032125 Aug 24 17:51 news_category_train.csv
drwxr-xr-x 1 root root     4096 Aug 15 13:44 sample_data


In [5]:
trainDataset = spark.read \
      .option("header", True) \
      .csv("news_category_train.csv")

trainDataset.show(truncate=50)

+--------+--------------------------------------------------+
|category|                                       description|
+--------+--------------------------------------------------+
|Business| Short sellers, Wall Street's dwindling band of...|
|Business| Private investment firm Carlyle Group, which h...|
|Business| Soaring crude prices plus worries about the ec...|
|Business| Authorities have halted oil export flows from ...|
|Business| Tearaway world oil prices, toppling records an...|
|Business| Stocks ended slightly higher on Friday but sta...|
|Business| Assets of the nation's retail money market mut...|
|Business| Retail sales bounced back a bit in July, and n...|
|Business|" After earning a PH.D. in Sociology, Danny Baz...|
|Business| Short sellers, Wall Street's dwindling  band o...|
|Business| Soaring crude prices plus worries  about the e...|
|Business| OPEC can do nothing to douse scorching  oil pr...|
|Business| Non OPEC oil exporters should consider  increa...|
|Busines

In [6]:
trainDataset.count()

120000

In [7]:
from pyspark.sql.functions import col

trainDataset.groupBy("category") \
    .count() \
    .orderBy(col("count").desc()) \
    .show()

+--------+-----+
|category|count|
+--------+-----+
|   World|30000|
|Sci/Tech|30000|
|  Sports|30000|
|Business|30000|
+--------+-----+



In [8]:
from pyspark.sql.functions import col

testDataset = spark.read \
      .option("header", True) \
      .csv("news_category_test.csv")


testDataset.groupBy("category") \
      .count() \
      .orderBy(col("count").desc()) \
      .show()

+--------+-----+
|category|count|
+--------+-----+
|   World| 1900|
|Sci/Tech| 1900|
|  Sports| 1900|
|Business| 1900|
+--------+-----+



In [9]:
# if we want to split the dataset
'''
(trainingData, testData) = trainDataset.randomSplit([0.7, 0.3], seed = 100)
print("Training Dataset Count: " + str(trainingData.count()))
print("Test Dataset Count: " + str(testData.count()))
'''

'\n(trainingData, testData) = trainDataset.randomSplit([0.7, 0.3], seed = 100)\nprint("Training Dataset Count: " + str(trainingData.count()))\nprint("Test Dataset Count: " + str(testData.count()))\n'

In [10]:
document_assembler = DocumentAssembler() \
    .setInputCol("description") \
    .setOutputCol("document")
    
tokenizer = Tokenizer() \
    .setInputCols(["document"]) \
    .setOutputCol("token")
      
normalizer = Normalizer() \
    .setInputCols(["token"]) \
    .setOutputCol("normalized")

stopwords_cleaner = StopWordsCleaner()\
    .setInputCols("normalized")\
    .setOutputCol("cleanTokens")\
    .setCaseSensitive(False)

lemma = LemmatizerModel.pretrained('lemma_antbnc') \
    .setInputCols(["cleanTokens"]) \
    .setOutputCol("lemma")


lemma_antbnc download started this may take some time.
Approximate size to download 907.6 KB
[OK!]


### with Glove 100d embeddings

In [11]:
glove_embeddings = WordEmbeddingsModel().pretrained() \
      .setInputCols(["document",'lemma'])\
      .setOutputCol("embeddings")\
      .setCaseSensitive(False)

embeddingsSentence = SentenceEmbeddings() \
      .setInputCols(["document", "embeddings"]) \
      .setOutputCol("sentence_embeddings") \
      .setPoolingStrategy("AVERAGE")

classsifierdl = ClassifierDLApproach()\
      .setInputCols(["sentence_embeddings"])\
      .setOutputCol("class")\
      .setLabelColumn("category")\
      .setMaxEpochs(3)\
      .setEnableOutputLogs(True)
      #.setOutputLogsPath('logs')

clf_pipeline = Pipeline(
    stages=[document_assembler, 
            tokenizer,
            normalizer,
            stopwords_cleaner, 
            lemma, 
            glove_embeddings,
            embeddingsSentence,
            classsifierdl])

glove_100d download started this may take some time.
Approximate size to download 145.3 MB
[OK!]


In [12]:
'''
default classifierDL params:

    maxEpochs -> 10,
    lr -> 5e-3f,
    dropout -> 0.5f,
    batchSize -> 64,
    enableOutputLogs -> false,
    verbose -> Verbose.Silent.id,
    validationSplit -> 0.0f,
    outputLogsPath -> ""
    
'''

'\ndefault classifierDL params:\n\n    maxEpochs -> 10,\n    lr -> 5e-3f,\n    dropout -> 0.5f,\n    batchSize -> 64,\n    enableOutputLogs -> false,\n    verbose -> Verbose.Silent.id,\n    validationSplit -> 0.0f,\n    outputLogsPath -> ""\n    \n'

In [13]:
# Train (3 min for 3 epochs)
%%time

clf_pipelineModel = clf_pipeline.fit(trainDataset)

CPU times: user 764 ms, sys: 87.8 ms, total: 852 ms
Wall time: 2min 25s


In [14]:
import os
log_file_name = os.listdir("/root/annotator_logs")[0]

with open("/root/annotator_logs/"+log_file_name, "r") as log_file :
    print(log_file.read())

Training started - epochs: 3 - learning_rate: 0.005 - batch_size: 64 - training_examples: 120000 - classes: 4
Epoch 0/3 - 7.25s - loss: 1638.4495 - acc: 0.8685 - batches: 1875
Epoch 1/3 - 4.73s - loss: 1618.0045 - acc: 0.8809083 - batches: 1875
Epoch 2/3 - 4.23s - loss: 1611.3066 - acc: 0.885825 - batches: 1875



In [15]:
# get the predictions on test Set

preds = clf_pipelineModel.transform(testDataset)

In [16]:
preds.select('category','description',"class.result").show(10, truncate=80)

+--------+--------------------------------------------------------------------------------+----------+
|category|                                                                     description|    result|
+--------+--------------------------------------------------------------------------------+----------+
|Business|Unions representing workers at Turner   Newall say they are 'disappointed' af...|[Business]|
|Sci/Tech| TORONTO, Canada    A second team of rocketeers competing for the  #36;10 mil...|[Sci/Tech]|
|Sci/Tech| A company founded by a chemistry researcher at the University of Louisville ...|[Sci/Tech]|
|Sci/Tech| It's barely dawn when Mike Fitzpatrick starts his shift with a blur of color...|[Sci/Tech]|
|Sci/Tech| Southern California's smog fighting agency went after emissions of the bovin...|[Sci/Tech]|
|Sci/Tech|"The British Department for Education and Skills (DfES) recently launched a "...|[Sci/Tech]|
|Sci/Tech|"confessed author of the Netsky and Sasser viruses, is responsi

In [17]:
preds_df = preds.select('category','description',"class.result").toPandas()

# The result is an array since in Spark NLP you can have multiple sentences.
# Let's explode the array and get the item(s) inside of result column out
preds_df['result'] = preds_df['result'].apply(lambda x : x[0])


In [18]:
# We are going to use sklearn to evalute the results on test dataset
from sklearn.metrics import classification_report

print (classification_report(preds_df['category'], preds_df['result']))

              precision    recall  f1-score   support

    Business       0.83      0.83      0.83      1900
    Sci/Tech       0.84      0.86      0.85      1900
      Sports       0.93      0.97      0.95      1900
       World       0.92      0.86      0.89      1900

    accuracy                           0.88      7600
   macro avg       0.88      0.88      0.88      7600
weighted avg       0.88      0.88      0.88      7600



## Getting prediction from Trained model

In [19]:
from sparknlp.base import LightPipeline

light_model = LightPipeline(clf_pipelineModel)

In [20]:
testDataset.select('description').take(2)

[Row(description="Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."),
 Row(description=' TORONTO, Canada    A second team of rocketeers competing for the  #36;10 million Ansari X Prize, a contest for privately funded suborbital space flight, has officially announced the first launch date for its manned rocket.')]

In [21]:
text='''
Fearing the fate of Italy, the centre-right government has threatened to be merciless with those who flout tough restrictions. 
As of Wednesday it will also include all shops being closed across Greece, with the exception of supermarkets. Banks, pharmacies, pet-stores, mobile phone stores, opticians, bakers, mini-markets, couriers and food delivery outlets are among the few that will also be allowed to remain open.
'''
result = light_model.annotate(text)

result['class']

['Business']

In [22]:
light_model.annotate('the soccer games will be postponed.')

{'lemma': ['soccer', 'game', 'postpone'],
 'document': ['the soccer games will be postponed.'],
 'normalized': ['the', 'soccer', 'games', 'will', 'be', 'postponed'],
 'sentence_embeddings': ['the soccer games will be postponed.'],
 'cleanTokens': ['soccer', 'games', 'postponed'],
 'token': ['the', 'soccer', 'games', 'will', 'be', 'postponed', '.'],
 'class': ['Sports'],
 'embeddings': ['soccer', 'game', 'postpone']}

# SentimentDL Classifier

see also here >> `https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/jupyter/training/english/classification/SentimentDL_train_multiclass_sentiment_classifier.ipynb`

In [23]:
!wget -q aclimdb_train.csv https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/sentiment-corpus/aclimdb/aclimdb_train.csv
!wget -q aclimdb_test.csv https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/sentiment-corpus/aclimdb/aclimdb_test.csv

In [24]:
trainDataset = spark.read \
      .option("header", True) \
      .csv("aclimdb_train.csv")

trainDataset.show(truncate=30)

+------------------------------+--------+
|                          text|   label|
+------------------------------+--------+
|This is an Excellent little...|positive|
|The Sarah Silverman program...|positive|
|"Prom Night" is a title-onl...|negative|
|So often a band will get to...|positive|
|"Pet Sematary" is an adapta...|positive|
|I watched the film recently...|negative|
|Boy this movie had me foole...|negative|
|Checking the spoiler alert ...|negative|
|Despite its rather salaciou...|positive|
|Absolute masterpiece of a f...|positive|
|The tweedy professor-types ...|positive|
|A movie best summed up by t...|negative|
|Take young, pretty people, ...|negative|
|For months I've been hearin...|negative|
|"Batman: The Mystery of the...|positive|
|Well, it was funny in spots...|negative|
|I have seen the short movie...|positive|
|Brainless film about a good...|negative|
|Leave it to geniuses like V...|negative|
|Seven Pounds stars Will Smi...|positive|
+------------------------------+--

In [25]:
testDataset = spark.read \
      .option("header", True) \
      .csv("aclimdb_test.csv")

testDataset.show(truncate = 30)

+------------------------------+--------+
|                          text|   label|
+------------------------------+--------+
|The Second Woman is about t...|negative|
|In my opinion the directing...|positive|
|I am listening to Istanbul,...|positive|
|Before I speak my piece, I ...|negative|
|ManBearPig is a pretty funn...|positive|
|A buddy and I went to see t...|negative|
|It is incredible that there...|negative|
|Dire! Dismal! Awful! Laugha...|negative|
|HLOTS was an outstanding se...|positive|
|This is just one of those f...|negative|
|This movie had the potentia...|negative|
|The 80s were overrun by all...|positive|
|The tunes are the best aspe...|positive|
|Having recently seen Grindh...|positive|
|My favorite film this year....|positive|
|This movie just might make ...|positive|
|This is the worst movie I h...|negative|
|This was a nice film. It ha...|positive|
|I don't know, maybe I just ...|negative|
|After wasting 2 hours of my...|negative|
+------------------------------+--

In [26]:
# actual content is inside description column
document = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

use = UniversalSentenceEncoder.pretrained("tfhub_use_lg", "en") \
    .setInputCols("document") \
    .setOutputCol("sentence_embeddings")

# the classes/labels/categories are in category column
sentimentdl = SentimentDLApproach()\
    .setInputCols(["sentence_embeddings"])\
    .setOutputCol("class")\
    .setLabelColumn("label")\
    .setMaxEpochs(5)\
    .setEnableOutputLogs(True)

pipeline = Pipeline(
    stages = [
        document,
        use,
        sentimentdl
    ])


tfhub_use_lg download started this may take some time.
Approximate size to download 753.3 MB
[OK!]


In [27]:
%%time
pipelineModel = pipeline.fit(trainDataset)

CPU times: user 1.63 s, sys: 194 ms, total: 1.82 s
Wall time: 6min 21s


In [28]:
result = pipelineModel.transform(testDataset)

In [29]:
result_df = result.select('text','label',"class.result").toPandas()

In [30]:
result_df.head(10)

Unnamed: 0,text,label,result
0,The Second Woman is about the story of a myste...,negative,[positive]
1,"In my opinion the directing, editing, lighting...",positive,[positive]
2,"I am listening to Istanbul, intent, my eyes cl...",positive,[positive]
3,"Before I speak my piece, I would like to make ...",negative,[positive]
4,ManBearPig is a pretty funny episode of South ...,positive,[positive]
5,A buddy and I went to see this movie when it c...,negative,[negative]
6,It is incredible that there were two films wit...,negative,[negative]
7,Dire! Dismal! Awful! Laughable! Disappointing!...,negative,[negative]
8,"HLOTS was an outstanding series, its what NYPD...",positive,[positive]
9,This is just one of those films which cannot j...,negative,[negative]


# MultiLabel Classifier

see also here >> `https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/jupyter/training/english/classification/MultiClassifierDL_train_multi_label_toxic_classifier.ipynb`

In [31]:
!curl -O 'https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/toxic_comments/toxic_train.snappy.parquet'
!curl -O 'https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/toxic_comments/toxic_test.snappy.parquet'

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 2702k  100 2702k    0     0  8606k      0 --:--:-- --:--:-- --:--:-- 8579k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  289k  100  289k    0     0  2415k      0 --:--:-- --:--:-- --:--:-- 2415k


In [2]:
trainDataset = spark.read.parquet("/content/toxic_train.snappy.parquet").repartition(120)
testDataset = spark.read.parquet("/content/toxic_test.snappy.parquet").repartition(10)

## Multilabel Classifier with Bert Embeddings

In [3]:
# Let's use shrink to remove new lines in the comments
document = DocumentAssembler()\
      .setInputCol("text")\
      .setOutputCol("document")\
      .setCleanupMode("shrink")

bert_sent = BertSentenceEmbeddings.pretrained('sent_small_bert_L8_512')\
      .setInputCols(["document"])\
      .setOutputCol("sentence_embeddings")

# We will use MultiClassifierDL built by using Bidirectional GRU and CNNs inside TensorFlow that supports up to 100 classes
# We will use only 5 Epochs but feel free to increase it on your own dataset
multiClassifier = MultiClassifierDLApproach()\
      .setInputCols("sentence_embeddings")\
      .setOutputCol("category")\
      .setLabelColumn("labels")\
      .setBatchSize(128)\
      .setMaxEpochs(5)\
      .setLr(1e-3)\
      .setThreshold(0.5)\
      .setShufflePerEpoch(False)\
      .setEnableOutputLogs(True)\
      .setValidationSplit(0.1)

pipeline = Pipeline(
    stages = [
        document,
        bert_sent,
        multiClassifier
    ])

sent_small_bert_L8_512 download started this may take some time.
Approximate size to download 149.1 MB
[OK!]


In [4]:
#! rm -r /root/annotator_logs

In [5]:
%%time
pipelineModel = pipeline.fit(trainDataset)

CPU times: user 681 ms, sys: 72.2 ms, total: 754 ms
Wall time: 2min 24s


In [6]:
!cat ~/annotator_logs/{multiClassifier.uid}.log

Training started - epochs: 5 - learning_rate: 0.001 - batch_size: 128 - training_examples: 13158 - classes: 6
Epoch 0/5 - 10.73s - loss: 0.33723187 - acc: 0.8611509 - val_loss: 0.30032548 - val_acc: 0.8673656 - val_f1: 0.8026087 - val_tpr: 0.75142246 - batches: 103
Epoch 1/5 - 1.09s - loss: 0.29930413 - acc: 0.8785859 - val_loss: 0.29365048 - val_acc: 0.86979216 - val_f1: 0.81163687 - val_tpr: 0.7816467 - batches: 103
Epoch 2/5 - 1.19s - loss: 0.2917258 - acc: 0.8815477 - val_loss: 0.2913404 - val_acc: 0.87299633 - val_f1: 0.8172958 - val_tpr: 0.7914678 - batches: 103
Epoch 3/5 - 1.08s - loss: 0.2864562 - acc: 0.8833028 - val_loss: 0.2905673 - val_acc: 0.8717741 - val_f1: 0.8149679 - val_tpr: 0.787109 - batches: 103
Epoch 4/5 - 1.10s - loss: 0.28217906 - acc: 0.88537127 - val_loss: 0.29022342 - val_acc: 0.8730394 - val_f1: 0.81671864 - val_tpr: 0.78860056 - batches: 103


In [7]:
preds = pipelineModel.transform(testDataset)

In [8]:
preds_df = preds.select('text','labels',"category.result").toPandas()

In [9]:
preds_df.head(10)

Unnamed: 0,text,labels,result
0,Vegan \n\nWhat in the hell does all that junk ...,[toxic],"[toxic, obscene]"
1,Fight Club! F**k Yeeaaaaahh!!!,"[toxic, obscene]","[toxic, obscene]"
2,"""\n\n Little quick on the trigger, ain't'cha b...",[toxic],[toxic]
3,Your user page indicates you're a left-wing li...,"[toxic, obscene, insult]",[toxic]
4,""" See all the many Google links, titled"""" Wik...",[toxic],[toxic]
5,"""\n\n LOL \n\nLOL. Seriously, """"BryanFromPalat...",[toxic],"[toxic, obscene]"
6,is it because it is of my naked mum having sex...,"[toxic, severe_toxic, obscene, insult, identit...","[toxic, severe_toxic, insult, obscene]"
7,")\na cowards site, that must stop changing thi...","[toxic, obscene]",[toxic]
8,"blow me, criticism IS constructive. \n\nBlow m...","[toxic, obscene, insult]","[toxic, insult, obscene]"
9,On account of the project deciding to ignore h...,"[toxic, obscene, insult]",[toxic]


## Saving & loading back the trained model

In [10]:
pipelineModel.stages

[DocumentAssembler_c7c521bd305d,
 BERT_SENTENCE_EMBEDDINGS_3608c0d843af,
 MultiClassifierDLModel_0dbf61fd7dbf]

In [11]:
# Save the Multilabel Classifier Model
pipelineModel.stages[2].write().overwrite().save('MultilabelClfBert')

In [12]:
# Load back  saved Multilabel Classifier Model
MultilabelClfModel = MultiClassifierDLModel.load('MultilabelClfBert')

In [13]:
# Generate prediction Pipeline with loaded Model 
ld_pipeline = Pipeline(stages=[document, bert_sent, MultilabelClfModel])
ld_pipeline_model = ld_pipeline.fit(spark.createDataFrame([['']]).toDF("text"))

In [14]:
# Apply Model Transform to testData
ld_preds = ld_pipeline_model.transform(testDataset)

In [15]:
ld_preds_df = ld_preds.select('text','labels',"category.result").toPandas()


In [16]:
ld_preds_df.head(10)

Unnamed: 0,text,labels,result
0,Vegan \n\nWhat in the hell does all that junk ...,[toxic],"[toxic, obscene]"
1,Fight Club! F**k Yeeaaaaahh!!!,"[toxic, obscene]","[toxic, obscene]"
2,"""\n\n Little quick on the trigger, ain't'cha b...",[toxic],[toxic]
3,Your user page indicates you're a left-wing li...,"[toxic, obscene, insult]",[toxic]
4,""" See all the many Google links, titled"""" Wik...",[toxic],[toxic]
5,"""\n\n LOL \n\nLOL. Seriously, """"BryanFromPalat...",[toxic],"[toxic, obscene]"
6,is it because it is of my naked mum having sex...,"[toxic, severe_toxic, obscene, insult, identit...","[toxic, severe_toxic, insult, obscene]"
7,")\na cowards site, that must stop changing thi...","[toxic, obscene]",[toxic]
8,"blow me, criticism IS constructive. \n\nBlow m...","[toxic, obscene, insult]","[toxic, insult, obscene]"
9,On account of the project deciding to ignore h...,"[toxic, obscene, insult]",[toxic]
