![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

# Spark NLP
### Multi-class Text Classification
#### By using ClassifierDL

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/jupyter/training/english/classification/ClassifierDL_Train_and_Evaluate.ipynb)

Only run this block if you are inside Google Colab otherwise skip it

In [None]:
# This is only to setup PySpark and Spark NLP on Colab
!wget http://setup.johnsnowlabs.com/colab.sh -O - | bash

In this notebook we are going to check the training logs on the fly. Thus, we start a session with `real_time_output=True`

In [5]:
import sparknlp

spark = sparknlp.start(real_time_output=True)

print("Spark NLP version: ", sparknlp.version())
print("Apache Spark version; ", spark.version)

Spark NLP version:  4.1.0
Apache Spark version;  3.2.1


Let's download news category dataset for training our text classifier

In [6]:
!wget -O news_category_train.csv https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/news_Category/news_category_train.csv

--2022-09-23 17:48:38--  https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/news_Category/news_category_train.csv
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.160.208
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.160.208|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 24032125 (23M) [text/csv]
Saving to: ‘news_category_train.csv’


2022-09-23 17:48:38 (102 MB/s) - ‘news_category_train.csv’ saved [24032125/24032125]



In [7]:
!wget -O news_category_test.csv https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/news_Category/news_category_test.csv

--2022-09-23 17:48:38--  https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/news_Category/news_category_test.csv
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.160.208
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.160.208|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1504408 (1.4M) [text/csv]
Saving to: ‘news_category_test.csv’


2022-09-23 17:48:39 (27.1 MB/s) - ‘news_category_test.csv’ saved [1504408/1504408]



In [8]:
!head news_category_train.csv

category,description
Business," Short sellers, Wall Street's dwindling band of ultra cynics, are seeing green again."
Business," Private investment firm Carlyle Group, which has a reputation for making well timed and occasionally controversial plays in the defense industry, has quietly placed its bets on another part of the market."
Business, Soaring crude prices plus worries about the economy and the outlook for earnings are expected to hang over the stock market next week during the depth of the summer doldrums.
Business," Authorities have halted oil export flows from the main pipeline in southern Iraq after intelligence showed a rebel militia could strike infrastructure, an oil official said on Saturday."
Business," Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections."
Business," Stocks ended slightly higher on Friday but stayed near lows for the year as oil prices surged past  #36;

The content is inside `description` column and the labels are inside `category` column

In [9]:
trainDataset = spark.read \
      .option("header", True) \
      .csv("news_category_train.csv")

In [10]:
trainDataset.show()

+--------+--------------------+
|category|         description|
+--------+--------------------+
|Business| Short sellers, W...|
|Business| Private investme...|
|Business| Soaring crude pr...|
|Business| Authorities have...|
|Business| Tearaway world o...|
|Business| Stocks ended sli...|
|Business| Assets of the na...|
|Business| Retail sales bou...|
|Business|" After earning a...|
|Business| Short sellers, W...|
|Business| Soaring crude pr...|
|Business| OPEC can do noth...|
|Business| Non OPEC oil exp...|
|Business| WASHINGTON/NEW Y...|
|Business| The dollar tumbl...|
|Business|If you think you ...|
|Business|The purchasing po...|
|Business|There is little c...|
|Business|The US trade defi...|
|Business|Oil giant Shell c...|
+--------+--------------------+
only showing top 20 rows



In [11]:
trainDataset.count()

120000

In [12]:
from pyspark.ml import Pipeline

from sparknlp.annotator import *
from sparknlp.common import *
from sparknlp.base import *

# Prepare TestDataset for Evaluation 

Let's evaluate our ClassifierDL model during training, saved it, and loaded it into a new pipeline by using a test dataset that model has never seen. To do this we first need to prepare a test dataset parquet file as shown below:

In [13]:
news_test_dataset = spark.read \
      .option("header", True) \
      .csv("news_category_test.csv")

In [14]:
document = DocumentAssembler()\
    .setInputCol("description")\
    .setOutputCol("document")

use = UniversalSentenceEncoder.pretrained() \
 .setInputCols(["document"])\
 .setOutputCol("sentence_embeddings")

pipeline = Pipeline(stages = [document,use])

test_dataset = pipeline.fit(news_test_dataset).transform(news_test_dataset)

tfhub_use download started this may take some time.
Approximate size to download 923.7 MB
[ | ]tfhub_use download started this may take some time.
Approximate size to download 923.7 MB
[ | ]Download done! Loading the resource.
[OK!]


In [15]:
test_dataset.show(2)

+--------+--------------------+--------------------+--------------------+
|category|         description|            document| sentence_embeddings|
+--------+--------------------+--------------------+--------------------+
|Business|Unions representi...|[{document, 0, 12...|[{sentence_embedd...|
|Sci/Tech| TORONTO, Canada ...|[{document, 0, 22...|[{sentence_embedd...|
+--------+--------------------+--------------------+--------------------+
only showing top 2 rows



Now, that out test dataset has the required embeddings, we save it as parquet and use it while training our ClassifierDL model.

In [16]:
test_dataset.write.parquet("./test_news.parquet")

Now let's train it and use a validation and the test dataset above for evaluation

In [17]:
classsifierdl = ClassifierDLApproach()\
  .setInputCols(["sentence_embeddings"])\
  .setOutputCol("class")\
  .setLabelColumn("category")\
  .setMaxEpochs(5)\
  .setEnableOutputLogs(True) \
  .setEvaluationLogExtended(True) \
  .setValidationSplit(0.2) \
  .setTestDataset("./test_news.parquet")

pipeline = Pipeline(
    stages = [
        document,
        use,
        classsifierdl
    ])

In [18]:
pipelineModel = pipeline.fit(trainDataset)

Training started - epochs: 5 - learning_rate: 0.005 - batch_size: 64 - training_examples: 96000 - classes: 4
Epoch 1/5 - 21.56s - loss: 1294.569 - acc: 0.8790208 - batches: 1500
Quality on validation dataset (20.0%), validation examples = 24000
time to finish evaluation: 1.38s
label      tp	 fp	 fn	 prec	 rec	 f1
Sci/Tech   5158	 945	 807	 0.8451581	 0.8647108	 0.85482264
Business   5003	 952	 1004	 0.8401343	 0.83286166	 0.83648217
Sports     5901	 270	 147	 0.956247	 0.9756944	 0.9658728
World      5235	 536	 745	 0.90712184	 0.87541807	 0.89098805
tp: 21297 fp: 2703 fn: 2703 labels: 4
Macro-average	 prec: 0.88716537, rec: 0.88717127, f1: 0.8871684
Micro-average	 prec: 0.887375, recall: 0.887375, f1: 0.887375
Quality on test dataset: 
time to finish evaluation: 0.35s
label      tp	 fp	 fn	 prec	 rec	 f1
Sci/Tech   1658	 322	 242	 0.83737373	 0.87263155	 0.8546392
Business   1569	 306	 331	 0.8368	 0.82578945	 0.8312583
Sports     1840	 81	 60	 0.9578345	 0.96842104	 0.9630987
World  

# How to use already trained ClassifierDL pipeline or its model

We have two ways of using what we already trained: pipeline or model.

Let's see how we can save the entire pipeline, load it, and do some prediction with that pre-trained pipeline.

## Save and load pre-trained ClassifierDL pipeline

In [19]:
# Google Colab is free so it comes with a little memory. 
# It's not possible to save and load in this notebook. But you can do this locally or in a decent machine!

# pipelineModel.save("./classifierdl_pipeline")
# loadedPipeline = PipelineModel.load("./classifierdl_pipeline")
# loadedPipeline.transform(YOUR_DATAFRAME)

# Save and load pre-trained ClassifierDL model

In [20]:
# dbfs:/ or hdfs:/ if you are saving it on distributed file systems
pipelineModel.stages[-1].write().overwrite().save('./tmp_classifierDL_model')



Let's use our pre-trained ClassifierDLModel in a pipeline: 

In [21]:

# In a new pipeline you can load it for prediction
document = DocumentAssembler()\
    .setInputCol("description")\
    .setOutputCol("document")

use = UniversalSentenceEncoder.pretrained() \
 .setInputCols(["document"])\
 .setOutputCol("sentence_embeddings")

classsifierdl = ClassifierDLModel.load("./tmp_classifierDL_model") \
  .setInputCols(["sentence_embeddings"])\
  .setOutputCol("class")

pipeline = Pipeline(
    stages = [
        document,
        use,
        classsifierdl
    ])


tfhub_use download started this may take some time.
Approximate size to download 923.7 MB
[OK!]


Now let's load it back so we can have prediction all together with everything in that pipeline:

In [22]:
from pyspark.sql.types import StringType

dfTest = spark.createDataFrame([
    "Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.",
    "Scientists have discovered irregular lumps beneath the icy surface of Jupiter's largest moon, Ganymede. These irregular masses may be rock formations, supported by Ganymede's icy shell for billions of years..."
], StringType()).toDF("description")

In [23]:
prediction = pipeline.fit(dfTest).transform(dfTest)

In [24]:
prediction.select("class.result").show()

prediction.select("class.metadata").show(truncate=False)

+----------+
|    result|
+----------+
|[Business]|
|[Sci/Tech]|
+----------+

+-----------------------------------------------------------------------------------------------------------------+
|metadata                                                                                                         |
+-----------------------------------------------------------------------------------------------------------------+
|[{Sports -> 2.753349E-6, Business -> 0.99998844, World -> 6.6571633E-6, Sci/Tech -> 2.1566113E-6, sentence -> 0}]|
|[{Sports -> 1.4710765E-14, Business -> 1.1435716E-13, World -> 2.8883496E-13, Sci/Tech -> 1.0, sentence -> 0}]   |
+-----------------------------------------------------------------------------------------------------------------+

