![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/classification/MultiClassifierDL_train_multi_label_toxic_classifier.ipynb)


# Multi-label Text Classification of Toxic Comments using MultiClassifierDL

In [None]:
# Only run this cell when you are using Spark NLP on Google Colab
!wget http://setup.johnsnowlabs.com/colab.sh -O - | bash

Let's download our Toxic comments for tarining and testing:

In [None]:
!curl -O 'https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/toxic_comments/toxic_train.snappy.parquet'

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 2702k  100 2702k    0     0  1720k      0  0:00:01  0:00:01 --:--:-- 1720k


In [None]:
!curl -O 'https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/toxic_comments/toxic_test.snappy.parquet'

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  289k  100  289k    0     0   254k      0  0:00:01  0:00:01 --:--:--  254k


In [None]:
import sparknlp

spark=sparknlp.start()
print("Spark NLP version")
sparknlp.version()

Spark NLP version


'4.3.1'

Let's read our Toxi comments datasets:

In [None]:
trainDataset = spark.read.parquet("toxic_train.snappy.parquet").repartition(120)
testDataset = spark.read.parquet("toxic_test.snappy.parquet").repartition(10)

In [None]:
trainDataset.show(2)

+----------------+--------------------+-------+
|              id|                text| labels|
+----------------+--------------------+-------+
|e63f1cc4b0b9959f|EAT SHIT HORSE FA...|[toxic]|
|ed58abb40640f983|PN News\nYou mean...|[toxic]|
+----------------+--------------------+-------+
only showing top 2 rows



As you can see, there are lots of new lines in our comments which we can fix them with `DocumentAssembler`

In [None]:
print(trainDataset.cache().count())
print(testDataset.cache().count())

14620
1605


In [None]:
from pyspark.ml import Pipeline

from sparknlp.annotator import *
from sparknlp.common import *
from sparknlp.base import *

In [None]:
# Let's use shrink to remove new lines in the comments
document = DocumentAssembler()\
  .setInputCol("text")\
  .setOutputCol("document")\
  .setCleanupMode("shrink")

# Here we use the state-of-the-art Universal Sentence Encoder model from TF Hub
embeddings = UniversalSentenceEncoder.pretrained() \
  .setInputCols(["document"])\
  .setOutputCol("sentence_embeddings")

# We will use MultiClassifierDL built by using Bidirectional GRU and CNNs inside TensorFlow that supports up to 100 classes
# We will use only 5 Epochs but feel free to increase it on your own dataset
multiClassifier = MultiClassifierDLApproach()\
  .setInputCols("sentence_embeddings")\
  .setOutputCol("category")\
  .setLabelColumn("labels")\
  .setBatchSize(128)\
  .setMaxEpochs(5)\
  .setLr(1e-3)\
  .setThreshold(0.5)\
  .setShufflePerEpoch(False)\
  .setEnableOutputLogs(True)\
  .setValidationSplit(0.1)

pipeline = Pipeline(
    stages = [
        document,
        embeddings,
        multiClassifier
    ])

tfhub_use download started this may take some time.
Approximate size to download 923.7 MB
[OK!]


In [None]:
pipelineModel = pipeline.fit(trainDataset)

In [None]:
!ls -l ~/annotator_logs/

total 240
-rw-r--r-- 1 root root 456 20. Feb 17:41 ClassifierDLApproach_0375e3a8df00.log
-rw-r--r-- 1 root root 918 20. Feb 17:38 ClassifierDLApproach_6fdb8a569309.log
-rw-r--r-- 1 root root 446 20. Feb 15:55 ClassifierDLApproach_97ff5c76d735.log
-rw-r--r-- 1 root root 438 20. Feb 17:38 ClassifierMetrics_09bd6fa2ecf7.log
-rw-r--r-- 1 root root 317 10. Feb 16:54 ClassifierMetrics_17606bbb7d1f.log
-rw-r--r-- 1 root root 571 20. Feb 17:45 ClassifierMetrics_176ce729caa6.log
-rw-r--r-- 1 root root 313 10. Feb 16:54 ClassifierMetrics_1a6c515483ae.log
-rw-r--r-- 1 root root 441 20. Feb 17:38 ClassifierMetrics_1e0c8ea78e67.log
-rw-r--r-- 1 root root 323 10. Feb 16:54 ClassifierMetrics_2530315112a8.log
-rw-r--r-- 1 root root 566 20. Feb 17:45 ClassifierMetrics_26e8744dc78c.log
-rw-r--r-- 1 root root 565 20. Feb 17:45 ClassifierMetrics_284f041511fb.log
-rw-r--r-- 1 root root 445 20. Feb 17:38 ClassifierMetrics_2b7b458fc84d.log
-rw-r--r-- 1 root root 551 20. Feb 17:45 ClassifierMetrics_2fde2811a9

In [None]:
!cat ~/annotator_logs/MultiClassifierDLApproach_d670b2c2d0df.log


cat: /home/root/annotator_logs/MultiClassifierDLApproach_d670b2c2d0df.log: No such file or directory


Let's save our trained multi-label classifier model to be loaded in our prediction pipeline:

In [None]:
pipelineModel.stages[-1].write().overwrite().save('tmp_multi_classifierDL_model')

## load saved pipeline

In [None]:
document = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

use = UniversalSentenceEncoder.pretrained() \
 .setInputCols(["document"])\
 .setOutputCol("sentence_embeddings")

multiClassifier = MultiClassifierDLModel.load("tmp_multi_classifierDL_model") \
  .setInputCols(["sentence_embeddings"])\
  .setOutputCol("category")\
  .setThreshold(0.5)

pipeline = Pipeline(
    stages = [
        document,
        use,
        multiClassifier
    ])

tfhub_use download started this may take some time.
Approximate size to download 923.7 MB
[OK!]


Let's now use our testing datasets to evaluate our model:

In [None]:
# let's see our labels:
print(pipeline.fit(testDataset).stages[2].getClasses())

['toxic', 'severe_toxic', 'identity_hate', 'insult', 'obscene', 'threat']


In [None]:
preds = pipeline.fit(testDataset).transform(testDataset)


In [None]:
preds.select('labels','text',"category.result").show(2)

+----------------+--------------------+----------------+
|          labels|                text|          result|
+----------------+--------------------+----------------+
|         [toxic]|Vegan \n\nWhat in...|         [toxic]|
|[toxic, obscene]|Fight Club! F**k ...|[toxic, obscene]|
+----------------+--------------------+----------------+
only showing top 2 rows



In [None]:
preds_df = preds.select('labels', 'category.result').toPandas()

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score

mlb = MultiLabelBinarizer()

y_true = mlb.fit_transform(preds_df['labels'])
y_pred = mlb.fit_transform(preds_df['result'])


print("Classification report: \n", (classification_report(y_true, y_pred)))
print("F1 micro averaging:",(f1_score(y_true, y_pred, average='micro')))
print("ROC: ",(roc_auc_score(y_true, y_pred, average="micro")))

Classification report: 
               precision    recall  f1-score   support

           0       0.53      0.35      0.42       127
           1       0.73      0.62      0.67       761
           2       0.79      0.67      0.73       824
           3       0.50      0.15      0.23       147
           4       0.73      0.38      0.50        50
           5       0.94      1.00      0.97      1504

   micro avg       0.84      0.77      0.80      3413
   macro avg       0.70      0.53      0.59      3413
weighted avg       0.82      0.77      0.78      3413
 samples avg       0.86      0.80      0.79      3413

F1 micro averaging: 0.802391537636057
ROC:  0.8437377009561553


In [None]:
preds.select("category.metadata").show(10, False)

+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|metadata                                                                                                                                                                                            

In [None]:
preds.select("category.metadata").printSchema()

root
 |-- metadata: array (nullable = true)
 |    |-- element: map (containsNull = true)
 |    |    |-- key: string
 |    |    |-- value: string (valueContainsNull = true)

