# Background
In this project we will build a spam filter! 

We'll use a classic dataset for this - UCI Repository SMS Spam Detection: https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection


Let's start by importing the neccessary libraries

In [42]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import (Tokenizer, RegexTokenizer, StopWordsRemover, NGram, HashingTF,
                                IDF, StringIndexer, CountVectorizer, VectorAssembler)
from pyspark.sql.functions import col, udf, length
from pyspark.sql.types import IntegerType
from pyspark.ml import Pipeline
from pyspark.ml.classification import NaiveBayes
from pyspark.ml.evaluation import MulticlassClassificationEvaluator,BinaryClassificationEvaluator

In [7]:
spark = SparkSession.builder.appName('nlp').getOrCreate()

Next we read in the data- which is in a tab delimited format. Spark's read.csv function allows you to specify the seperator, which in this case will be \t. There is also no header

In [8]:
data = spark.read.csv('/user/a208669/SMSSpamCollection',inferSchema=True,header=False,sep='\t')

In [9]:
data.printSchema()

root
 |-- _c0: string (nullable = true)
 |-- _c1: string (nullable = true)



In [26]:
data.count()

5574

In [19]:
data.show()

+----+--------------------+
| _c0|                 _c1|
+----+--------------------+
| ham|Go until jurong p...|
| ham|Ok lar... Joking ...|
|spam|Free entry in 2 a...|
| ham|U dun say so earl...|
| ham|Nah I don't think...|
|spam|FreeMsg Hey there...|
| ham|Even my brother i...|
| ham|As per your reque...|
|spam|WINNER!! As a val...|
|spam|Had your mobile 1...|
| ham|I'm gonna be home...|
|spam|SIX chances to wi...|
|spam|URGENT! You have ...|
| ham|I've been searchi...|
| ham|I HAVE A DATE ON ...|
|spam|XXXMobileMovieClu...|
| ham|Oh k...i'm watchi...|
| ham|Eh u remember how...|
| ham|Fine if thats th...|
|spam|England v Macedon...|
+----+--------------------+
only showing top 20 rows



Let's firstly rename the fields to more descriptive names

In [10]:
data = data.withColumnRenamed('_c0','class').withColumnRenamed('_c1','text')

In [11]:
data.groupBy('class').count().show()

+-----+-----+
|class|count|
+-----+-----+
|  ham| 4827|
| spam|  747|
+-----+-----+



In [20]:
print('Percentage spam: {0:.2f}%'.format(747/5574*100))

Percentage spam: 13.40%


We'll also create a new field containing the length of the text message

In [22]:
data = data.withColumn('length',length(data['text']))

In [23]:
data.show()

+-----+--------------------+------+
|class|                text|length|
+-----+--------------------+------+
|  ham|Go until jurong p...|   111|
|  ham|Ok lar... Joking ...|    29|
| spam|Free entry in 2 a...|   155|
|  ham|U dun say so earl...|    49|
|  ham|Nah I don't think...|    61|
| spam|FreeMsg Hey there...|   147|
|  ham|Even my brother i...|    77|
|  ham|As per your reque...|   160|
| spam|WINNER!! As a val...|   157|
| spam|Had your mobile 1...|   154|
|  ham|I'm gonna be home...|   109|
| spam|SIX chances to wi...|   136|
| spam|URGENT! You have ...|   155|
|  ham|I've been searchi...|   196|
|  ham|I HAVE A DATE ON ...|    35|
| spam|XXXMobileMovieClu...|   149|
|  ham|Oh k...i'm watchi...|    26|
|  ham|Eh u remember how...|    81|
|  ham|Fine if thats th...|    56|
| spam|England v Macedon...|   155|
+-----+--------------------+------+
only showing top 20 rows



Let's see if there is a difference in the length of spam vs ham messages

In [25]:
data.groupBy("class").mean().show()

+-----+-----------------+
|class|      avg(lenght)|
+-----+-----------------+
|  ham|71.45431945307645|
| spam|138.6706827309237|
+-----+-----------------+



This field looks like it's predictive. Next we'll start a pipeline with the following stages:
* Split the message into tokens - i.e. convert it to lower case and split the string by whitespace into a list of words
* StopWordsRemover - Stop words are words which should be excluded from the input, typically because the words appear frequently and don’t carry as much meaning. StopWordsRemover takes as input a sequence of strings (e.g. the output of a Tokenizer) and drops all the stop words from the input sequences. The list of stopwords is specified by the stopWords parameter. Default stop words for some languages are accessible by calling StopWordsRemover.loadDefaultStopWords(language)
* CountVectorizer - Count the number of times a particular word appear in the message
* Compute the Inverse Document Frequency (IDF) given a collection of documents.
* Convert class field to a numeric field so that it can be used in the Naive Bayes model
* VectorAssembler : Assemble the feature columns into a feature vector.
* Specifies that a naive bayes model will be fitted

In [27]:
tokenizer = Tokenizer(inputCol ='text',outputCol= 'token_text')
stop_remove = StopWordsRemover(inputCol= 'token_text',outputCol = 'stop_token')
count_vec = CountVectorizer(inputCol = 'stop_token', outputCol='c_vec')
idf = IDF(inputCol = 'c_vec', outputCol='tf_idf')
ham_spam_to_numeric = StringIndexer(inputCol = 'class', outputCol='label')
assembler = VectorAssembler(inputCols=['tf_idf','length'],outputCol='features')
nb_model = NaiveBayes(featuresCol='features',labelCol='label')

In [28]:
pipeline = Pipeline(stages=[ham_spam_to_numeric,
                            tokenizer,
                            stop_remove,
                            count_vec,idf,
                            assembler,
                            nb_model])

First we split the data into a training and test set

In [29]:
train, test = data.randomSplit([0.7,0.3],seed=78442)

Next we can fit the Naive Bayes model

In [30]:
spam_detector = pipeline.fit(train)

And see how well it performs on the test set

In [32]:
test_results = spam_detector.transform(test)

In [33]:
test_results.printSchema()

root
 |-- class: string (nullable = true)
 |-- text: string (nullable = true)
 |-- length: integer (nullable = true)
 |-- label: double (nullable = true)
 |-- token_text: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- stop_token: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- c_vec: vector (nullable = true)
 |-- tf_idf: vector (nullable = true)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = true)



In [60]:
test_results.select('class','label','prediction').show()

+-----+-----+----------+
|class|label|prediction|
+-----+-----+----------+
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       0.0|
|  ham|  0.0|       1.0|
+-----+-----+----------+
only showing top 20 rows



We'll firstly calculate the acccuracy rate

In [37]:
acc_eval = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction",metricName="accuracy")


In [35]:
acc = acc_eval.evaluate(test_results)

In [40]:
print('Accuracy rate: {0:.2f}%'.format(acc*100))

Accuracy rate: 97.74%


When we look at the confusion matrix, we see that the model is not biased toward classifying all the SMSs as ham, which is indeed good news. The accuracy rate is however not the best metric to use, since it is dependent on the cut-off value you use for your probabilities

In [48]:
test_results.crosstab('label','prediction').show()

+----------------+----+---+
|label_prediction| 0.0|1.0|
+----------------+----+---+
|             1.0|  19|188|
|             0.0|1453| 19|
+----------------+----+---+



Next we'll calculate the Area under the ROC curve

In [51]:
acc_eval2 = BinaryClassificationEvaluator(rawPredictionCol='prediction')
auc = acc_eval2.evaluate(test_results)

In [53]:
print('AUC: {0:.2f}%'.format(auc*100))

AUC: 94.77%


Not bad at all for a very rudimentary model!