![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/training/english/classification/MultiClassifierDL_train_multi_label_E2E_challenge_classifier.ipynb)


# Multi-label Text Classification: E2E Challenge using MultiClassifierDL

In [None]:
# Only run this cell when you are using Spark NLP on Google Colab
!wget http://setup.johnsnowlabs.com/colab.sh -O - | bash

Let's download our Toxic comments for tarining and testing:

In [None]:
!curl -O 'https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/e2e_challenge/e2e_train.snappy.parquet'

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1922k  100 1922k    0     0  1337k      0  0:00:01  0:00:01 --:--:-- 1337k


In [None]:
import sparknlp

spark=sparknlp.start()
print("Spark NLP version")
sparknlp.version()

Spark NLP version


'4.3.1'

Let's read our Toxi comments datasets:

In [None]:
trainDataset, testDataset = spark.read.parquet("e2e_train.snappy.parquet") \
  .randomSplit([0.9, 0.1], seed = 12345)

In [None]:
trainDataset.show(2)

+--------------------+--------------------+
|                 ref|              labels|
+--------------------+--------------------+
|'Bibimbap House' ...|[name[Bibimbap Ho...|
|'Browns Cambridge...|[name[Browns Camb...|
+--------------------+--------------------+
only showing top 2 rows



As you can see, there are lots of new lines in our comments which we can fix them with `DocumentAssembler`

In [None]:
print(trainDataset.cache().count())
print(testDataset.cache().count())

37762
4299


In [None]:
from pyspark.ml import Pipeline

from sparknlp.annotator import *
from sparknlp.common import *
from sparknlp.base import *

In [None]:
# The actual text is in a column named ref
document = DocumentAssembler()\
  .setInputCol("ref")\
  .setOutputCol("document")

# Here we use the state-of-the-art Universal Sentence Encoder model from TF Hub
embeddings = UniversalSentenceEncoder.pretrained() \
  .setInputCols(["document"])\
  .setOutputCol("sentence_embeddings")

# We will use MultiClassifierDL built by using Bidirectional GRU and CNNs inside TensorFlow that supports up to 100 classes
# We will use only 5 Epochs but feel free to increase it on your own dataset
multiClassifier = MultiClassifierDLApproach()\
  .setInputCols("sentence_embeddings")\
  .setOutputCol("category")\
  .setLabelColumn("labels")\
  .setBatchSize(128)\
  .setMaxEpochs(5)\
  .setLr(1e-3)\
  .setThreshold(0.5)\
  .setShufflePerEpoch(False)\
  .setEnableOutputLogs(True)\
  .setValidationSplit(0.1)

pipeline = Pipeline(
    stages = [
        document,
        embeddings,
        multiClassifier
    ])

tfhub_use download started this may take some time.
Approximate size to download 923.7 MB
[OK!]


In [None]:
pipelineModel = pipeline.fit(trainDataset)

In [None]:
!ls -l ~/annotator_logs/

total 216
-rw-r--r-- 1 root root 456 20. Feb 17:41 ClassifierDLApproach_0375e3a8df00.log
-rw-r--r-- 1 root root 918 20. Feb 17:38 ClassifierDLApproach_6fdb8a569309.log
-rw-r--r-- 1 root root 446 20. Feb 15:55 ClassifierDLApproach_97ff5c76d735.log
-rw-r--r-- 1 root root 438 20. Feb 17:38 ClassifierMetrics_09bd6fa2ecf7.log
-rw-r--r-- 1 root root 317 10. Feb 16:54 ClassifierMetrics_17606bbb7d1f.log
-rw-r--r-- 1 root root 571 20. Feb 17:45 ClassifierMetrics_176ce729caa6.log
-rw-r--r-- 1 root root 313 10. Feb 16:54 ClassifierMetrics_1a6c515483ae.log
-rw-r--r-- 1 root root 441 20. Feb 17:38 ClassifierMetrics_1e0c8ea78e67.log
-rw-r--r-- 1 root root 323 10. Feb 16:54 ClassifierMetrics_2530315112a8.log
-rw-r--r-- 1 root root 566 20. Feb 17:45 ClassifierMetrics_26e8744dc78c.log
-rw-r--r-- 1 root root 565 20. Feb 17:45 ClassifierMetrics_284f041511fb.log
-rw-r--r-- 1 root root 445 20. Feb 17:38 ClassifierMetrics_2b7b458fc84d.log
-rw-r--r-- 1 root root 551 20. Feb 17:45 ClassifierMetrics_2fde2811a9

In [None]:
!cat ~/annotator_logs/MultiClassifierDLApproach_b80de1f04776.log

cat: /home/root/annotator_logs/MultiClassifierDLApproach_b80de1f04776.log: No such file or directory


Let's save our trained multi-label classifier model to be loaded in our prediction pipeline:

In [None]:
pipelineModel.stages[-1].write().overwrite().save('tmp_multi_classifierDL_model')

## load saved pipeline

In [None]:
document = DocumentAssembler()\
    .setInputCol("ref")\
    .setOutputCol("document")

use = UniversalSentenceEncoder.pretrained() \
 .setInputCols(["document"])\
 .setOutputCol("sentence_embeddings")

multiClassifier = MultiClassifierDLModel.load("tmp_multi_classifierDL_model") \
  .setInputCols(["sentence_embeddings"])\
  .setOutputCol("category")\
  .setThreshold(0.5)

pipeline = Pipeline(
    stages = [
        document,
        use,
        multiClassifier
    ])

tfhub_use download started this may take some time.
Approximate size to download 923.7 MB
[OK!]


Let's now use our testing datasets to evaluate our model:

In [None]:
# let's see our labels:
print(pipeline.fit(testDataset).stages[2].getClasses())
print(len(pipeline.fit(testDataset).stages[2].getClasses()))

['name[Bibimbap House]', 'name[Wildwood]', 'name[Cotto]', 'name[Clowns]', 'near[Burger King]', 'name[The Dumpling Tree]', 'name[The Vaults]', 'near[Crowne Plaza Hotel]', 'name[The Golden Palace]', 'name[The Rice Boat]', 'customer rating[high]', 'near[Avalon]', 'name[Alimentum]', 'near[The Bakers]', 'name[The Waterman]', 'near[Ranch]', 'name[The Olive Grove]', 'name[The Eagle]', 'name[The Wrestlers]', 'eatType[restaurant]', 'near[All Bar One]', 'customer rating[low]', 'near[Café Sicilia]', 'near[Yippee Noodle Bar]', 'food[Indian]', 'eatType[pub]', 'name[Green Man]', 'name[Strada]', 'near[Café Adriatic]', 'eatType[coffee shop]', 'name[Loch Fyne]', 'customer rating[5 out of 5]', 'near[Express by Holiday Inn]', 'food[French]', 'name[The Mill]', 'food[Japanese]', 'name[Travellers Rest Beefeater]', 'name[The Plough]', 'name[Cocum]', 'near[The Six Bells]', 'name[The Phoenix]', 'priceRange[cheap]', 'name[Midsummer House]', 'near[Rainbow Vegetarian Café]', 'near[The Rice Boat]', 'customer ratin

In [None]:
preds = pipeline.fit(testDataset).transform(testDataset)


In [None]:
preds.select('labels', 'ref', 'category.result').show(2)

+--------------------+--------------------+--------------------+
|              labels|                 ref|              result|
+--------------------+--------------------+--------------------+
|[name[Alimentum],...|1 out of 5 stars ...|[name[Alimentum],...|
|[name[The Punter]...|1 star budget, fa...|[near[Café Sicili...|
+--------------------+--------------------+--------------------+
only showing top 2 rows



In [None]:
preds_df = preds.select('labels', 'category.result').toPandas()

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score

mlb = MultiLabelBinarizer()

y_true = mlb.fit_transform(preds_df['labels'])
y_pred = mlb.fit_transform(preds_df['result'])

print("Classification report: \n", (classification_report(y_true, y_pred)))
print("F1 micro averaging:",(f1_score(y_true, y_pred, average='micro')))
print("ROC: ",(roc_auc_score(y_true, y_pred, average="micro")))

Classification report: 
               precision    recall  f1-score   support

           0       0.88      0.84      0.86       790
           1       0.86      0.87      0.86      1774
           2       0.68      0.06      0.11       431
           3       0.70      0.12      0.20       422
           4       0.72      0.25      0.37       525
           5       0.78      0.37      0.50       592
           6       0.68      0.18      0.29       421
           7       0.72      0.21      0.32       512
           8       0.99      0.95      0.97      1043
           9       0.97      0.88      0.92       660
          10       0.84      0.52      0.64       306
          11       0.81      0.62      0.70       932
          12       0.83      0.80      0.81      1777
          13       0.95      0.90      0.92       292
          14       0.94      0.50      0.66       411
          15       0.93      0.81      0.86       599
          16       0.90      0.73      0.80       564
  

  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
preds.select("category.metadata").show(10)

+--------------------+
|            metadata|
+--------------------+
|[{name[Alimentum]...|
|[{name[Alimentum]...|
|[{name[Alimentum]...|
|[{name[Alimentum]...|
|[{name[Alimentum]...|
|[{name[Alimentum]...|
|[{name[Alimentum]...|
|[{name[Alimentum]...|
|[{name[Alimentum]...|
|[{name[Alimentum]...|
+--------------------+
only showing top 10 rows



In [None]:
preds.select("category.metadata").printSchema()

root
 |-- metadata: array (nullable = true)
 |    |-- element: map (containsNull = true)
 |    |    |-- key: string
 |    |    |-- value: string (valueContainsNull = true)

