![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/open-source-nlp/05.1.Text_Classification_Examples_in_SparkML_SparkNLP.ipynb)

# Text Classification with Spark NLP

In [None]:
! pip install -q pyspark==3.3.0 spark-nlp==5.0.0

In [2]:
import os
import sys

import sparknlp

from sparknlp.base import *
from sparknlp.common import *
from sparknlp.annotator import *

from pyspark.ml import Pipeline
from pyspark.sql import SparkSession

import pandas as pd

spark = sparknlp.start()

print("Spark NLP version: ", sparknlp.version())
print("Apache Spark version: ", spark.version)

spark

Spark NLP version:  5.0.0
Apache Spark version:  3.3.0


In [None]:
! wget https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Public/data/news_category_train.csv
! wget https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Public/data/news_category_test.csv

In [4]:
# newsDF = spark.read.parquet("data/news_category.parquet") >> if it is a parquet

newsDF = spark.read \
      .option("header", True) \
      .csv("news_category_train.csv")

newsDF.show(truncate=50)

+--------+--------------------------------------------------+
|category|                                       description|
+--------+--------------------------------------------------+
|Business| Short sellers, Wall Street's dwindling band of...|
|Business| Private investment firm Carlyle Group, which h...|
|Business| Soaring crude prices plus worries about the ec...|
|Business| Authorities have halted oil export flows from ...|
|Business| Tearaway world oil prices, toppling records an...|
|Business| Stocks ended slightly higher on Friday but sta...|
|Business| Assets of the nation's retail money market mut...|
|Business| Retail sales bounced back a bit in July, and n...|
|Business|" After earning a PH.D. in Sociology, Danny Baz...|
|Business| Short sellers, Wall Street's dwindling  band o...|
|Business| Soaring crude prices plus worries  about the e...|
|Business| OPEC can do nothing to douse scorching  oil pr...|
|Business| Non OPEC oil exporters should consider  increa...|
|Busines

In [5]:
newsDF.take(2)

[Row(category='Business', description=" Short sellers, Wall Street's dwindling band of ultra cynics, are seeing green again."),
 Row(category='Business', description=' Private investment firm Carlyle Group, which has a reputation for making well timed and occasionally controversial plays in the defense industry, has quietly placed its bets on another part of the market.')]

In [6]:
from pyspark.sql.functions import col

newsDF.groupBy("category") \
    .count() \
    .orderBy(col("count").desc()) \
    .show()

+--------+-----+
|category|count|
+--------+-----+
|   World|30000|
|Sci/Tech|30000|
|  Sports|30000|
|Business|30000|
+--------+-----+



## Building Classification Pipeline

### LogReg with CountVectorizer

Tokenizer: Tokenization

stopwordsRemover: Remove Stop Words

countVectors: Count vectors (“document-term vectors”)

In [7]:
from pyspark.ml.feature import CountVectorizer, HashingTF, IDF, OneHotEncoder, StringIndexer, VectorAssembler, SQLTransformer

In [8]:
%%time

document_assembler = DocumentAssembler() \
      .setInputCol("description") \
      .setOutputCol("document")

tokenizer = Tokenizer() \
      .setInputCols(["document"]) \
      .setOutputCol("token")

normalizer = Normalizer() \
      .setInputCols(["token"]) \
      .setOutputCol("normalized")

stopwords_cleaner = StopWordsCleaner()\
      .setInputCols("normalized")\
      .setOutputCol("cleanTokens")\
      .setCaseSensitive(False)

stemmer = Stemmer() \
      .setInputCols(["cleanTokens"]) \
      .setOutputCol("stem")

finisher = Finisher() \
      .setInputCols(["stem"]) \
      .setOutputCols(["token_features"]) \
      .setOutputAsArray(True) \
      .setCleanAnnotations(False)

countVectors = CountVectorizer(inputCol="token_features", outputCol="features", vocabSize=10000, minDF=5)

label_stringIdx = StringIndexer(inputCol = "category", outputCol = "label")

nlp_pipeline = Pipeline(
    stages=[document_assembler,
            tokenizer,
            normalizer,
            stopwords_cleaner,
            stemmer,
            finisher,
            countVectors,
            label_stringIdx])

nlp_model = nlp_pipeline.fit(newsDF)

processed = nlp_model.transform(newsDF)

processed.count()

CPU times: user 358 ms, sys: 67.8 ms, total: 426 ms
Wall time: 51.3 s


120000

In [9]:
processed.select('description','token_features').show(truncate=50)

+--------------------------------------------------+--------------------------------------------------+
|                                       description|                                    token_features|
+--------------------------------------------------+--------------------------------------------------+
| Short sellers, Wall Street's dwindling band of...|[short, seller, wall, street, dwindl, band, ult...|
| Private investment firm Carlyle Group, which h...|[privat, invest, firm, carlyl, group, reput, ma...|
| Soaring crude prices plus worries about the ec...|[soar, crude, price, plu, worri, economi, outlo...|
| Authorities have halted oil export flows from ...|[author, halt, oil, export, flow, main, pipelin...|
| Tearaway world oil prices, toppling records an...|[tearawai, world, oil, price, toppl, record, st...|
| Stocks ended slightly higher on Friday but sta...|[stock, end, slightli, higher, fridai, staye, n...|
| Assets of the nation's retail money market mut...|[asset, nati

In [10]:
processed.select('token_features').take(2)

[Row(token_features=['short', 'seller', 'wall', 'street', 'dwindl', 'band', 'ultra', 'cynic', 'see', 'green']),
 Row(token_features=['privat', 'invest', 'firm', 'carlyl', 'group', 'reput', 'make', 'well', 'time', 'occasion', 'controversi', 'plai', 'defens', 'industri', 'quietli', 'place', 'bet', 'anoth', 'part', 'market'])]

In [11]:
processed.select('features').take(2)

[Row(features=SparseVector(10000, {241: 1.0, 384: 1.0, 467: 1.0, 743: 1.0, 838: 1.0, 2228: 1.0, 3676: 1.0, 6152: 1.0, 6233: 1.0})),
 Row(features=SparseVector(10000, {26: 1.0, 38: 1.0, 46: 1.0, 68: 1.0, 117: 1.0, 155: 1.0, 182: 1.0, 197: 1.0, 246: 1.0, 304: 1.0, 320: 1.0, 407: 1.0, 428: 1.0, 621: 1.0, 868: 1.0, 2361: 1.0, 2824: 1.0, 2863: 1.0, 6834: 1.0}))]

In [12]:
processed.select('description','features','label').show()

+--------------------+--------------------+-----+
|         description|            features|label|
+--------------------+--------------------+-----+
| Short sellers, W...|(10000,[241,384,4...|  0.0|
| Private investme...|(10000,[26,38,46,...|  0.0|
| Soaring crude pr...|(10000,[15,28,46,...|  0.0|
| Authorities have...|(10000,[0,32,35,4...|  0.0|
| Tearaway world o...|(10000,[1,2,11,28...|  0.0|
| Stocks ended sli...|(10000,[3,13,14,2...|  0.0|
| Assets of the na...|(10000,[0,4,10,15...|  0.0|
| Retail sales bou...|(10000,[0,1,10,15...|  0.0|
|" After earning a...|(10000,[98,99,125...|  0.0|
| Short sellers, W...|(10000,[241,384,4...|  0.0|
| Soaring crude pr...|(10000,[15,28,46,...|  0.0|
| OPEC can do noth...|(10000,[0,24,28,2...|  0.0|
| Non OPEC oil exp...|(10000,[0,21,28,3...|  0.0|
| WASHINGTON/NEW Y...|(10000,[2,4,13,14...|  0.0|
| The dollar tumbl...|(10000,[2,14,72,1...|  0.0|
|If you think you ...|(10000,[74,77,143...|  0.0|
|The purchasing po...|(10000,[46,54,167...|  0.0|


In [13]:
# set seed for reproducibility
(trainingData, testData) = processed.randomSplit([0.7, 0.3], seed = 100)
print("Training Dataset Count: " + str(trainingData.count()))
print("Test Dataset Count: " + str(testData.count()))

Training Dataset Count: 83915
Test Dataset Count: 36085


In [14]:
trainingData.printSchema()

root
 |-- category: string (nullable = true)
 |-- description: string (nullable = true)
 |-- document: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |    |-- value: string (valueContainsNull = true)
 |    |    |-- embeddings: array (nullable = true)
 |    |    |    |-- element: float (containsNull = false)
 |-- token: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- annotatorType: string (nullable = true)
 |    |    |-- begin: integer (nullable = false)
 |    |    |-- end: integer (nullable = false)
 |    |    |-- result: string (nullable = true)
 |    |    |-- metadata: map (nullable = true)
 |    |    |    |-- key: string
 |    |    |   

In [15]:
from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0)

lrModel = lr.fit(trainingData)

predictions = lrModel.transform(testData)

predictions.filter(predictions['prediction'] == 0) \
    .select("description","category","probability","label","prediction") \
    .orderBy("probability", ascending=False) \
    .show(n = 10, truncate = 30)

+------------------------------+--------+------------------------------+-----+----------+
|                   description|category|                   probability|label|prediction|
+------------------------------+--------+------------------------------+-----+----------+
|" U.S. blue chips declined ...|Business|[0.9969052069602881,0.00125...|  0.0|       0.0|
|" General Motors Corp. &lt;...|Business|[0.9953865621086483,0.00163...|  0.0|       0.0|
| The dollar paused on Tuesd...|Business|[0.9944763517606333,9.20409...|  0.0|       0.0|
|" Stocks slipped on Tuesday...|Business|[0.994329539223382,0.001814...|  0.0|       0.0|
| There is more to corporate...|Business|[0.9940249051491369,0.00453...|  0.0|       0.0|
| Consumer prices rose by a ...|   World|[0.9934177979655902,0.00127...|  3.0|       0.0|
|" Sears, Roebuck and Co. &l...|Business|[0.9930054244880596,0.00293...|  0.0|       0.0|
|" U.S. blue chip stocks fel...|Business|[0.9926442525545816,0.00241...|  0.0|       0.0|
|" U.S. st

In [16]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")

evaluator.evaluate(predictions)

0.8985162543484673

In [17]:
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
y_true = predictions.select("label")
y_true = y_true.toPandas()

y_pred = predictions.select("prediction")
y_pred = y_pred.toPandas()

In [18]:
y_pred.prediction.value_counts()

2.0    9269
1.0    9089
0.0    8986
3.0    8741
Name: prediction, dtype: int64

In [19]:
cnf_matrix = confusion_matrix(list(y_true.label.astype(int)), list(y_pred.prediction.astype(int)))
cnf_matrix

array([[7820,  895,   95,  303],
       [ 674, 7793,   89,  314],
       [  55,   86, 8787,   95],
       [ 437,  315,  298, 8029]])

In [20]:
print(classification_report(y_true.label, y_pred.prediction))
print(accuracy_score(y_true.label, y_pred.prediction))

              precision    recall  f1-score   support

         0.0       0.87      0.86      0.86      9113
         1.0       0.86      0.88      0.87      8870
         2.0       0.95      0.97      0.96      9023
         3.0       0.92      0.88      0.90      9079

    accuracy                           0.90     36085
   macro avg       0.90      0.90      0.90     36085
weighted avg       0.90      0.90      0.90     36085

0.8986836635721214


### LogReg with TFIDF

In [21]:
from pyspark.ml.feature import HashingTF, IDF

hashingTF = HashingTF(inputCol="token_features", outputCol="rawFeatures", numFeatures=10000)

idf = IDF(inputCol="rawFeatures", outputCol="features", minDocFreq=5) #minDocFreq: remove sparse terms

nlp_pipeline_tf = Pipeline(
    stages=[document_assembler,
            tokenizer,
            normalizer,
            stopwords_cleaner,
            stemmer,
            finisher,
            hashingTF,
            idf,
            label_stringIdx])

nlp_model_tf = nlp_pipeline_tf.fit(newsDF)

processed_tf = nlp_model_tf.transform(newsDF)

processed_tf.count()


120000

In [22]:
# set seed for reproducibility
processed_tf.select('description','features','label').show()

+--------------------+--------------------+-----+
|         description|            features|label|
+--------------------+--------------------+-----+
| Short sellers, W...|(10000,[551,621,6...|  0.0|
| Private investme...|(10000,[157,831,9...|  0.0|
| Soaring crude pr...|(10000,[793,1738,...|  0.0|
| Authorities have...|(10000,[1548,1611...|  0.0|
| Tearaway world o...|(10000,[323,585,1...|  0.0|
| Stocks ended sli...|(10000,[453,609,6...|  0.0|
| Assets of the na...|(10000,[258,444,1...|  0.0|
| Retail sales bou...|(10000,[14,585,19...|  0.0|
|" After earning a...|(10000,[114,796,1...|  0.0|
| Short sellers, W...|(10000,[551,621,6...|  0.0|
| Soaring crude pr...|(10000,[793,1738,...|  0.0|
| OPEC can do noth...|(10000,[298,616,9...|  0.0|
| Non OPEC oil exp...|(10000,[616,1063,...|  0.0|
| WASHINGTON/NEW Y...|(10000,[360,832,1...|  0.0|
| The dollar tumbl...|(10000,[419,949,1...|  0.0|
|If you think you ...|(10000,[1041,2059...|  0.0|
|The purchasing po...|(10000,[901,2198,...|  0.0|


In [23]:
(trainingData, testData) = processed_tf.randomSplit([0.7, 0.3], seed = 100)
print("Training Dataset Count: " + str(trainingData.count()))
print("Test Dataset Count: " + str(testData.count()))

Training Dataset Count: 83915
Test Dataset Count: 36085


In [24]:
lrModel_tf = lr.fit(trainingData)

predictions_tf = lrModel_tf.transform(testData)

predictions_tf.select("description","category","probability","label","prediction") \
    .orderBy("probability", ascending=False) \
    .show(n = 10, truncate = 30)


+------------------------------+--------+------------------------------+-----+----------+
|                   description|category|                   probability|label|prediction|
+------------------------------+--------+------------------------------+-----+----------+
|" U.S. blue chip stocks fel...|Business|[0.9962587979479481,0.00118...|  0.0|       0.0|
|" Anadarko Petroleum Corp. ...|Business|[0.9941348956384896,0.00264...|  0.0|       0.0|
|" Stocks slipped on Tuesday...|Business|[0.9934262092957592,0.00197...|  0.0|       0.0|
|" Sears, Roebuck   Co. &lt;...|Business|[0.9922108515948802,0.00485...|  0.0|       0.0|
| Consumer prices rose by a ...|   World|[0.991409853187127,0.002043...|  3.0|       0.0|
| A sharp drop in oil prices...|   World|[0.9909956415078902,0.00194...|  3.0|       0.0|
| Federal Reserve policy mak...|   World|[0.9906819727506818,0.00334...|  3.0|       0.0|
|" Sears, Roebuck and Co. &l...|Business|[0.9904372420530858,0.00655...|  0.0|       0.0|
|" U.S. bl

In [25]:
y_true = predictions_tf.select("label")
y_true = y_true.toPandas()

y_pred = predictions_tf.select("prediction")
y_pred = y_pred.toPandas()

print(classification_report(y_true.label, y_pred.prediction))
print(accuracy_score(y_true.label, y_pred.prediction))

              precision    recall  f1-score   support

         0.0       0.86      0.84      0.85      9113
         1.0       0.85      0.86      0.86      8870
         2.0       0.93      0.97      0.95      9023
         3.0       0.91      0.88      0.89      9079

    accuracy                           0.89     36085
   macro avg       0.89      0.89      0.89     36085
weighted avg       0.89      0.89      0.89     36085

0.8875155881945407


### Random Forest with TFIDF

In [26]:
from pyspark.ml.classification import RandomForestClassifier

rf = RandomForestClassifier(labelCol="label", \
                            featuresCol="features", \
                            numTrees = 100, \
                            maxDepth = 4, \
                            maxBins = 32)

# Train model with Training Data
rfModel = rf.fit(trainingData)
predictions_rf = rfModel.transform(testData)


In [27]:
predictions_rf.select("description","category","probability","label","prediction") \
    .orderBy("probability", ascending=False) \
    .show(n = 10, truncate = 30)

+------------------------------+--------+------------------------------+-----+----------+
|                   description|category|                   probability|label|prediction|
+------------------------------+--------+------------------------------+-----+----------+
|" U.S. investment bank Merr...|Business|[0.39434943638706244,0.2189...|  0.0|       0.0|
| Japan's Nikkei average ros...|Business|[0.38889384355621304,0.2208...|  0.0|       0.0|
|" PeopleSoft Inc. &lt;A HRE...|Business|[0.38183169493923463,0.2400...|  0.0|       0.0|
|" Gilead Sciences Inc. &lt;...|Business|[0.37914489443669314,0.2281...|  0.0|       0.0|
|Ryanair, the Irish no-frill...|Business|[0.37898461894558694,0.2192...|  0.0|       0.0|
| Genentech Inc.  on Wednesd...|Business|[0.37694646177115687,0.2286...|  0.0|       0.0|
|" Goldman Sachs Group Inc. ...|Business|[0.3767215901372718,0.22562...|  0.0|       0.0|
|US investment bank Lehman B...|Business|[0.3760576885735656,0.22644...|  0.0|       0.0|
|US invest

In [28]:
y_true = predictions_rf.select("label")
y_true = y_true.toPandas()

y_pred = predictions_rf.select("prediction")
y_pred = y_pred.toPandas()

print(classification_report(y_true.label, y_pred.prediction))
print(accuracy_score(y_true.label, y_pred.prediction))

              precision    recall  f1-score   support

         0.0       0.79      0.64      0.71      9113
         1.0       0.61      0.77      0.68      8870
         2.0       0.81      0.85      0.83      9023
         3.0       0.83      0.74      0.78      9079

    accuracy                           0.75     36085
   macro avg       0.76      0.75      0.75     36085
weighted avg       0.76      0.75      0.75     36085

0.7498129416655119


## LogReg with Spark NLP Glove Word Embeddings

In [29]:
document_assembler = DocumentAssembler() \
      .setInputCol("description") \
      .setOutputCol("document")

tokenizer = Tokenizer() \
      .setInputCols(["document"]) \
      .setOutputCol("token")

normalizer = Normalizer() \
      .setInputCols(["token"]) \
      .setOutputCol("normalized")

stopwords_cleaner = StopWordsCleaner()\
      .setInputCols("normalized")\
      .setOutputCol("cleanTokens")\
      .setCaseSensitive(False)

glove_embeddings = WordEmbeddingsModel().pretrained() \
      .setInputCols(["document",'cleanTokens'])\
      .setOutputCol("embeddings")\
      .setCaseSensitive(False)

embeddingsSentence = SentenceEmbeddings() \
      .setInputCols(["document", "embeddings"]) \
      .setOutputCol("sentence_embeddings") \
      .setPoolingStrategy("AVERAGE")

embeddings_finisher = EmbeddingsFinisher() \
      .setInputCols(["sentence_embeddings"]) \
      .setOutputCols(["finished_sentence_embeddings"]) \
      .setOutputAsVector(True)\
      .setCleanAnnotations(False)

explodeVectors = SQLTransformer(statement=
      "SELECT EXPLODE(finished_sentence_embeddings) AS features, * FROM __THIS__")

label_stringIdx = StringIndexer(inputCol = "category", outputCol = "label")


nlp_pipeline_w2v = Pipeline(
    stages=[document_assembler,
            tokenizer,
            normalizer,
            stopwords_cleaner,
            glove_embeddings,
            embeddingsSentence,
            embeddings_finisher,
            explodeVectors,
            label_stringIdx])

nlp_model_w2v = nlp_pipeline_w2v.fit(newsDF)

processed_w2v = nlp_model_w2v.transform(newsDF)

processed_w2v.count()


glove_100d download started this may take some time.
Approximate size to download 145.3 MB
[OK!]


120000

In [30]:
processed_w2v.columns

['features',
 'category',
 'description',
 'document',
 'token',
 'normalized',
 'cleanTokens',
 'embeddings',
 'sentence_embeddings',
 'finished_sentence_embeddings',
 'label']

In [31]:
processed_w2v.show(5)

+--------------------+--------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----------------------------+-----+
|            features|category|         description|            document|               token|          normalized|         cleanTokens|          embeddings| sentence_embeddings|finished_sentence_embeddings|label|
+--------------------+--------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+----------------------------+-----+
|[-0.1556767076253...|Business| Short sellers, W...|[{document, 0, 84...|[{token, 1, 5, Sh...|[{token, 1, 5, Sh...|[{token, 1, 5, Sh...|[{word_embeddings...|[{sentence_embedd...|        [[-0.155676707625...|  0.0|
|[-0.0144653050228...|Business| Private investme...|[{document, 0, 20...|[{token, 1, 7, Pr...|[{token, 1, 7, Pr...|[{token, 1, 7, Pr...|[{word_e

In [32]:
processed_w2v.select('finished_sentence_embeddings').take(1)

[Row(finished_sentence_embeddings=[DenseVector([-0.1557, 0.196, 0.1099, -0.3089, 0.16, 0.1672, -0.4649, -0.1101, -0.053, -0.1551, 0.0327, 0.0772, 0.1494, -0.1865, 0.1155, -0.0597, 0.0234, -0.0451, 0.2361, -0.0089, 0.3358, 0.0444, 0.0088, -0.1453, 0.2289, 0.0914, -0.1665, -0.3726, 0.1892, 0.121, 0.1993, -0.0239, -0.1346, 0.1159, 0.2086, 0.1285, 0.068, 0.1372, 0.3153, -0.1934, 0.0257, -0.226, -0.0984, 0.1139, 0.1413, -0.3743, 0.072, 0.1403, 0.251, -0.3106, 0.1709, -0.0697, -0.0554, 0.5123, -0.1873, -1.7784, 0.0295, 0.1014, 0.9268, 0.2129, -0.1354, 0.5739, -0.0679, 0.461, 0.4216, 0.0225, 0.4456, -0.2462, 0.1411, -0.3258, 0.0025, 0.0114, -0.3895, -0.1106, -0.261, 0.0147, 0.0781, 0.1268, -0.2042, -0.2278, 0.5096, 0.1539, -0.3515, -0.0102, -0.7003, -0.3872, -0.1668, -0.2405, -0.0766, 0.1396, -0.0592, -0.1568, -0.1606, -0.1371, -0.684, -0.2549, -0.1541, 0.1536, 0.2715, 0.3342])])]

In [33]:
# IF SQLTransformer IS NOT USED INSIDE THE PIPELINE, WE CAN EXPLODE OUTSIDE
from pyspark.sql.functions import explode

# processed_w2v= processed_w2v.withColumn("features", explode(processed_w2v.finished_sentence_embeddings))

In [34]:
processed_w2v.select("features").take(1)

[Row(features=DenseVector([-0.1557, 0.196, 0.1099, -0.3089, 0.16, 0.1672, -0.4649, -0.1101, -0.053, -0.1551, 0.0327, 0.0772, 0.1494, -0.1865, 0.1155, -0.0597, 0.0234, -0.0451, 0.2361, -0.0089, 0.3358, 0.0444, 0.0088, -0.1453, 0.2289, 0.0914, -0.1665, -0.3726, 0.1892, 0.121, 0.1993, -0.0239, -0.1346, 0.1159, 0.2086, 0.1285, 0.068, 0.1372, 0.3153, -0.1934, 0.0257, -0.226, -0.0984, 0.1139, 0.1413, -0.3743, 0.072, 0.1403, 0.251, -0.3106, 0.1709, -0.0697, -0.0554, 0.5123, -0.1873, -1.7784, 0.0295, 0.1014, 0.9268, 0.2129, -0.1354, 0.5739, -0.0679, 0.461, 0.4216, 0.0225, 0.4456, -0.2462, 0.1411, -0.3258, 0.0025, 0.0114, -0.3895, -0.1106, -0.261, 0.0147, 0.0781, 0.1268, -0.2042, -0.2278, 0.5096, 0.1539, -0.3515, -0.0102, -0.7003, -0.3872, -0.1668, -0.2405, -0.0766, 0.1396, -0.0592, -0.1568, -0.1606, -0.1371, -0.684, -0.2549, -0.1541, 0.1536, 0.2715, 0.3342]))]

In [35]:
processed_w2v.select("features").take(1)

[Row(features=DenseVector([-0.1557, 0.196, 0.1099, -0.3089, 0.16, 0.1672, -0.4649, -0.1101, -0.053, -0.1551, 0.0327, 0.0772, 0.1494, -0.1865, 0.1155, -0.0597, 0.0234, -0.0451, 0.2361, -0.0089, 0.3358, 0.0444, 0.0088, -0.1453, 0.2289, 0.0914, -0.1665, -0.3726, 0.1892, 0.121, 0.1993, -0.0239, -0.1346, 0.1159, 0.2086, 0.1285, 0.068, 0.1372, 0.3153, -0.1934, 0.0257, -0.226, -0.0984, 0.1139, 0.1413, -0.3743, 0.072, 0.1403, 0.251, -0.3106, 0.1709, -0.0697, -0.0554, 0.5123, -0.1873, -1.7784, 0.0295, 0.1014, 0.9268, 0.2129, -0.1354, 0.5739, -0.0679, 0.461, 0.4216, 0.0225, 0.4456, -0.2462, 0.1411, -0.3258, 0.0025, 0.0114, -0.3895, -0.1106, -0.261, 0.0147, 0.0781, 0.1268, -0.2042, -0.2278, 0.5096, 0.1539, -0.3515, -0.0102, -0.7003, -0.3872, -0.1668, -0.2405, -0.0766, 0.1396, -0.0592, -0.1568, -0.1606, -0.1371, -0.684, -0.2549, -0.1541, 0.1536, 0.2715, 0.3342]))]

In [36]:
processed_w2v.select('description','features','label').show()


+--------------------+--------------------+-----+
|         description|            features|label|
+--------------------+--------------------+-----+
| Short sellers, W...|[-0.1556767076253...|  0.0|
| Private investme...|[-0.0144653050228...|  0.0|
| Soaring crude pr...|[0.10348732769489...|  0.0|
| Authorities have...|[-0.0355810523033...|  0.0|
| Tearaway world o...|[0.00647281948477...|  0.0|
| Stocks ended sli...|[0.20069395005702...|  0.0|
| Assets of the na...|[0.38012433052062...|  0.0|
| Retail sales bou...|[0.20352847874164...|  0.0|
|" After earning a...|[0.13536226749420...|  0.0|
| Short sellers, W...|[-0.1556767076253...|  0.0|
| Soaring crude pr...|[0.10348732769489...|  0.0|
| OPEC can do noth...|[0.20307321846485...|  0.0|
| Non OPEC oil exp...|[0.09010648727416...|  0.0|
| WASHINGTON/NEW Y...|[0.10887209326028...|  0.0|
| The dollar tumbl...|[0.05723679438233...|  0.0|
|If you think you ...|[0.11463439464569...|  0.0|
|The purchasing po...|[0.05890964344143...|  0.0|


In [37]:
# set seed for reproducibility
(trainingData, testData) = processed_w2v.randomSplit([0.7, 0.3], seed = 100)
print("Training Dataset Count: " + str(trainingData.count()))
print("Test Dataset Count: " + str(testData.count()))

Training Dataset Count: 83915
Test Dataset Count: 36085


In [38]:
from pyspark.sql.functions import udf

@udf("long")
def num_nonzeros(v):
    return v.numNonzeros()

testData = testData.where(num_nonzeros("features") != 0)

In [39]:
lrModel_w2v = lr.fit(trainingData)

In [40]:
predictions_w2v = lrModel_w2v.transform(testData)

predictions_w2v.select("description","category","probability","label","prediction") \
    .orderBy("probability", ascending=False) \
    .show(n = 10, truncate = 30)


+------------------------------+--------+------------------------------+-----+----------+
|                   description|category|                   probability|label|prediction|
+------------------------------+--------+------------------------------+-----+----------+
| Brokerage Bear Stearns Com...|Business|[0.9892404374132758,0.00767...|  0.0|       0.0|
| Stocks fell on Monday, wit...|Business|[0.9855914603655052,0.00997...|  0.0|       0.0|
| ChevronTexaco Corp., the N...|Business|[0.9824069757843499,0.01250...|  0.0|       0.0|
|The steel tubing company re...|Business|[0.9814641269307299,0.01515...|  0.0|       0.0|
|Hutchison Whampoa said it w...|Business|[0.9813978140836207,0.01716...|  0.0|       0.0|
|  Shares of rival retailers...|Business|[0.9813376426629606,0.01243...|  0.0|       0.0|
| Tokyo stocks opened lower ...|Business|[0.9813017652925261,0.00686...|  0.0|       0.0|
|The London Stock Exchange (...|Business|[0.9809459968203622,0.01502...|  0.0|       0.0|
| Brokerag

In [41]:
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import pandas as pd

y_true = predictions_w2v.select("label")
y_true = y_true.toPandas()

y_pred = predictions_w2v.select("prediction")
y_pred = y_pred.toPandas()

print(classification_report(y_true.label, y_pred.prediction))
print(accuracy_score(y_true.label, y_pred.prediction))

              precision    recall  f1-score   support

         0.0       0.83      0.82      0.82      8999
         1.0       0.83      0.82      0.82      8961
         2.0       0.93      0.96      0.94      9086
         3.0       0.88      0.86      0.87      9039

    accuracy                           0.87     36085
   macro avg       0.86      0.87      0.86     36085
weighted avg       0.86      0.87      0.86     36085

0.8653179991686296


In [42]:
processed_w2v.select('description','cleanTokens.result').show(truncate=50)

+--------------------------------------------------+--------------------------------------------------+
|                                       description|                                            result|
+--------------------------------------------------+--------------------------------------------------+
| Short sellers, Wall Street's dwindling band of...|[Short, sellers, Wall, Streets, dwindling, band...|
| Private investment firm Carlyle Group, which h...|[Private, investment, firm, Carlyle, Group, rep...|
| Soaring crude prices plus worries about the ec...|[Soaring, crude, prices, plus, worries, economy...|
| Authorities have halted oil export flows from ...|[Authorities, halted, oil, export, flows, main,...|
| Tearaway world oil prices, toppling records an...|[Tearaway, world, oil, prices, toppling, record...|
| Stocks ended slightly higher on Friday but sta...|[Stocks, ended, slightly, higher, Friday, staye...|
| Assets of the nation's retail money market mut...|[Assets, nat

## LogReg with Spark NLP Bert Embeddings

In [8]:
document_assembler = DocumentAssembler() \
      .setInputCol("description") \
      .setOutputCol("document")

tokenizer = Tokenizer() \
      .setInputCols(["document"]) \
      .setOutputCol("token")

normalizer = Normalizer() \
      .setInputCols(["token"]) \
      .setOutputCol("normalized")

stopwords_cleaner = StopWordsCleaner()\
      .setInputCols("normalized")\
      .setOutputCol("cleanTokens")\
      .setCaseSensitive(False)

bert_embeddings = BertEmbeddings.pretrained('bert_base_cased', 'en') \
      .setInputCols(["document",'cleanTokens'])\
      .setOutputCol("bert")\
      .setCaseSensitive(False)\

embeddingsSentence = SentenceEmbeddings() \
      .setInputCols(["document", "bert"]) \
      .setOutputCol("sentence_embeddings") \
      .setPoolingStrategy("AVERAGE")

embeddings_finisher = EmbeddingsFinisher() \
      .setInputCols(["sentence_embeddings"]) \
      .setOutputCols(["finished_sentence_embeddings"]) \
      .setOutputAsVector(True)\
      .setCleanAnnotations(False)

label_stringIdx = StringIndexer(inputCol = "category", outputCol = "label")


nlp_pipeline_bert = Pipeline(
    stages=[document_assembler,
            tokenizer,
            normalizer,
            stopwords_cleaner,
            bert_embeddings,
            embeddingsSentence,
            embeddings_finisher,
            label_stringIdx])



bert_base_cased download started this may take some time.
Approximate size to download 384.9 MB
[OK!]


In [15]:
%%time
limited_df = newsDF.limit(10000)

nlp_model_bert = nlp_pipeline_bert.fit(limited_df)

processed_bert = nlp_model_bert.transform(limited_df)

processed_bert.count()

CPU times: user 16.6 s, sys: 2.09 s, total: 18.7 s
Wall time: 52min 42s


10000

In [16]:
from pyspark.sql.functions import explode

processed_bert = processed_bert.withColumn("features", explode(processed_bert.finished_sentence_embeddings))

processed_bert.select('description','features','label').show()


+--------------------+--------------------+-----+
|         description|            features|label|
+--------------------+--------------------+-----+
| Short sellers, W...|[-0.0012149482499...|  2.0|
| Private investme...|[0.13144019246101...|  2.0|
| Soaring crude pr...|[-0.1905521601438...|  2.0|
| Authorities have...|[0.06882479041814...|  2.0|
| Tearaway world o...|[-0.1174716278910...|  2.0|
| Stocks ended sli...|[-0.0321817845106...|  2.0|
| Assets of the na...|[-0.2906664013862...|  2.0|
| Retail sales bou...|[-0.0385283492505...|  2.0|
|" After earning a...|[-0.0362812504172...|  2.0|
| Short sellers, W...|[-0.0012149482499...|  2.0|
| Soaring crude pr...|[-0.1905521601438...|  2.0|
| OPEC can do noth...|[-0.1431127935647...|  2.0|
| Non OPEC oil exp...|[0.01600192859768...|  2.0|
| WASHINGTON/NEW Y...|[0.14494347572326...|  2.0|
| The dollar tumbl...|[-0.1958881020545...|  2.0|
|If you think you ...|[0.27292791008949...|  2.0|
|The purchasing po...|[0.00386757543310...|  2.0|


In [17]:
# set seed for reproducibility
(trainingData, testData) = processed_bert.randomSplit([0.7, 0.3], seed = 100)
print("Training Dataset Count: " + str(trainingData.count()))
print("Test Dataset Count: " + str(testData.count()))

Training Dataset Count: 7033
Test Dataset Count: 2967


In [18]:
from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(maxIter=20, regParam=0.3, elasticNetParam=0)

lrModel = lr.fit(trainingData)


In [19]:
from pyspark.sql.functions import udf

@udf("long")
def num_nonzeros(v):
    return v.numNonzeros()

testData = testData.where(num_nonzeros("features") != 0)

In [20]:
predictions = lrModel.transform(testData)

predictions.select("description","category","probability","label","prediction") \
    .orderBy("probability", ascending=False) \
    .show(n = 10, truncate = 30)


+------------------------------+--------+------------------------------+-----+----------+
|                   description|category|                   probability|label|prediction|
+------------------------------+--------+------------------------------+-----+----------+
|Wise Solutions has released...|Sci/Tech|[0.9972450764044677,2.61175...|  0.0|       0.0|
|Microsoft has a massive pat...|Sci/Tech|[0.9969462975287583,0.00172...|  0.0|       0.0|
|Microsoft Corp. has publish...|Sci/Tech|[0.9969021631884314,4.95202...|  0.0|       0.0|
|A worm that has the capabil...|Sci/Tech|[0.9962258919592463,9.54592...|  0.0|       0.0|
|Microsoft has made availabl...|Sci/Tech|[0.9960325986626795,5.09907...|  0.0|       0.0|
|Release makes use of techno...|Sci/Tech|[0.9960054398557852,3.32463...|  0.0|       0.0|
|Macromedia hopes to boost u...|Sci/Tech|[0.9957959611167083,0.00123...|  0.0|       0.0|
|Software giant releases jus...|Sci/Tech|[0.9954924952585451,3.43349...|  0.0|       0.0|
| Sleepyca

In [21]:
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import pandas as pd

df = predictions.select('description','category','label','prediction').toPandas()

print(classification_report(df.label, df.prediction))
print(accuracy_score(df.label, df.prediction))

              precision    recall  f1-score   support

         0.0       0.84      0.84      0.84       796
         1.0       0.87      0.83      0.85       726
         2.0       0.81      0.80      0.81       738
         3.0       0.89      0.95      0.92       707

    accuracy                           0.85      2967
   macro avg       0.85      0.85      0.85      2967
weighted avg       0.85      0.85      0.85      2967

0.8533872598584429


## LogReg with ELMO Embeddings

In [8]:
%%time

document_assembler = DocumentAssembler() \
      .setInputCol("description") \
      .setOutputCol("document")

tokenizer = Tokenizer() \
      .setInputCols(["document"]) \
      .setOutputCol("token")

normalizer = Normalizer() \
      .setInputCols(["token"]) \
      .setOutputCol("normalized")

stopwords_cleaner = StopWordsCleaner()\
      .setInputCols("normalized")\
      .setOutputCol("cleanTokens")\
      .setCaseSensitive(False)

elmo_embeddings = ElmoEmbeddings.pretrained()\
      .setPoolingLayer("word_emb")\
      .setInputCols(["document",'cleanTokens'])\
      .setOutputCol("elmo")

embeddingsSentence = SentenceEmbeddings() \
      .setInputCols(["document", "elmo"]) \
      .setOutputCol("sentence_embeddings") \
      .setPoolingStrategy("AVERAGE")

embeddings_finisher = EmbeddingsFinisher() \
      .setInputCols(["sentence_embeddings"]) \
      .setOutputCols(["finished_sentence_embeddings"]) \
      .setOutputAsVector(True)\
      .setCleanAnnotations(False)

label_stringIdx = StringIndexer(inputCol = "category", outputCol = "label")


nlp_pipeline_elmo = Pipeline(
    stages=[document_assembler,
            tokenizer,
            normalizer,
            stopwords_cleaner,
            elmo_embeddings,
            embeddingsSentence,
            embeddings_finisher,
            label_stringIdx])

nlp_model_elmo = nlp_pipeline_elmo.fit(newsDF)

processed_elmo = nlp_model_elmo.transform(newsDF)

processed_elmo.count()


elmo download started this may take some time.
Approximate size to download 334.1 MB
[OK!]
CPU times: user 355 ms, sys: 65.1 ms, total: 420 ms
Wall time: 43 s


120000

In [9]:
(trainingData, testData) = newsDF.randomSplit([0.7, 0.3], seed = 100)

In [10]:
processed_trainingData = nlp_model_elmo.transform(trainingData)

processed_trainingData.count()

83915

In [11]:
processed_testData = nlp_model_elmo.transform(testData)

processed_testData.count()

36085

In [12]:
processed_trainingData.columns

['category',
 'description',
 'document',
 'token',
 'normalized',
 'cleanTokens',
 'elmo',
 'sentence_embeddings',
 'finished_sentence_embeddings',
 'label']

In [13]:
from pyspark.sql.functions import explode

processed_testData= processed_testData.withColumn("features", explode(processed_testData.finished_sentence_embeddings))

processed_trainingData= processed_trainingData.withColumn("features", explode(processed_trainingData.finished_sentence_embeddings))

In [14]:
from pyspark.sql.functions import udf

@udf("long")
def num_nonzeros(v):
    return v.numNonzeros()

processed_testData = processed_testData.where(num_nonzeros("features") != 0)

In [15]:
%%time

from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(maxIter=20, regParam=0.3, elasticNetParam=0)

lrModel = lr.fit(processed_trainingData)


CPU times: user 13.8 s, sys: 1.33 s, total: 15.1 s
Wall time: 43min 40s


In [16]:
processed_trainingData.columns

['category',
 'description',
 'document',
 'token',
 'normalized',
 'cleanTokens',
 'elmo',
 'sentence_embeddings',
 'finished_sentence_embeddings',
 'label',
 'features']

In [17]:
predictions = lrModel.transform(processed_testData)

predictions.select("description","category","probability","label","prediction") \
    .orderBy("probability", ascending=False) \
    .show(10, truncate = 30)

+------------------------------+--------+------------------------------+-----+----------+
|                   description|category|                   probability|label|prediction|
+------------------------------+--------+------------------------------+-----+----------+
|" Exxon Mobil Corp. &lt;A H...|Business|[0.9942678089128859,0.00369...|  0.0|       0.0|
|" Exxon Mobil Corp. &lt;A H...|Business|[0.9942678089128859,0.00369...|  0.0|       0.0|
| Falling oil prices and str...|   World|[0.9929108385820713,0.00659...|  3.0|       0.0|
| Discount retailer Dollar G...|Business|[0.9923224927163741,0.00580...|  0.0|       0.0|
|" Halliburton Co. &lt;A HRE...|Business|[0.9916542041897508,0.00488...|  0.0|       0.0|
| Kmart Holding Corporation,...|Business|[0.9910315092004528,0.00713...|  0.0|       0.0|
|US stocks gained on optimis...|Business|[0.9905003947189976,0.00683...|  0.0|       0.0|
| A bankruptcy judge gave US...|   World|[0.990093559326872,0.005094...|  3.0|       0.0|
|The conve

In [18]:
df = predictions.select('description','category','label','prediction').toPandas()

In [19]:
df.shape

(36085, 4)

In [20]:
df.head()

Unnamed: 0,description,category,label,prediction
0,A Colorado assistant store manager at Costco...,Business,0.0,0.0
1,A group led by privately held Colony Capital...,Business,0.0,0.0
2,A group of technology companies Tuesday rene...,Business,0.0,0.0
3,"AMP Ltd., Australia #39;s largest life insur...",Business,0.0,0.0
4,"About 8,000 employees of the federal tax age...",Business,0.0,0.0


In [21]:
from sklearn.metrics import classification_report, accuracy_score

print(classification_report(df.label, df.prediction))
print(accuracy_score(df.label, df.prediction))

              precision    recall  f1-score   support

         0.0       0.83      0.82      0.83      9113
         1.0       0.82      0.82      0.82      8870
         2.0       0.93      0.96      0.94      9023
         3.0       0.88      0.87      0.88      9079

    accuracy                           0.87     36085
   macro avg       0.87      0.87      0.87     36085
weighted avg       0.87      0.87      0.87     36085

0.8673132880698351


## LogReg with Universal Sentence Encoder

In [22]:
useEmbeddings = UniversalSentenceEncoder.pretrained()\
      .setInputCols("document")\
      .setOutputCol("use_embeddings")

tfhub_use download started this may take some time.
Approximate size to download 923.7 MB
[OK!]


In [23]:
document_assembler = DocumentAssembler() \
      .setInputCol("description") \
      .setOutputCol("document")

loaded_useEmbeddings = UniversalSentenceEncoder.load('/root/cache_pretrained/tfhub_use_en_2.4.0_2.4_1587136330099')\
      .setInputCols("document")\
      .setOutputCol("use_embeddings")

embeddings_finisher = EmbeddingsFinisher() \
      .setInputCols(["use_embeddings"]) \
      .setOutputCols(["finished_use_embeddings"]) \
      .setOutputAsVector(True)\
      .setCleanAnnotations(False)

label_stringIdx = StringIndexer(inputCol = "category", outputCol = "label")

use_pipeline = Pipeline(
      stages=[
        document_assembler,
        loaded_useEmbeddings,
        embeddings_finisher,
        label_stringIdx]
      )

use_df = use_pipeline.fit(newsDF).transform(newsDF)

In [24]:
use_df.select('finished_use_embeddings').show(3)

+-----------------------+
|finished_use_embeddings|
+-----------------------+
|   [[0.0441501475870...|
|   [[0.0844451859593...|
|   [[0.0426647253334...|
+-----------------------+
only showing top 3 rows



In [25]:
from pyspark.sql.functions import explode

use_df= use_df.withColumn("features", explode(use_df.finished_use_embeddings))

In [26]:
use_df.show(2)

+--------+--------------------+--------------------+--------------------+-----------------------+-----+--------------------+
|category|         description|            document|      use_embeddings|finished_use_embeddings|label|            features|
+--------+--------------------+--------------------+--------------------+-----------------------+-----+--------------------+
|Business| Short sellers, W...|[{document, 0, 84...|[{sentence_embedd...|   [[0.0441501475870...|  0.0|[0.04415014758706...|
|Business| Private investme...|[{document, 0, 20...|[{sentence_embedd...|   [[0.0844451859593...|  0.0|[0.08444518595933...|
+--------+--------------------+--------------------+--------------------+-----------------------+-----+--------------------+
only showing top 2 rows



In [27]:
# set seed for reproducibility
(trainingData, testData) = use_df.randomSplit([0.7, 0.3], seed = 100)
print("Training Dataset Count: " + str(trainingData.count()))
print("Test Dataset Count: " + str(testData.count()))

Training Dataset Count: 83915
Test Dataset Count: 36085


In [28]:
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import pandas as pd

from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(maxIter=20, regParam=0.3, elasticNetParam=0)

lrModel = lr.fit(trainingData)

predictions = lrModel.transform(testData)

predictions.filter(predictions['prediction'] == 0) \
    .select("description","category","probability","label","prediction") \
    .orderBy("probability", ascending=False) \
    .show(n = 10, truncate = 30)


+------------------------------+--------+------------------------------+-----+----------+
|                   description|category|                   probability|label|prediction|
+------------------------------+--------+------------------------------+-----+----------+
|Amid talk of a possible liq...|Business|[0.983377407597205,0.006717...|  0.0|       0.0|
|" U.S. investment bank Merr...|Business|[0.9823928340543989,0.00874...|  0.0|       0.0|
| Safeway Inc. , the third l...|Business|[0.9822568234049623,0.00963...|  0.0|       0.0|
|" Stock futures pointed to ...|Business|[0.981956069993429,0.008847...|  0.0|       0.0|
|Prudential Financial Inc., ...|Business|[0.9809333303052401,0.01135...|  0.0|       0.0|
| Wall Street was expected t...|Business|[0.9808441091350967,0.00762...|  0.0|       0.0|
|Financial services company ...|Business|[0.9801851472566468,0.00980...|  0.0|       0.0|
|" U.S. investment bank Morg...|Business|[0.9801017321319347,0.00483...|  0.0|       0.0|
| The U.S.

In [29]:
df = predictions.select('description','category','label','prediction').toPandas()
#df['result'] = df['result'].apply(lambda x: x[0])

In [30]:
df.head()

Unnamed: 0,description,category,label,prediction
0,A Colorado assistant store manager at Costco...,Business,0.0,0.0
1,A group led by privately held Colony Capital...,Business,0.0,0.0
2,A group of technology companies Tuesday rene...,Business,0.0,0.0
3,"AMP Ltd., Australia #39;s largest life insur...",Business,0.0,0.0
4,"About 8,000 employees of the federal tax age...",Business,0.0,0.0


In [31]:
print(classification_report(df.label, df.prediction))
print(accuracy_score(df.label, df.prediction))

              precision    recall  f1-score   support

         0.0       0.84      0.83      0.83      9113
         1.0       0.84      0.84      0.84      8870
         2.0       0.95      0.97      0.96      9023
         3.0       0.90      0.88      0.89      9079

    accuracy                           0.88     36085
   macro avg       0.88      0.88      0.88     36085
weighted avg       0.88      0.88      0.88     36085

0.8827490647083276


### train on entire dataset

In [32]:
from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(maxIter=20, regParam=0.3, elasticNetParam=0)

lrModel = lr.fit(use_df)

In [33]:
test_df = spark.read.csv("./news_category_test.csv",header=True)

In [34]:
test_df = use_pipeline.fit(test_df).transform(test_df)

In [35]:
test_df= test_df.withColumn("features", explode(test_df.finished_use_embeddings))

In [36]:
test_df.show(2)

+--------+--------------------+--------------------+--------------------+-----------------------+-----+--------------------+
|category|         description|            document|      use_embeddings|finished_use_embeddings|label|            features|
+--------+--------------------+--------------------+--------------------+-----------------------+-----+--------------------+
|Business|Unions representi...|[{document, 0, 12...|[{sentence_embedd...|   [[0.0129975611343...|  0.0|[0.01299756113439...|
|Sci/Tech| TORONTO, Canada ...|[{document, 0, 22...|[{sentence_embedd...|   [[0.0019998808857...|  1.0|[0.00199988088570...|
+--------+--------------------+--------------------+--------------------+-----------------------+-----+--------------------+
only showing top 2 rows



In [37]:
from pyspark.sql.functions import col

test_df.groupBy("category","label") \
    .count() \
    .orderBy(col("count").desc()) \
    .show()

+--------+-----+-----+
|category|label|count|
+--------+-----+-----+
|Sci/Tech|  1.0| 1900|
|  Sports|  2.0| 1900|
|   World|  3.0| 1900|
|Business|  0.0| 1900|
+--------+-----+-----+



In [38]:
predictions = lrModel.transform(test_df)

In [39]:
df = predictions.select('description','category','label','prediction').toPandas()

In [40]:
df['label'] = df.category.replace({'World':3.0,
                    'Sports':2.0,
                    'Business':0.0,
                    'Sci/Tech':1.0})

In [41]:
df.head()

Unnamed: 0,description,category,label,prediction
0,Unions representing workers at Turner Newall...,Business,0.0,0.0
1,"TORONTO, Canada A second team of rocketeer...",Sci/Tech,1.0,1.0
2,A company founded by a chemistry researcher a...,Sci/Tech,1.0,1.0
3,It's barely dawn when Mike Fitzpatrick starts...,Sci/Tech,1.0,1.0
4,Southern California's smog fighting agency we...,Sci/Tech,1.0,0.0


In [42]:
from sklearn.metrics import classification_report, accuracy_score

print(classification_report(df.label, df.prediction))
print(accuracy_score(df.label, df.prediction))

              precision    recall  f1-score   support

         0.0       0.83      0.83      0.83      1900
         1.0       0.84      0.85      0.85      1900
         2.0       0.95      0.97      0.96      1900
         3.0       0.90      0.87      0.89      1900

    accuracy                           0.88      7600
   macro avg       0.88      0.88      0.88      7600
weighted avg       0.88      0.88      0.88      7600

0.8801315789473684


# ClassifierDL

In [43]:
# actual content is inside description column
document = DocumentAssembler()\
    .setInputCol("description")\
    .setOutputCol("document")

use = UniversalSentenceEncoder.load('/root/cache_pretrained/tfhub_use_en_2.4.0_2.4_1587136330099')\
      .setInputCols("document")\
      .setOutputCol("sentence_embeddings")

# the classes/labels/categories are in category column
classsifierdl = ClassifierDLApproach()\
    .setInputCols(["sentence_embeddings"])\
    .setOutputCol("class")\
    .setLabelColumn("category")\
    .setMaxEpochs(5)\
    .setEnableOutputLogs(True)

pipeline = Pipeline(
    stages = [
        document,
        use,
        classsifierdl
    ])

In [44]:
# set seed for reproducibility
(trainingData, testData) = newsDF.randomSplit([0.7, 0.3], seed = 100)
print("Training Dataset Count: " + str(trainingData.count()))
print("Test Dataset Count: " + str(testData.count()))

Training Dataset Count: 83915
Test Dataset Count: 36085


In [45]:
pipelineModel = pipeline.fit(trainingData)

In [46]:
pipelineModel.stages[2].write().overwrite().save('classifierDL_model_5e')

In [47]:
from sklearn.metrics import classification_report, accuracy_score

df = pipelineModel.transform(testData).select('category','description',"class.result").toPandas()

df['result'] = df['result'].apply(lambda x: x[0])

print(classification_report(df.category, df.result))
print(accuracy_score(df.category, df.result))

              precision    recall  f1-score   support

    Business       0.84      0.86      0.85      9113
    Sci/Tech       0.85      0.86      0.86      8870
      Sports       0.95      0.98      0.97      9023
       World       0.93      0.87      0.90      9079

    accuracy                           0.89     36085
   macro avg       0.90      0.89      0.89     36085
weighted avg       0.90      0.89      0.89     36085

0.8946376610780102


## Loading the trained classifier from disk

In [48]:
import sparknlp
sparknlp.__path__

['/usr/local/lib/python3.10/dist-packages/sparknlp']

In [49]:
trainDataset = spark.read \
      .option("header", True) \
      .csv("news_category_train.csv")

In [50]:
trainDataset.count()

120000

In [51]:
trainingData.count()

83915

In [52]:
document = DocumentAssembler()\
    .setInputCol("description")\
    .setOutputCol("document")

use = UniversalSentenceEncoder.load('/root/cache_pretrained/tfhub_use_en_2.4.0_2.4_1587136330099')\
      .setInputCols("document")\
      .setOutputCol("sentence_embeddings")

classsifierdlmodel = ClassifierDLModel.load('classifierDL_model_5e')

pipeline = Pipeline(
    stages = [
        document,
        use,
        classsifierdlmodel
    ])

In [53]:
pipeline.fit(testData.limit(10)).transform(testData.limit(10)).select('category','description',"class.result").show(10, truncate=50)

+--------+--------------------------------------------------+----------+
|category|                                       description|    result|
+--------+--------------------------------------------------+----------+
|Business|  A Colorado assistant store manager at Costco ...|[Business]|
|Business|  A group led by privately held Colony Capital ...|[Business]|
|Business|  A group of technology companies Tuesday renew...|[Business]|
|Business|  AMP Ltd., Australia #39;s largest life insure...|[Business]|
|Business|  About 8,000 employees of the federal tax agen...|[Business]|
|Business|  After winning a battle to keep his job earlie...|[Business]|
|Business|  Air Canada creditors including a General Elec...|[Business]|
|Business|  Americans paid their credit card bills on tim...|[Business]|
|Business|  Andrew Mohl (left), chief executive of AMP Lt...|[Business]|
|Business|  BEIJING, Sept.12 -- A senior United States tr...|   [World]|
+--------+-----------------------------------------

In [54]:
lm = LightPipeline(pipeline.fit(spark.createDataFrame([[""]]).toDF("text")))
lm.annotate('In its first two years, the UK dedicated card companies have surge')

{'document': ['In its first two years, the UK dedicated card companies have surge'],
 'sentence_embeddings': ['In its first two years, the UK dedicated card companies have surge'],
 'class': ['Sci/Tech']}

In [55]:
text='''
Fearing the fate of Italy, the centre-right government has threatened to be merciless with those who flout tough restrictions. As of Wednesday it will also include all shops being closed across Greece, with the exception of supermarkets. Banks, pharmacies, pet-stores, mobile phone stores, opticians, bakers, mini-markets, couriers and food delivery outlets are among the few that will also be allowed to remain open.
'''

In [56]:
lm = LightPipeline(pipeline.fit(spark.createDataFrame([[""]]).toDF("text")))

lm.annotate(text)

{'document': ['\nFearing the fate of Italy, the centre-right government has threatened to be merciless with those who flout tough restrictions. As of Wednesday it will also include all shops being closed across Greece, with the exception of supermarkets. Banks, pharmacies, pet-stores, mobile phone stores, opticians, bakers, mini-markets, couriers and food delivery outlets are among the few that will also be allowed to remain open.\n'],
 'sentence_embeddings': ['\nFearing the fate of Italy, the centre-right government has threatened to be merciless with those who flout tough restrictions. As of Wednesday it will also include all shops being closed across Greece, with the exception of supermarkets. Banks, pharmacies, pet-stores, mobile phone stores, opticians, bakers, mini-markets, couriers and food delivery outlets are among the few that will also be allowed to remain open.\n'],
 'class': ['Business']}

# Classifier DL + Glove + Basic text processing

In [57]:
document = DocumentAssembler()\
    .setInputCol("description")\
    .setOutputCol("document")

tokenizer = Tokenizer() \
      .setInputCols(["document"]) \
      .setOutputCol("token")

lemma = LemmatizerModel.pretrained('lemma_antbnc') \
      .setInputCols(["token"]) \
      .setOutputCol("lemma")

glove_embeddings = WordEmbeddingsModel().pretrained() \
      .setInputCols(["document",'lemma'])\
      .setOutputCol("embeddings")\
      .setCaseSensitive(False)

lemma_pipeline = Pipeline(
    stages=[document,
            tokenizer,
            lemma,
            glove_embeddings])

lemma_antbnc download started this may take some time.
Approximate size to download 907.6 KB
[OK!]
glove_100d download started this may take some time.
Approximate size to download 145.3 MB
[OK!]


In [58]:
lemma_pipeline.fit(trainingData.limit(1000)).transform(trainingData.limit(1000)).show(truncate=30)

+--------+------------------------------+------------------------------+------------------------------+------------------------------+------------------------------+
|category|                   description|                      document|                         token|                         lemma|                    embeddings|
+--------+------------------------------+------------------------------+------------------------------+------------------------------+------------------------------+
|Business|    The credit rating of th...|[{document, 0, 164,     The...|[{token, 4, 6, The, {senten...|[{token, 4, 6, The, {senten...|[{word_embeddings, 4, 6, Th...|
|Business|  ''The Oprah Winfrey Show ...|[{document, 0, 131,   ''The...|[{token, 2, 3, '', {sentenc...|[{token, 2, 3, '', {sentenc...|[{word_embeddings, 2, 3, ''...|
|Business|  A  $120 million fine levi...|[{document, 0, 278,   A  $1...|[{token, 2, 2, A, {sentence...|[{token, 2, 2, A, {sentence...|[{word_embeddings, 2, 2, A,...|
|Bus

In [59]:
document_assembler = DocumentAssembler() \
      .setInputCol("description") \
      .setOutputCol("document")

tokenizer = Tokenizer() \
      .setInputCols(["document"]) \
      .setOutputCol("token")

normalizer = Normalizer() \
      .setInputCols(["token"]) \
      .setOutputCol("normalized")

stopwords_cleaner = StopWordsCleaner()\
      .setInputCols("normalized")\
      .setOutputCol("cleanTokens")\
      .setCaseSensitive(False)

lemma = LemmatizerModel.pretrained('lemma_antbnc') \
      .setInputCols(["cleanTokens"]) \
      .setOutputCol("lemma")

glove_embeddings = WordEmbeddingsModel().pretrained() \
      .setInputCols(["document",'lemma'])\
      .setOutputCol("embeddings")\
      .setCaseSensitive(False)

embeddingsSentence = SentenceEmbeddings() \
      .setInputCols(["document", "embeddings"]) \
      .setOutputCol("sentence_embeddings") \
      .setPoolingStrategy("AVERAGE")

classsifierdl = ClassifierDLApproach()\
      .setInputCols(["sentence_embeddings"])\
      .setOutputCol("class")\
      .setLabelColumn("category")\
      .setMaxEpochs(5)\
      .setEnableOutputLogs(True)

clf_pipeline = Pipeline(
    stages=[document_assembler,
            tokenizer,
            normalizer,
            stopwords_cleaner,
            lemma,
            glove_embeddings,
            embeddingsSentence,
            classsifierdl])

lemma_antbnc download started this may take some time.
Approximate size to download 907.6 KB
[OK!]
glove_100d download started this may take some time.
Approximate size to download 145.3 MB
[OK!]


In [60]:
!rm -rf classifier_dl_pipeline_glove

In [61]:
clf_pipeline.save('classifier_dl_pipeline_glove')

In [62]:
clf_pipelineModel = clf_pipeline.fit(trainingData)

In [63]:
df = clf_pipelineModel.transform(testData).select('category','description',"class.result").toPandas()

df.head()

Unnamed: 0,category,description,result
0,Business,A Colorado assistant store manager at Costco...,[Business]
1,Business,A group led by privately held Colony Capital...,[Business]
2,Business,A group of technology companies Tuesday rene...,[Business]
3,Business,"AMP Ltd., Australia #39;s largest life insur...",[Business]
4,Business,"About 8,000 employees of the federal tax age...",[Business]


In [64]:
from sklearn.metrics import classification_report, accuracy_score

df['result'] = df['result'].apply(lambda x: x[0])

print(classification_report(df.category, df.result))

print(accuracy_score(df.category, df.result))

              precision    recall  f1-score   support

    Business       0.85      0.83      0.84      9113
    Sci/Tech       0.83      0.86      0.85      8870
      Sports       0.95      0.97      0.96      9023
       World       0.91      0.88      0.89      9079

    accuracy                           0.88     36085
   macro avg       0.88      0.88      0.88     36085
weighted avg       0.88      0.88      0.88     36085

0.88363585977553


In [65]:
import pandas as pd

In [66]:
news_df = newsDF.toPandas()

In [67]:
news_df.head()

Unnamed: 0,category,description
0,Business,"Short sellers, Wall Street's dwindling band o..."
1,Business,"Private investment firm Carlyle Group, which ..."
2,Business,Soaring crude prices plus worries about the e...
3,Business,Authorities have halted oil export flows from...
4,Business,"Tearaway world oil prices, toppling records a..."


In [68]:
news_df.to_csv('news_dataset.csv', index=False)

In [69]:
document_assembler = DocumentAssembler() \
      .setInputCol("description") \
      .setOutputCol("document")

tokenizer = Tokenizer() \
      .setInputCols(["document"]) \
      .setOutputCol("token")

normalizer = Normalizer() \
      .setInputCols(["token"]) \
      .setOutputCol("normalized")

stopwords_cleaner = StopWordsCleaner()\
      .setInputCols("normalized")\
      .setOutputCol("cleanTokens")\
      .setCaseSensitive(False)

lemma = LemmatizerModel.pretrained('lemma_antbnc') \
      .setInputCols(["cleanTokens"]) \
      .setOutputCol("lemma")

glove_embeddings = WordEmbeddingsModel().pretrained() \
      .setInputCols(["document",'lemma'])\
      .setOutputCol("embeddings")\
      .setCaseSensitive(False)

txt_pipeline = Pipeline(
    stages=[document_assembler,
            tokenizer,
            normalizer,
            stopwords_cleaner,
            lemma,
            glove_embeddings,
            embeddingsSentence])

lemma_antbnc download started this may take some time.
Approximate size to download 907.6 KB
[OK!]
glove_100d download started this may take some time.
Approximate size to download 145.3 MB
[OK!]


In [70]:
txt_pipelineModel = txt_pipeline.fit(testData.limit(1))

In [71]:
txt_pipelineModel.save('text_prep_pipeline_glove')