# Pyspark Spam Classification

In [1]:
!pip3 install pyspark

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


##Import and Connect SparkSession


In [2]:
import numpy as np 
import pandas as pd 
import os
import pyspark
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer
from pyspark.sql.functions import monotonically_increasing_id 
from pyspark.sql.functions import regexp_replace
from pyspark.ml.feature import Tokenizer
from pyspark.ml.feature import StopWordsRemover, HashingTF, IDF

spark = SparkSession.builder \
        .master('local[*]') \
        .appName('first_spark_application') \
        .getOrCreate()

In [3]:
# text_message = spark.read.csv("../input/spam-text-message-classification/SPAM text message 20170820 - Data.csv",header=True)
text_message = spark.read.csv("SPAM text message 20170820 - Data.csv",header=True,inferSchema=True)
text_message.show(truncate=False)

+--------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|Category|Message                                                                                                                                                                                             |
+--------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|ham     |Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...                                                                                     |
|ham     |Ok lar... Joking wif u oni...                                                                                                                                 

In [4]:
text_message.dtypes

[('Category', 'string'), ('Message', 'string')]

## Data Preparation

In [5]:
text_message.filter('Message IS NULL').count()

0

In [6]:
text_message.filter('Category IS NULL').count()

0

In [7]:
text_message.select('Category').distinct().show(truncate=False)

+--------------------------------------------------------------------------------------------------------------------------+
|Category                                                                                                                  |
+--------------------------------------------------------------------------------------------------------------------------+
|ham\tHI BABE UAWAKE?FEELLIKW SHIT.JUSTFOUND OUT VIA ALETTER THATMUM GOTMARRIED 4thNOV.BEHIND OURBACKS  FUCKINNICE!SELFISH|
|ham                                                                                                                       |
|spam                                                                                                                      |
|ham\tYeah                                                                                                                 |
+--------------------------------------------------------------------------------------------------------------------------+


In [8]:
text_message = text_message.filter( (text_message.Category  == "ham") | (text_message.Category  == "spam") ) 

In [9]:
# change Category to one hot label
tag = StringIndexer(inputCol = 'Category' , outputCol = 'label' )
tag = tag.fit(text_message)
msg_one_hot = tag.transform(text_message)

In [10]:
# delete category column and add index
msg_one_hot = msg_one_hot.drop('Category')
msg_one_hot = msg_one_hot.select("*").withColumn("id", monotonically_increasing_id())
msg_one_hot.show()

+--------------------+-----+---+
|             Message|label| id|
+--------------------+-----+---+
|Go until jurong p...|  0.0|  0|
|Ok lar... Joking ...|  0.0|  1|
|Free entry in 2 a...|  1.0|  2|
|U dun say so earl...|  0.0|  3|
|Nah I don't think...|  0.0|  4|
|FreeMsg Hey there...|  1.0|  5|
|Even my brother i...|  0.0|  6|
|As per your reque...|  0.0|  7|
|WINNER!! As a val...|  1.0|  8|
|Had your mobile 1...|  1.0|  9|
|I'm gonna be home...|  0.0| 10|
|SIX chances to wi...|  1.0| 11|
|URGENT! You have ...|  1.0| 12|
|I've been searchi...|  0.0| 13|
|I HAVE A DATE ON ...|  0.0| 14|
|XXXMobileMovieClu...|  1.0| 15|
|Oh k...i'm watchi...|  0.0| 16|
|Eh u remember how...|  0.0| 17|
|Fine if thats th...|  0.0| 18|
|England v Macedon...|  1.0| 19|
+--------------------+-----+---+
only showing top 20 rows



In [11]:
#Remove punctuation and numbers
wrangled = msg_one_hot.withColumn('message', regexp_replace(msg_one_hot.Message, '[_():;,.!?\\-\\d]', ' '))
# wrangled = wrangled.drop('Message')
wrangled = wrangled.withColumn('message', regexp_replace(wrangled.message, ' +', ' '))
wrangled.show(4,truncate=False)

+-------------------------------------------------------------------------------------------------------------------------------+-----+---+
|message                                                                                                                        |label|id |
+-------------------------------------------------------------------------------------------------------------------------------+-----+---+
|Go until jurong point crazy Available only in bugis n great world la e buffet Cine there got amore wat                         |0.0  |0  |
|Ok lar Joking wif u oni                                                                                                        |0.0  |1  |
|Free entry in a wkly comp to win FA Cup final tkts st May Text FA to to receive entry question std txt rate T&C's apply over 's|1.0  |2  |
|U dun say so early hor U c already then say                                                                                    |0.0  |3  |
+-------------------

In [12]:
# Tokenize Message
wrangled = Tokenizer(inputCol='message', outputCol='words').transform(wrangled)
wrangled.show(4, truncate=False)

+-------------------------------------------------------------------------------------------------------------------------------+-----+---+------------------------------------------------------------------------------------------------------------------------------------------------------------+
|message                                                                                                                        |label|id |words                                                                                                                                                       |
+-------------------------------------------------------------------------------------------------------------------------------+-----+---+------------------------------------------------------------------------------------------------------------------------------------------------------------+
|Go until jurong point crazy Available only in bugis n great world la e buffet Cine there got amore wat      

In [13]:
# Remove stop words.
# print(StopWordsRemover().getStopWords())
wrangled = StopWordsRemover(inputCol="words", outputCol="terms").transform(wrangled)
wrangled.select("words","terms").show(4, truncate=False)

+------------------------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+
|words                                                                                                                                                       |terms                                                                                                                              |
+------------------------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------+
|[go, until, jurong, point, crazy, available, only, in, bugis, n, great, world, la, e, buffet, cine, there, got, amore, wat]   

In [14]:
# Apply the hashing trick
wrangled = HashingTF(inputCol="terms", outputCol="hash", numFeatures=1024).transform(wrangled)
wrangled.select("words","hash").show(4, truncate=False)

+------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|words                                                                                                                                                       |hash                                                                                                                                                                |
+------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|[go, until, jurong, point, 

In [15]:
# Convert hashed symbols to TF-IDF
tf_idf = IDF(inputCol="hash", outputCol="features").fit(wrangled).transform(wrangled)
tf_idf.select("hash","features").show(4, truncate=False)

+--------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|hash                                                                                                                                                                |features                                                                                                                                                                                                                      

##Training Model

In [16]:
# Split the data into training and testing sets
msg_train, msg_test = tf_idf.randomSplit([0.8,0.2], seed = 42)

### LogisticRegression

In [17]:
from pyspark.ml.classification import LogisticRegression
# Fit a Logistic Regression model to the training data
logistic = LogisticRegression(regParam=0.2).fit(msg_train)

# Make predictions on the testing data
pred_log = logistic.transform(msg_test)

# Create a confusion matrix, comparing predictions to known labels
pred_log.groupBy('label', 'prediction').count().show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  1.0|       1.0|  104|
|  0.0|       1.0|    2|
|  1.0|       0.0|   42|
|  0.0|       0.0|  923|
+-----+----------+-----+



In [18]:
# Calculate the elements of the confusion matrix
TN = pred_log.filter('prediction = 0 AND label = prediction').count()
TP = pred_log.filter('prediction = 1 AND label = prediction').count()
FN = pred_log.filter('prediction = 0 AND label != prediction').count()
FP = pred_log.filter('prediction = 1 AND label != prediction').count()

# Accuracy measures the proportion of correct predictions
acc_log = (TN + TP) / (TN + TP + FN + FP)
acc_log = float("{:.2f}".format(acc_log))
print(acc_log)

0.96


In [19]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator

# Calculate precision and recall
precision_log = float("{:.2f}".format(TP/(TP+FP)))
recall_log = float("{:.2f}".format(TP/(TP+FN)))
F1_log = float("{:.2f}".format(2*(precision_log*recall_log)/(precision_log+recall_log)))
print('precision = {:.2f}\nrecall    = {:.2f}\nF1-score = {:.2f}'.format(precision_log, recall_log,F1_log))

# Find weighted precision
multi_evaluator = MulticlassClassificationEvaluator()
weighted_precision = multi_evaluator.evaluate(pred_log, {multi_evaluator.metricName: "weightedPrecision"})
print('weighted precision = {:.2f}'.format(weighted_precision))

# Find AUC
binary_evaluator = BinaryClassificationEvaluator()
auc = binary_evaluator.evaluate(pred_log, {binary_evaluator.metricName: "areaUnderROC"})
print('AUC = {:.2f}'.format(auc))

precision = 0.98
recall    = 0.71
F1-score = 0.82
weighted precision = 0.96
AUC = 0.99


In [20]:
from pyspark.ml.classification import DecisionTreeClassifier
# Fit a Decision Tree model to the training data
dtree = DecisionTreeClassifier().fit(msg_train)

# Make predictions on the testing data
pred_dtree = dtree.transform(msg_test)

# Create a confusion matrix, comparing predictions to known labels
pred_dtree.groupBy('label', 'prediction').count().show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  1.0|       1.0|   75|
|  0.0|       1.0|   14|
|  1.0|       0.0|   71|
|  0.0|       0.0|  911|
+-----+----------+-----+



In [21]:
# Calculate the elements of the confusion matrix
TN = pred_dtree.filter('prediction = 0 AND label = prediction').count()
TP = pred_dtree.filter('prediction = 1 AND label = prediction').count()
FN = pred_dtree.filter('prediction = 0 AND label != prediction').count()
FP = pred_dtree.filter('prediction = 1 AND label != prediction').count()

# Accuracy measures the proportion of correct predictions
acc_dtree = float("{:.2f}".format((TN + TP) / (TN + TP + FN + FP)))
print(acc_dtree)

0.92


In [22]:
# Calculate precision and recall
precision_dtree = float("{:.2f}".format(TP/(TP+FP)))
recall_dtree = float("{:.2f}".format(TP/(TP+FN)))
F1_dtree = float("{:.2f}".format(2*(precision_dtree*recall_dtree)/(precision_dtree+recall_dtree)))
print('precision = {:.2f}\nrecall    = {:.2f}\nF1-score = {:.2f}'.format(precision_dtree, recall_dtree,F1_dtree))

precision = 0.84
recall    = 0.51
F1-score = 0.63


In [23]:
eval_col = ["Model name","Accuracy","Precision","Recall","F1-score"]
eval_models=[("Logistic Regression",acc_log,precision_log,recall_log,F1_log),
             ("Decision Tree",acc_dtree,precision_dtree,recall_dtree,F1_dtree)]
evalDF = spark.createDataFrame(data=eval_models, schema = eval_col)
evalDF.printSchema()
evalDF.show(truncate=False)

root
 |-- Model name: string (nullable = true)
 |-- Accuracy: double (nullable = true)
 |-- Precision: double (nullable = true)
 |-- Recall: double (nullable = true)
 |-- F1-score: double (nullable = true)

+-------------------+--------+---------+------+--------+
|Model name         |Accuracy|Precision|Recall|F1-score|
+-------------------+--------+---------+------+--------+
|Logistic Regression|0.96    |0.98     |0.71  |0.82    |
|Decision Tree      |0.92    |0.84     |0.51  |0.63    |
+-------------------+--------+---------+------+--------+

