![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)
# Spark NLP
## Multi-label Text Classification
### Toxic Comments
#### By using MultiClassifierDL

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/jupyter/training/english/classification/MultiClassifierDL_train_multi_label_toxic_classifier.ipynb)

In [3]:
import os

# Install java
! apt-get update -qq
! apt-get install -y openjdk-8-jdk-headless -qq > /dev/null

os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]
! java -version

# Install pyspark
! pip install --ignore-installed -q pyspark==2.4.6
! pip install --ignore-installed -q spark-nlp==2.6.0

openjdk version "1.8.0_265"
OpenJDK Runtime Environment (build 1.8.0_265-8u265-b01-0ubuntu2~18.04-b01)
OpenJDK 64-Bit Server VM (build 25.265-b01, mixed mode)
[K     |████████████████████████████████| 218.4MB 67kB/s 
[K     |████████████████████████████████| 204kB 57.7MB/s 
[?25h  Building wheel for pyspark (setup.py) ... [?25l[?25hdone


Let's download our Toxic comments for tarining and testing:

In [None]:
!curl -O 'https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/toxic_comments/toxic_train.snappy.parquet'

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 2702k  100 2702k    0     0  3117k      0 --:--:-- --:--:-- --:--:-- 3113k


In [None]:
!curl -O 'https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/toxic_comments/toxic_test.snappy.parquet'

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  289k  100  289k    0     0   185k      0  0:00:01  0:00:01 --:--:--  185k


In [None]:
import sparknlp

spark=sparknlp.start()
print("Spark NLP version")
sparknlp.version()


Spark NLP version


'2.6.0'

Let's read our Toxi comments datasets:

In [None]:
trainDataset = spark.read.parquet("/content/toxic_train.snappy.parquet").repartition(120)
testDataset = spark.read.parquet("/content/toxic_test.snappy.parquet").repartition(10)

In [None]:
trainDataset.show(2)

+----------------+--------------------+-------+
|              id|                text| labels|
+----------------+--------------------+-------+
|e63f1cc4b0b9959f|EAT SHIT HORSE FA...|[toxic]|
|ed58abb40640f983|PN News
You mean ...|[toxic]|
+----------------+--------------------+-------+
only showing top 2 rows



As you can see, there are lots of new lines in our comments which we can fix them with `DocumentAssembler`

In [None]:
print(trainDataset.cache().count())
print(testDataset.cache().count())

14620
1605


In [None]:
from pyspark.ml import Pipeline

from sparknlp.annotator import *
from sparknlp.common import *
from sparknlp.base import *

In [None]:
# Let's use shrink to remove new lines in the comments
document = DocumentAssembler()\
  .setInputCol("text")\
  .setOutputCol("document")\
  .setCleanupMode("shrink")

# Here we use the state-of-the-art Universal Sentence Encoder model from TF Hub
embeddings = UniversalSentenceEncoder.pretrained() \
  .setInputCols(["document"])\
  .setOutputCol("sentence_embeddings")

# We will use MultiClassifierDL built by using Bidirectional GRU and CNNs inside TensorFlow that supports up to 100 classes
# We will use only 5 Epochs but feel free to increase it on your own dataset
multiClassifier = MultiClassifierDLApproach()\
  .setInputCols("sentence_embeddings")\
  .setOutputCol("category")\
  .setLabelColumn("labels")\
  .setBatchSize(128)\
  .setMaxEpochs(5)\
  .setLr(1e-3)\
  .setThreshold(0.5)\
  .setShufflePerEpoch(False)\
  .setEnableOutputLogs(True)\
  .setValidationSplit(0.1)

pipeline = Pipeline(
    stages = [
        document,
        embeddings,
        multiClassifier
    ])

tfhub_use download started this may take some time.
Approximate size to download 923.7 MB
[OK!]


In [None]:
pipelineModel = pipeline.fit(trainDataset)

In [None]:
!ls -l ~/annotator_logs/

total 4
-rw-r--r-- 1 root root 885 Sep  2 16:56 MultiClassifierDLApproach_d670b2c2d0df.log


In [None]:
!cat ~/annotator_logs/MultiClassifierDLApproach_d670b2c2d0df.log


Training started - epochs: 5 - learning_rate: 0.001 - batch_size: 128 - training_examples: 13158 - classes: 6
Epoch 0/5 - 15.19s - loss: 0.38046357 - acc: 0.848714 - val_loss: 0.30129096 - val_acc: 0.871466 - val_f1: 0.81246215 - val_tpr: 0.77513814 - batches: 103
Epoch 1/5 - 5.51s - loss: 0.30138606 - acc: 0.87715614 - val_loss: 0.28858984 - val_acc: 0.8747491 - val_f1: 0.819081 - val_tpr: 0.789548 - batches: 103
Epoch 2/5 - 5.37s - loss: 0.29324576 - acc: 0.87968993 - val_loss: 0.28451642 - val_acc: 0.8766811 - val_f1: 0.82239383 - val_tpr: 0.79497665 - batches: 103
Epoch 3/5 - 5.38s - loss: 0.28977352 - acc: 0.88131446 - val_loss: 0.2825411 - val_acc: 0.87826157 - val_f1: 0.8243148 - val_tpr: 0.7951459 - batches: 103
Epoch 4/5 - 5.38s - loss: 0.2876302 - acc: 0.88208383 - val_loss: 0.28134403 - val_acc: 0.878595 - val_f1: 0.82474065 - val_tpr: 0.79545283 - batches: 103


Let's save our trained multi-label classifier model to be loaded in our prediction pipeline:

In [None]:
pipelineModel.stages[-1].write().overwrite().save('/content/tmp_multi_classifierDL_model')

## load saved pipeline

In [None]:
document = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

use = UniversalSentenceEncoder.pretrained() \
 .setInputCols(["document"])\
 .setOutputCol("sentence_embeddings")

multiClassifier = MultiClassifierDLModel.load("/content/tmp_multi_classifierDL_model") \
  .setInputCols(["sentence_embeddings"])\
  .setOutputCol("category")\
  .setThreshold(0.5)

pipeline = Pipeline(
    stages = [
        document,
        use,
        multiClassifier
    ])

tfhub_use download started this may take some time.
Approximate size to download 923.7 MB
[OK!]


Let's now use our testing datasets to evaluate our model:

In [None]:
# let's see our labels:
print(pipeline.fit(testDataset).stages[2].getClasses())

['toxic', 'severe_toxic', 'identity_hate', 'insult', 'obscene', 'threat']


In [None]:
preds = pipeline.fit(testDataset).transform(testDataset)


In [None]:
preds.select('labels','text',"category.result").show(2)

+----------------+--------------------+----------------+
|          labels|                text|          result|
+----------------+--------------------+----------------+
|         [toxic]|Vegan 

What in t...|         [toxic]|
|[toxic, obscene]|Fight Club! F**k ...|[toxic, obscene]|
+----------------+--------------------+----------------+
only showing top 2 rows



In [None]:
preds_df = preds.select('labels', 'category.result').toPandas()

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score

mlb = MultiLabelBinarizer()

y_true = mlb.fit_transform(preds_df['labels'])
y_pred = mlb.fit_transform(preds_df['result'])


print("Classification report: \n", (classification_report(y_true, y_pred)))
print("F1 micro averaging:",(f1_score(y_true, y_pred, average='micro')))
print("ROC: ",(roc_auc_score(y_true, y_pred, average="micro")))

Classification report: 
               precision    recall  f1-score   support

           0       0.53      0.35      0.42       127
           1       0.73      0.62      0.67       761
           2       0.79      0.67      0.73       824
           3       0.50      0.15      0.23       147
           4       0.73      0.38      0.50        50
           5       0.94      1.00      0.97      1504

   micro avg       0.84      0.77      0.80      3413
   macro avg       0.70      0.53      0.59      3413
weighted avg       0.82      0.77      0.78      3413
 samples avg       0.86      0.80      0.79      3413

F1 micro averaging: 0.802391537636057
ROC:  0.8437377009561553


In [None]:
preds.select("category.metadata").show(10, False)

+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|metadata                                                                                                                                                                                                 

In [None]:
preds.select("category.metadata").printSchema()

root
 |-- metadata: array (nullable = true)
 |    |-- element: map (containsNull = true)
 |    |    |-- key: string
 |    |    |-- value: string (valueContainsNull = true)

