![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

# Spark NLP
### Multi-class Text Classification
#### By using ClassifierDL

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/jupyter/training/english/classification/ClassifierDL_Train_multi_class_news_category_classifier.ipynb)

In [1]:
import os

# Install java
! apt-get install -y openjdk-8-jdk-headless -qq > /dev/null
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]
! java -version

# Install pyspark
! pip install --ignore-installed -q pyspark==2.4.4

# Install Spark NLP
! pip install --ignore-installed -q spark-nlp==2.4.4

openjdk version "1.8.0_242"
OpenJDK Runtime Environment (build 1.8.0_242-8u242-b08-0ubuntu3~18.04-b08)
OpenJDK 64-Bit Server VM (build 25.242-b08, mixed mode)


`UniversalSentenceEncoder` requires more `buffer.max` so we create the SparkSession manually:

In [2]:
import sparknlp
from pyspark.sql import SparkSession

def start():
    builder = SparkSession.builder \
        .appName("Spark NLP") \
        .master("local[*]") \
        .config("spark.driver.memory", "8G") \
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")\
        .config("spark.kryoserializer.buffer.max", "1000M")\
        .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.11:2.4.4")

    return builder.getOrCreate()

  
spark = start()

print("Spark NLP version")
sparknlp.version()
print("Apache Spark version")
spark.version

Spark NLP version
Apache Spark version


'2.4.4'

Let's download news category dataset for training our text classifier

In [5]:
!wget -O news_category_train.csv https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/news_Category/news_category_train.csv

--2020-03-16 16:34:57--  https://s3.amazonaws.com/auxdata.johnsnowlabs.com/public/resources/en/classifier-dl/news_Category/news_category_train.csv
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.216.133.85
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.216.133.85|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 24032125 (23M) [text/csv]
Saving to: ‘news_category_train.csv’


2020-03-16 16:35:01 (9.09 MB/s) - ‘news_category_train.csv’ saved [24032125/24032125]



In [6]:
!head news_category_train.csv

category,description
Business," Short sellers, Wall Street's dwindling band of ultra cynics, are seeing green again."
Business," Private investment firm Carlyle Group, which has a reputation for making well timed and occasionally controversial plays in the defense industry, has quietly placed its bets on another part of the market."
Business, Soaring crude prices plus worries about the economy and the outlook for earnings are expected to hang over the stock market next week during the depth of the summer doldrums.
Business," Authorities have halted oil export flows from the main pipeline in southern Iraq after intelligence showed a rebel militia could strike infrastructure, an oil official said on Saturday."
Business," Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections."
Business," Stocks ended slightly higher on Friday but stayed near lows for the year as oil prices surged past  #36;

The content is inside `description` column and the labels are inside `category` column

In [0]:
trainDataset = spark.read \
      .option("header", True) \
      .csv("news_category_train.csv")

In [4]:
trainDataset.show()

+--------+--------------------+
|category|         description|
+--------+--------------------+
|Business| Short sellers, W...|
|Business| Private investme...|
|Business| Soaring crude pr...|
|Business| Authorities have...|
|Business| Tearaway world o...|
|Business| Stocks ended sli...|
|Business| Assets of the na...|
|Business| Retail sales bou...|
|Business|" After earning a...|
|Business| Short sellers, W...|
|Business| Soaring crude pr...|
|Business| OPEC can do noth...|
|Business| Non OPEC oil exp...|
|Business| WASHINGTON/NEW Y...|
|Business| The dollar tumbl...|
|Business|If you think you ...|
|Business|The purchasing po...|
|Business|There is little c...|
|Business|The US trade defi...|
|Business|Oil giant Shell c...|
+--------+--------------------+
only showing top 20 rows



In [0]:
from pyspark.ml import Pipeline

from sparknlp.annotator import *
from sparknlp.common import *
from sparknlp.base import *

In [6]:
# actual content is inside description column
document = DocumentAssembler()\
    .setInputCol("description")\
    .setOutputCol("document")

use = UniversalSentenceEncoder.pretrained() \
 .setInputCols(["document"])\
 .setOutputCol("sentence_embeddings")

# the classes/labels/categories are in category column
classsifierdl = ClassifierDLApproach()\
  .setInputCols(["sentence_embeddings"])\
  .setOutputCol("class")\
  .setLabelColumn("category")\
  .setMaxEpochs(5)\
  .setEnableOutputLogs(True)

pipeline = Pipeline(
    stages = [
        document,
        use,
        classsifierdl
    ])

tfhub_use download started this may take some time.
[ | ]
[OK!]


In [0]:
pipelineModel = pipeline.fit(trainDataset)

In [8]:
!cd ~/annotator_logs && ls -l

total 4
-rw-r--r-- 1 root root 532 Mar 16 16:52 ClassifierDLApproach_b0a3535b0393.log


In [9]:
!cat ~/annotator_logs/ClassifierDLApproach_b0a3535b0393.log

Training started - total epochs: 5 - learning rate: 0.005 - batch size: 64 - training examples: 120000
Epoch 0/5 - 36.11660808%.2fs - loss: 1633.0354 - accuracy: 0.8799833 - batches: 1875
Epoch 1/5 - 36.381269208%.2fs - loss: 1602.1971 - accuracy: 0.8913083 - batches: 1875
Epoch 2/5 - 34.888333409%.2fs - loss: 1589.8265 - accuracy: 0.896175 - batches: 1875
Epoch 3/5 - 34.897568042%.2fs - loss: 1586.0842 - accuracy: 0.89944166 - batches: 1875
Epoch 4/5 - 34.750174157%.2fs - loss: 1582.5492 - accuracy: 0.9021583 - batches: 1875


In [0]:
from pyspark.sql.types import StringType

dfTest = spark.createDataFrame([
    "Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.",
    "Scientists have discovered irregular lumps beneath the icy surface of Jupiter's largest moon, Ganymede. These irregular masses may be rock formations, supported by Ganymede's icy shell for billions of years..."
], StringType()).toDF("description")

In [0]:
prediction = pipelineModel.transform(dfTest)

In [24]:
prediction.select("class.result").show()

prediction.select("class.metadata").show(truncate=False)

+----------+
|    result|
+----------+
|[Business]|
|[Sci/Tech]|
+----------+

+-----------------------------------------------------------------------------------------------------------------+
|metadata                                                                                                         |
+-----------------------------------------------------------------------------------------------------------------+
|[[Sports -> 1.6808018E-6, Business -> 0.99998975, World -> 7.434875E-6, Sci/Tech -> 1.1887713E-6, sentence -> 0]]|
|[[Sports -> 8.665945E-15, Business -> 2.2782523E-14, World -> 8.526451E-15, Sci/Tech -> 1.0, sentence -> 0]]     |
+-----------------------------------------------------------------------------------------------------------------+

