![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/Spark_NLP_Udemy_MOOC/Open_Source/20.03.MultiClassifierDL.ipynb)

# **MultiClassifierDL**

This notebook will cover the different parameters and usages of `MultiClassifierDL`. This annotator provides the ability to make classifications when a given input may belong to more than one label. 


**📖 Learning Objectives:**

1. Understand how `MultiClassifierDL` algorithm works.

2. Understand how `MultiClassifierDL` follows an unsupervised approach which builds upon features extracted from the text.

3. Become comfortable using the different parameters of the annotator.


**🔗 Helpful Links:**

- Documentation : [MultiClassifierDL](https://nlp.johnsnowlabs.com/docs/en/annotators#multiclassifierdl)

- Python Docs : [MultiClassifierDLApproach](https://nlp.johnsnowlabs.com/api/python/reference/autosummary/sparknlp/annotator/classifier_dl/multi_classifier_dl/index.html#sparknlp.annotator.classifier_dl.multi_classifier_dl.MultiClassifierDLApproach), [MultiClassifierDLModel](https://nlp.johnsnowlabs.com/api/python/reference/autosummary/sparknlp/annotator/classifier_dl/multi_classifier_dl/index.html#sparknlp.annotator.classifier_dl.multi_classifier_dl.MultiClassifierDLModel)

- Scala Docs : [MultiClassifierDLApproach](https://nlp.johnsnowlabs.com/api/com/johnsnowlabs/nlp/annotators/classifier/dl/MultiClassifierDLApproach.html), [MultiClassifierDLModel](https://nlp.johnsnowlabs.com/api/com/johnsnowlabs/nlp/annotators/classifier/dl/MultiClassifierDLModel)

- For extended examples of usage, see the [Spark NLP Workshop repository](https://github.com/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Public/5.Text_Classification_with_ClassifierDL.ipynb).

## **📜 Background**

`Multi-label classification` is a variant of the text classification problem where multiple nonexclusive labels may be assigned to each instance. 

Classification is a predictive modeling problem where the class label is anticipated for a specific example of input data. Typically, a classification task involves predicting a single label.

In some cases, class labels or class membership are not mutually exclusive. These tasks are referred to as multiple label classification, or multi-label classification for short.

<br/>

Various Sentence Embeddings can be used as an input, such as the state-of-the-art [UniversalSentenceEncoder](https://nlp.johnsnowlabs.com/docs/en/transformers#universalsentenceencoder), [BertSentenceEmbeddings](https://nlp.johnsnowlabs.com/docs/en/transformers#bertsentenceembeddings) or [SentenceEmbeddings](https://nlp.johnsnowlabs.com/docs/en/annotators#sentenceembeddings).

In this notebook, the aim is to perform multi-label classification of tweets  in order to show the use and efficiency of the deep learning based `MultiClassifierDL` annotator.



## **🎬 Colab Setup**

In [None]:
! pip install -q pyspark==3.3.0  spark-nlp==4.2.4

In [None]:
import sparknlp

import sys
sys.path.append('../../')

import sparknlp
import pandas as pd
import pyspark.sql.functions as F

from pyspark.sql import SparkSession
from pyspark.ml import Pipeline, PipelineModel
from pyspark.sql.functions import array_contains
from sparknlp.annotator import *
from sparknlp.common import RegexRule
from sparknlp.base import DocumentAssembler, Finisher
from sparknlp.base import LightPipeline

spark = sparknlp.start()

print("Spark NLP version", sparknlp.version())
print("Apache Spark version:", spark.version)

spark

Spark NLP version 4.2.4
Apache Spark version: 3.3.0


## **🖨️ Input/Output Annotation Types**

- Input: `SENTENCE EMBEDDINGS`

- Output: `CATEGORY`

## **🔎 Parameters**

### **`MultiClassifierDLApproach`**

- `setBatchSize`: (int) number of samples used in one iteration of training  (Default: `64`).

- `setLr`: (float) controls the step size taken by the optimizer when updating the weights of the model (Default: `5e-3`)

- `setThreshold`: (float) the point at which the model makes a prediction (Default: `0.5`).

- `setMaxEpochs`: (int) the number of times the model is trained on the entire dataset (Default; `10`). 

- `setShufflePerEpoch`: (Boolean) whether to shuffle the training data on each epoch (Default: `False`).

- `setValidationSplit`: (float) the proportion of training dataset to be validated against the model on each Epoch (`Default`: 0.0).

- `setVerbose`: (int) is a setting that determines how much information is printed out during the training process (`Default`: 0).

- `setEnableOutputLogs`: (Boolean) whether to output to annotator's log folder (Default: false).

- `setEvaluationLogExtended` (Boolean) controls whether or not to output extended evaluation metrics during model training.

- `setLabelColumn`: (str) define the column with label per each document.

- `setOutputLogsPath`: specifies the directory path to save the training and evaluation logs during the model training process.

- `setRandomSeed`: defines a random seed.

- `setTestDataset`: (str) path to a parquet file of a test dataset.

- `setConfigProtoBytes`: (int) ConfigProto from tensorflow, serialized into byte array.

### **`MultiClassifierDLModel`**


- `setConfigProtoBytes`: Sets configProto from tensorflow, serialized into byte array.

- `setThreshold`: (float) Sets the minimum threshold for each label to be accepted (`Default`: 0.5).

- `setDatasetParams`: is used to set the dataset parameters.

## **💻 MultiClassifierDLModel**

`MultiClassifierDLModel` is an annotator for multiple-label Text Classification.

Multi-class classification is a type of supervised learning problem in which a model is trained to classify input data into more than two possible classes. The `MultiClassifierDLModel` annotator in Spark NLP is designed specifically for multi-class classification tasks, and can be used to classify input data into any number of classes.

## Using a 💎 Model From the John Snow Labs Models Hub

Instead of training, saving, loading and getting predictions from this model, in this example, let us use a model from the John Snow Labs Models Hub.   

The model's name is [Toxic Comment Classification](https://nlp.johnsnowlabs.com/2021/01/21/multiclassifierdl_use_toxic_sm_en.html). This model automatically detects identity hate, insult, obscene, severe toxic, threat, or toxic content in SM comments using our out-of-the-box Spark NLP Multiclassifier DL.

[UniversalSentenceEncoder](https://nlp.johnsnowlabs.com/docs/en/transformers#universalsentenceencoder) was used for the training of this particular model, so we have to use the same embeddings (not the BertSentenceEmbeddings).

In [None]:
document = DocumentAssembler()\
  .setInputCol("text")\
  .setOutputCol("document")

use = UniversalSentenceEncoder.pretrained() \
  .setInputCols(["document"])\
  .setOutputCol("use_embeddings")

docClassifier = MultiClassifierDLModel.pretrained("multiclassifierdl_use_toxic_sm") \
  .setInputCols(["use_embeddings"])\
  .setOutputCol("category")

nlpPipeline = Pipeline(stages=[document, use, docClassifier])

tfhub_use download started this may take some time.
Approximate size to download 923.7 MB
[OK!]
multiclassifierdl_use_toxic_sm download started this may take some time.
Approximate size to download 11.6 MB
[OK!]


In [None]:
data = spark.createDataFrame([["""F*ck You Shut the fuck up you whiney ass little bitch I'll do all the shit I want and you can go fuck yourself you jive ass turkey fucking cunt!"""]]).toDF("text")
result = nlpPipeline.fit(data).transform(data)

We are using a trained model (`multiclassifierdl_use_toxic_sm`); let us get predictions for this text.

In [None]:
result.select(F.explode(F.arrays_zip(result.document.result, result.category.result)).alias("cols")) \
.select(F.expr("cols['0']").alias("document"),
        F.expr("cols['1']").alias("category")).show(truncate = False)

+-----------------------------------------------------------------------------------------------------------------------------------------------+-------------+
|document                                                                                                                                       |category     |
+-----------------------------------------------------------------------------------------------------------------------------------------------+-------------+
|F*ck You Shut the fuck up you whiney ass little bitch I'll do all the shit I want and you can go fuck yourself you jive ass turkey fucking cunt|toxic        |
|null                                                                                                                                           |severe_toxic |
|null                                                                                                                                           |identity_hate|
|null                                   

### **Parameters**

**`.extractParamMap`** will help us see the parameters, their definitions, default and current values.

In [None]:
docClassifier.extractParamMap()

{Param(parent='MultiClassifierDLModel_519f532b6628', name='engine', doc='Deep Learning engine used for this model'): 'tensorflow',
 Param(parent='MultiClassifierDLModel_519f532b6628', name='lazyAnnotator', doc='Whether this AnnotatorModel acts as lazy in RecursivePipelines'): False,
 Param(parent='MultiClassifierDLModel_519f532b6628', name='threshold', doc='The minimum threshold for each label to be accepted. Default is 0.5'): 0.5,
 Param(parent='MultiClassifierDLModel_519f532b6628', name='classes', doc='get the tags used to trained this MultiClassifierDLModel'): ['toxic',
  'severe_toxic',
  'identity_hate',
  'insult',
  'obscene',
  'threat'],
 Param(parent='MultiClassifierDLModel_519f532b6628', name='inputCols', doc='previous annotations columns, if renamed'): ['use_embeddings'],
 Param(parent='MultiClassifierDLModel_519f532b6628', name='outputCol', doc='output annotation column. can be left default.'): 'category',
 Param(parent='MultiClassifierDLModel_519f532b6628', name='storageR

In [None]:
docClassifier.getClasses()

['toxic', 'severe_toxic', 'identity_hate', 'insult', 'obscene', 'threat']

In [None]:
docClassifier.getEngine()

'tensorflow'

In [None]:
docClassifier.getThreshold()

0.5

In [None]:
docClassifier.getStorageRef()

'tfhub_use'

### **`setThreshold`** 

This parameter sets the minimum threshold for the final result - use 0.2 instead of the default value of 0.6.

This way, classes with lower confidence will also be included to the predictions. 


In [None]:
document = DocumentAssembler()\
  .setInputCol("text")\
  .setOutputCol("document")

use = UniversalSentenceEncoder.pretrained() \
  .setInputCols(["document"])\
  .setOutputCol("use_embeddings")

docClassifier = MultiClassifierDLModel.pretrained("multiclassifierdl_use_toxic_sm") \
  .setInputCols(["use_embeddings"])\
  .setOutputCol("category")\
  .setThreshold(0.7)

nlpPipeline = Pipeline(stages=[document, use, docClassifier])

tfhub_use download started this may take some time.
Approximate size to download 923.7 MB
[OK!]
multiclassifierdl_use_toxic_sm download started this may take some time.
Approximate size to download 11.6 MB
[OK!]


In [None]:
data = spark.createDataFrame([["""F*ck You Shut the fuck up you whiney ass little bitch I'll do all the shit I want and you can go fuck yourself you jive ass turkey fucking cunt!"""]]).toDF("text")
result = nlpPipeline.fit(data).transform(data)

In [None]:
result.select(F.explode(F.arrays_zip(result.document.result, result.category.result)).alias("cols")) \
.select(F.expr("cols['0']").alias("document"),
        F.expr("cols['1']").alias("sentiment")).show(truncate = False)

+------------------------------------------------------------------------------------------------------------------------------------------------+------------+
|document                                                                                                                                        |sentiment   |
+------------------------------------------------------------------------------------------------------------------------------------------------+------------+
|F*ck You Shut the fuck up you whiney ass little bitch I'll do all the shit I want and you can go fuck yourself you jive ass turkey fucking cunt!|toxic       |
|null                                                                                                                                            |severe_toxic|
|null                                                                                                                                            |insult      |
|null                                   

Increasing the threshold removed **identity_hate** from the previous predictions list.

## **💻 MultiClassifierDLApproach**

`MultiClassifierDLApproach` is used for model training.

### Load the **Training** and **Testing Datasets**


In [None]:
!curl -O 'https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/toxic_comments/toxic_train.snappy.parquet'
!curl -O 'https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/toxic_comments/toxic_test.snappy.parquet'

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 2702k  100 2702k    0     0  1266k      0  0:00:02  0:00:02 --:--:-- 1266k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  289k  100  289k    0     0   186k      0  0:00:01  0:00:01 --:--:--  186k


In [None]:
trainDataset = spark.read.parquet("./toxic_train.snappy.parquet").repartition(120)
testDataset = spark.read.parquet("./toxic_test.snappy.parquet").repartition(10)

In [None]:
trainDataset.show(10, truncate = 100)

+----------------+----------------------------------------------------------------------------------------------------+-------------------------------------------------------------+
|              id|                                                                                                text|                                                       labels|
+----------------+----------------------------------------------------------------------------------------------------+-------------------------------------------------------------+
|e63f1cc4b0b9959f|                                                                              EAT SHIT HORSE FACE!!!|                                                      [toxic]|
|ed58abb40640f983|PN News\nYou mean the people here actually care about the talk page of a low importance-rated art...|                                                      [toxic]|
|a1237f726b5f5d89|Dude.\n\nPlace the following in the large text box:\n\nPlease stop rapin

In [None]:
trainDataset.count()

14620

In [None]:
testDataset.show(10, truncate = 100)

+----------------+----------------------------------------------------------------------------------------------------+-----------------------------------------------------+
|              id|                                                                                                text|                                               labels|
+----------------+----------------------------------------------------------------------------------------------------+-----------------------------------------------------+
|47d256dea1223d39|                 Vegan \n\nWhat in the hell does all that junk have to do with photos? 68.54.163.153|                                              [toxic]|
|5e0dea75de819976|                                                                      Fight Club! F**k Yeeaaaaahh!!!|                                     [toxic, obscene]|
|2f84caf5fd45353c|"\n\n Little quick on the trigger, ain't'cha bud? \n\nYou know, if you're not even going to give ...|           

In [None]:
testDataset.count()

1605

## **💻 Train a Model with `Bert Sentence Embeddings`**


Using `MultiClassifierDLApproach`, a model is trained. The parameters that were used for training the model are:



*   `setMaxEpochs`: 5
*   `setLe`: 1e-3
*   `setThreshold`: 0.5
*   `setValidationSplit`: 0.1
*   `setBatchSize`: 128
*   `setShufflePerEpoch`: True
*   `setValidationSplit`: 0.1



In [None]:
document = DocumentAssembler()\
              .setInputCol("text")\
              .setOutputCol("document")\
              .setCleanupMode("shrink")

bert_sent = BertSentenceEmbeddings.pretrained('sent_small_bert_L8_512')\
              .setInputCols(["document"])\
              .setOutputCol("sentence_embeddings")

multiClassifier = MultiClassifierDLApproach()\
              .setInputCols("sentence_embeddings")\
              .setOutputCol("category")\
              .setLabelColumn("labels")\
              .setBatchSize(128)\
              .setMaxEpochs(5)\
              .setLr(1e-3)\
              .setThreshold(0.5)\
              .setShufflePerEpoch(False)\
              .setEnableOutputLogs(True)\
              .setValidationSplit(0.1)

pipeline = Pipeline(stages = [document,
                              bert_sent,
                              multiClassifier])

sent_small_bert_L8_512 download started this may take some time.
Approximate size to download 149.1 MB
[OK!]


In [None]:
%%time
pipelineModel = pipeline.fit(trainDataset)

CPU times: user 6.45 s, sys: 706 ms, total: 7.16 s
Wall time: 18min 46s


Although the number of epochs was minimum (**5**) and Batch Size was quite high, it still took **18+** minutes to train the model. The accuracy is **88.5 %**, which has room for improvement.   



In [None]:
!cat ~/annotator_logs/{multiClassifier.uid}.log

Training started - epochs: 5 - learning_rate: 0.001 - batch_size: 128 - training_examples: 13158 - classes: 6
Epoch 0/5 - 3.78s - loss: 0.33723184 - acc: 0.8611509 - batches: 103
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 1/5 - 2.08s - loss: 0.29930413 - acc: 0.8785859 - batches: 103
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 2/5 - 1.93s - loss: 0.2917258 - acc: 0.8815477 - batches: 103
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 3/5 - 1.92s - loss: 0.2864562 - acc: 0.8833028 - batches: 103
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 4/5 - 1.89s - loss: 0.28217906 - acc: 0.88537127 - batches: 103
Quality on validation dataset (10.0%), validation examples = 1462 


In [None]:
preds = pipelineModel.transform(testDataset)

In [None]:
preds_df = preds.select('text','labels',"category.result").toPandas()

In [None]:
preds_df.head(10)

Unnamed: 0,text,labels,result
0,Vegan \n\nWhat in the hell does all that junk ...,[toxic],"[toxic, obscene]"
1,Fight Club! F**k Yeeaaaaahh!!!,"[toxic, obscene]","[toxic, obscene]"
2,"""\n\n Little quick on the trigger, ain't'cha b...",[toxic],[toxic]
3,Your user page indicates you're a left-wing li...,"[toxic, obscene, insult]",[toxic]
4,""" See all the many Google links, titled"""" Wik...",[toxic],[toxic]
5,"""\n\n LOL \n\nLOL. Seriously, """"BryanFromPalat...",[toxic],"[toxic, obscene]"
6,is it because it is of my naked mum having sex...,"[toxic, severe_toxic, obscene, insult, identit...","[toxic, severe_toxic, insult, obscene]"
7,")\na cowards site, that must stop changing thi...","[toxic, obscene]",[toxic]
8,"blow me, criticism IS constructive. \n\nBlow m...","[toxic, obscene, insult]","[toxic, insult, obscene]"
9,On account of the project deciding to ignore h...,"[toxic, obscene, insult]",[toxic]


## **🔎 Parameters in Detail**

### **`setBatchSize`**

In deep learning, batch size refers to the number of training examples used in one iteration of gradient descent. During training, the training data is divided into batches, and each batch is fed through the neural network to compute the loss and update the weights. 

`setBatchSize` is a hyperparameter that can be tuned to achieve the best performance. Increasing the batch size can reduce the overall training time, but it may also require more memory and computational resources. Conversely, reducing the batch size may increase the training time but it can reduce the memory usage and allow the model to fit into memory.

By default, the setBatchSize parameter is set to `64` in `MultiClassifierDLApproach`. This default value is suitable for small datasets, but larger datasets may require larger batch sizes for optimal performance.



### **`setLr`**

Lr stands for **Learning Rate**. It is a hyperparameter that determines the step size at which the model weights are updated during the optimization process in model training.

During training, the model adjusts its weights to minimize the loss function, which measures the difference between the predicted and actual labels for each input sentence. The learning rate determines the size of the weight update at each iteration. A larger learning rate results in larger weight updates, while a smaller learning rate results in smaller weight updates.

`setLr` parameter can significantly affect the training performance and the convergence speed of the model. A learning rate that is too high can cause the model to overshoot the optimal weights and fail to converge, while a learning rate that is too low can result in slow convergence or getting stuck in a local minimum.

The optimal value for the `setLr` parameter depends on the specific task and the characteristics of the dataset. It is often necessary to experiment with different learning rates to find the optimal value.

### **`setThreshold`**

`setThreshold` is used to set the classification threshold for the predicted probabilities of each class.

In multi-class classification, a model is trained to classify input data into one of several possible classes. After training, the model will predict a probability distribution over the classes for each input data point, where the predicted probability for each class represents the model's confidence that the input data point belongs to that class.

`setThreshold` is used to set a threshold value for the predicted probabilities. Predicted probabilities below this threshold will be considered as belonging to the negative class, while any probabilities above the threshold will be considered as belonging to the positive class. This threshold is used to convert the probability distribution output by the model into a binary classification result.


### **`setMaxEpochs`**

`setMaxEpochs` specifies the maximum number of training epochs (iterations) to run during the training process, after which the training will stop regardless of whether the model has converged or not.

An epoch is a complete iteration over the training data, which means that the model sees every example in the training data exactly once during that epoch. 

During training, the model parameters are updated based on the error between the predicted output and the true output for each example in the training data. As the number of epochs increases, the model has more opportunities to adjust its parameters and reduce its error on the training data.

Setting a larger number of epochs may improve the performance of the model, but it can also increase the training time and the risk of overfitting the model to the training data. Conversely, setting a smaller number of epochs may result in a faster training time but a suboptimal model performance.


### **`setShufflePerEpoch`**

`setShufflePerEpoch` is a parameter that specifies whether to shuffle the data at the beginning of each epoch during training of a machine learning model. 

When `setShufflePerEpoch` is set to true, the data is shuffled at the beginning of each epoch to randomize the order in which the data is presented to the machine learning algorithm. This can help prevent overfitting and improve the generalization of the model.


### **`setValidationSplit`**

`setValidationSplit` allows to specify the fraction of the training dataset that you want to use for validation during model training.

`setValidationSplit` allows to split the dataset into a training set and a validation set, where the model is trained on the training set and evaluated on the validation set during each epoch of training.

### **`setVerbose`**

`setVerbose` controls the amount of information displayed during the training process. This information can include metrics such as loss, accuracy, and other performance indicators that are calculated on the training data during each epoch.

When `setVerbose` is set to True, the training log will display more detailed information about the training process, including the loss and accuracy values for each epoch, the training time for each epoch, and other relevant metrics. 

When `setVerbose` is set to False (which is the default), the training log will be less verbose and will only display a summary of the training process, including the final accuracy value and the total training time.

### **`setEnableOutputLogs`**

`setEnableOutputLogs` allows to enable or disable the logging of various information during the training process.

When `setEnableOutputLogs` is `True`, the model will log information about the training progress, such as the number of epochs completed, the current loss value, and the F1 score. 

This information can be useful for monitoring the training process and evaluating the performance of the model.


### **`setEvaluationLogExtended`**

`setEvaluationLogExtended` enables or disables extended evaluation log during training.

During training, the performance of the model is evaluated on a validation set after each epoch. By default, the evaluation log contains only basic information such as precision, recall, and F1-score. 

If `setEvaluationLogExtended` is set to True, the evaluation log will also contain additional information such as per-entity precision, recall, and F1-score.

### **`setLabelColumn`**

`setLabelColumn` allows to specify the name of the column in the input DataFrame that contains the labels.

When `setLabelColumn` is called with the name of a valid column in the input DataFrame, the MultiClassifierDL will use the values in that column as the labels for the training data. 

### **`setOutputLogsPath`**

`setOutputLogsPath` allows to specify a path to a directory where the model will write the logs during the training process.

When `setEnableOutputLogs` is True and `setOutputLogsPath`("path/to/logs") are called, the model will log the training progress information to files in the specified directory. 

This can be useful for keeping track of the training progress of multiple models, comparing the performance of different models, and identifying potential issues or errors during the training process.

### **`setRandomSeed`**

`setRandomSeed` allows to specify the random seed used by the model during training.

When you train a model, the initialization of the model's weights and biases can have a significant impact on the model's performance. To ensure that your model's initialization is reproducible, you can set a random seed that is used by the model during training. 

By setting a random seed, you can ensure that your model is initialized in the same way each time you train it, which can help you to reproduce your results and debug any issues that you may encounter.

### **`setTestDataset`**

`setTestDataset` sets the test dataset for the model. The test dataset is a set of data that is not used during training, but is used to evaluate the performance of the model after training. 

It is important to evaluate the model on a separate test dataset to ensure that it is able to generalize well to new, unseen data.

### **`setConfigProtoBytes`**

`setConfigProtoBytes` allows users to set the TensorFlow configuration proto bytes for the underlying neural network used in the model.

The TensorFlow configuration proto bytes specify the configuration settings for the TensorFlow runtime, which can include options such as memory allocation, CPU/GPU usage, and other performance-related settings. 

### 💾 **Saving & Loading Back the Trained Model**

In [None]:
pipelineModel.stages

[DocumentAssembler_e79ff164cdc8,
 BERT_SENTENCE_EMBEDDINGS_3608c0d843af,
 MultiClassifierDLModel_fe82c46e669b]

In [None]:
# Save the Multilabel Classifier Model
pipelineModel.stages[2].write().overwrite().save('MultilabelClfBert')

In [None]:
# Load back the saved Multilabel Classifier Model
MultilabelClfModel = MultiClassifierDLModel.load('MultilabelClfBert')

In [None]:
# Generate prediction Pipeline with loaded Model 
ld_pipeline = Pipeline(stages=[document, bert_sent, MultilabelClfModel])
ld_pipeline_model = ld_pipeline.fit(spark.createDataFrame([['']]).toDF("text"))

In [None]:
# Apply Model Transform to testData
ld_preds = ld_pipeline_model.transform(testDataset)

In [None]:
ld_preds_df = ld_preds.select('text','labels',"category.result").toPandas()

In [None]:
ld_preds_df.head(10)

Unnamed: 0,text,labels,result
0,Vegan \n\nWhat in the hell does all that junk ...,[toxic],"[toxic, obscene]"
1,Fight Club! F**k Yeeaaaaahh!!!,"[toxic, obscene]","[toxic, obscene]"
2,"""\n\n Little quick on the trigger, ain't'cha b...",[toxic],[toxic]
3,Your user page indicates you're a left-wing li...,"[toxic, obscene, insult]",[toxic]
4,""" See all the many Google links, titled"""" Wik...",[toxic],[toxic]
5,"""\n\n LOL \n\nLOL. Seriously, """"BryanFromPalat...",[toxic],"[toxic, obscene]"
6,is it because it is of my naked mum having sex...,"[toxic, severe_toxic, obscene, insult, identit...","[toxic, severe_toxic, insult, obscene]"
7,")\na cowards site, that must stop changing thi...","[toxic, obscene]",[toxic]
8,"blow me, criticism IS constructive. \n\nBlow m...","[toxic, obscene, insult]","[toxic, insult, obscene]"
9,On account of the project deciding to ignore h...,"[toxic, obscene, insult]",[toxic]


## **💻 Retrain the Model with Different Parameters - 1**

The process of retraining a deep learning model involves finding the set of weights and biases that minimize the loss function, which is a measure of how well the model is able to predict the correct output given the input data. As a result, the model's performance and metrics will improve. In practice, finding the optimal set of weights and biases can be a complex and computationally intensive process that requires many iterations.

When we retrain a deep learning model with new parameters, we are essentially restarting this process from scratch, using a new set of weights and biases that may be better suited to the problem at hand. For example, we might adjust the learning rate, regularization parameters, or other hyperparameters of the model to improve its performance on a specific task or dataset.


**`setMaxEpochs`** was increased from **5** to **12**, and **`setBatchSize`** was decreased from **128** to **32**. 

The new model's accuracy jumped from **88.5 %** to **91.6 %**.

In [None]:
multiClassifier = MultiClassifierDLApproach()\
              .setInputCols("sentence_embeddings")\
              .setOutputCol("category")\
              .setLabelColumn("labels")\
              .setBatchSize(32)\
              .setMaxEpochs(12)\
              .setLr(1e-3)\
              .setThreshold(0.5)\
              .setShufflePerEpoch(False)\
              .setEnableOutputLogs(True)\
              .setValidationSplit(0.1)

pipeline = Pipeline(stages = [document,
                              bert_sent,
                              multiClassifier])

In [None]:
%%time
pipelineModel = pipeline.fit(trainDataset)

CPU times: user 7.93 s, sys: 701 ms, total: 8.63 s
Wall time: 25min 13s


Check the improvement of the metrics of the model after every epoch.

In [None]:
!cat ~/annotator_logs/{multiClassifier.uid}.log

Training started - epochs: 12 - learning_rate: 0.001 - batch_size: 32 - training_examples: 13158 - classes: 6
Epoch 0/12 - 6.78s - loss: 0.3167296 - acc: 0.86285996 - batches: 412
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 1/12 - 4.68s - loss: 0.291983 - acc: 0.8727441 - batches: 412
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 2/12 - 4.58s - loss: 0.28258383 - acc: 0.87811726 - batches: 412
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 3/12 - 4.73s - loss: 0.2742937 - acc: 0.8813112 - batches: 412
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 4/12 - 4.55s - loss: 0.26593325 - acc: 0.88559043 - batches: 412
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 5/12 - 4.55s - loss: 0.2573048 - acc: 0.8897209 - batches: 412
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 6/12 - 4.57s - loss: 0.24843219 - acc: 0.89404637 - batches: 412

### 💾 **Saving the Trained Model**

It is possible to save the trained model locally, with the chance to load and get predictions later.

In [None]:
# Save the Multilabel Classifier Model
pipelineModel.stages[2].write().overwrite().save('MultilabelClfBert_2')

In [None]:
# Load back the saved Multilabel Classifier Model
MultilabelClfModel = MultiClassifierDLModel.load('MultilabelClfBert_2')

In [None]:
# Generate prediction Pipeline with loaded Model 
ld_pipeline = Pipeline(stages=[document, bert_sent, MultilabelClfModel])
ld_pipeline_model = ld_pipeline.fit(spark.createDataFrame([['']]).toDF("text"))

In [None]:
# Apply Model Transform to testData
ld_preds = ld_pipeline_model.transform(testDataset)

In [None]:
ld_preds_df = ld_preds.select('text','labels',"category.result").toPandas()

In [None]:
ld_preds_df.head(10)

Unnamed: 0,text,labels,result
0,Vegan \n\nWhat in the hell does all that junk ...,[toxic],[toxic]
1,Fight Club! F**k Yeeaaaaahh!!!,"[toxic, obscene]","[toxic, obscene]"
2,"""\n\n Little quick on the trigger, ain't'cha b...",[toxic],[obscene]
3,Your user page indicates you're a left-wing li...,"[toxic, obscene, insult]",[toxic]
4,""" See all the many Google links, titled"""" Wik...",[toxic],[toxic]
5,"""\n\n LOL \n\nLOL. Seriously, """"BryanFromPalat...",[toxic],"[toxic, obscene]"
6,is it because it is of my naked mum having sex...,"[toxic, severe_toxic, obscene, insult, identit...","[toxic, severe_toxic, insult, obscene]"
7,")\na cowards site, that must stop changing thi...","[toxic, obscene]",[toxic]
8,"blow me, criticism IS constructive. \n\nBlow m...","[toxic, obscene, insult]","[toxic, obscene]"
9,On account of the project deciding to ignore h...,"[toxic, obscene, insult]",[toxic]


## **💻 Retrain the Model with Different Parameters - 2**

This time, the `setLr` was set to **1e-4** (1/10 of the starting value) and the new model's accuracy decreased from **88.5** % to **87.9** %.

In [None]:
multiClassifier = MultiClassifierDLApproach()\
              .setInputCols("sentence_embeddings")\
              .setOutputCol("category")\
              .setLabelColumn("labels")\
              .setBatchSize(32)\
              .setMaxEpochs(12)\
              .setLr(1e-4)\
              .setThreshold(0.5)\
              .setShufflePerEpoch(False)\
              .setEnableOutputLogs(True)\
              .setValidationSplit(0.1)

pipeline = Pipeline(stages = [document,
                              bert_sent,
                              multiClassifier])

In [None]:
%%time
pipelineModel = pipeline.fit(trainDataset)

CPU times: user 7.72 s, sys: 748 ms, total: 8.47 s
Wall time: 24min 48s


In [None]:
!cat ~/annotator_logs/{multiClassifier.uid}.log

Training started - epochs: 12 - learning_rate: 1.0E-4 - batch_size: 32 - training_examples: 13158 - classes: 6
Epoch 0/12 - 6.46s - loss: 0.35708442 - acc: 0.84854025 - batches: 412
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 1/12 - 4.59s - loss: 0.3079388 - acc: 0.86855 - batches: 412
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 2/12 - 4.75s - loss: 0.29906648 - acc: 0.8720389 - batches: 412
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 3/12 - 4.57s - loss: 0.29448482 - acc: 0.8734325 - batches: 412
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 4/12 - 4.51s - loss: 0.2914722 - acc: 0.8747889 - batches: 412
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 5/12 - 4.60s - loss: 0.28924784 - acc: 0.87547344 - batches: 412
Quality on validation dataset (10.0%), validation examples = 1462 
Epoch 6/12 - 4.50s - loss: 0.28749964 - acc: 0.8764237 - batches: 412

### 💾 **Saving the Trained Model**

In [None]:
# Save the Multilabel Classifier Model
pipelineModel.stages[2].write().overwrite().save('MultilabelClfBert_3')

In [None]:
# Load back the saved Multilabel Classifier Model
MultilabelClfModel = MultiClassifierDLModel.load('MultilabelClfBert_3')

In [None]:
# Generate prediction Pipeline with loaded Model 
ld_pipeline = Pipeline(stages=[document, bert_sent, MultilabelClfModel])
ld_pipeline_model = ld_pipeline.fit(spark.createDataFrame([['']]).toDF("text"))

In [None]:
# Apply Model Transform to testData
ld_preds = ld_pipeline_model.transform(testDataset)

In [None]:
ld_preds_df = ld_preds.select('text','labels',"category.result").toPandas()

In [None]:
ld_preds_df.head(10)

Unnamed: 0,text,labels,result
0,Vegan \n\nWhat in the hell does all that junk ...,[toxic],"[toxic, obscene]"
1,Fight Club! F**k Yeeaaaaahh!!!,"[toxic, obscene]","[toxic, obscene]"
2,"""\n\n Little quick on the trigger, ain't'cha b...",[toxic],[toxic]
3,Your user page indicates you're a left-wing li...,"[toxic, obscene, insult]",[toxic]
4,""" See all the many Google links, titled"""" Wik...",[toxic],[toxic]
5,"""\n\n LOL \n\nLOL. Seriously, """"BryanFromPalat...",[toxic],"[toxic, obscene]"
6,is it because it is of my naked mum having sex...,"[toxic, severe_toxic, obscene, insult, identit...","[toxic, severe_toxic, insult, obscene]"
7,")\na cowards site, that must stop changing thi...","[toxic, obscene]",[toxic]
8,"blow me, criticism IS constructive. \n\nBlow m...","[toxic, obscene, insult]","[toxic, insult, obscene]"
9,On account of the project deciding to ignore h...,"[toxic, obscene, insult]",[toxic]
