In [44]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import length
from pyspark.ml.feature import (Tokenizer, StopWordsRemover, CountVectorizer, IDF, StringIndexer)
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import NaiveBayes
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [2]:
spark = SparkSession.builder.appName('nlp').getOrCreate()

In [3]:
data = spark.read.csv('smsspamcollection/SMSSpamCollection', inferSchema=True, sep='\t')

In [4]:
data.show()

+----+--------------------+
| _c0|                 _c1|
+----+--------------------+
| ham|Go until jurong p...|
| ham|Ok lar... Joking ...|
|spam|Free entry in 2 a...|
| ham|U dun say so earl...|
| ham|Nah I don't think...|
|spam|FreeMsg Hey there...|
| ham|Even my brother i...|
| ham|As per your reque...|
|spam|WINNER!! As a val...|
|spam|Had your mobile 1...|
| ham|I'm gonna be home...|
|spam|SIX chances to wi...|
|spam|URGENT! You have ...|
| ham|I've been searchi...|
| ham|I HAVE A DATE ON ...|
|spam|XXXMobileMovieClu...|
| ham|Oh k...i'm watchi...|
| ham|Eh u remember how...|
| ham|Fine if thats th...|
|spam|England v Macedon...|
+----+--------------------+
only showing top 20 rows



In [7]:
data = data.withColumnRenamed('_c0' ,  'class').withColumnRenamed('_c1', 'text')

In [8]:
data.show()

+-----+--------------------+
|class|                text|
+-----+--------------------+
|  ham|Go until jurong p...|
|  ham|Ok lar... Joking ...|
| spam|Free entry in 2 a...|
|  ham|U dun say so earl...|
|  ham|Nah I don't think...|
| spam|FreeMsg Hey there...|
|  ham|Even my brother i...|
|  ham|As per your reque...|
| spam|WINNER!! As a val...|
| spam|Had your mobile 1...|
|  ham|I'm gonna be home...|
| spam|SIX chances to wi...|
| spam|URGENT! You have ...|
|  ham|I've been searchi...|
|  ham|I HAVE A DATE ON ...|
| spam|XXXMobileMovieClu...|
|  ham|Oh k...i'm watchi...|
|  ham|Eh u remember how...|
|  ham|Fine if thats th...|
| spam|England v Macedon...|
+-----+--------------------+
only showing top 20 rows



In [10]:
data = data.withColumn('lenght', length(data['text']))

In [11]:
data.show()

+-----+--------------------+------+
|class|                text|lenght|
+-----+--------------------+------+
|  ham|Go until jurong p...|   111|
|  ham|Ok lar... Joking ...|    29|
| spam|Free entry in 2 a...|   155|
|  ham|U dun say so earl...|    49|
|  ham|Nah I don't think...|    61|
| spam|FreeMsg Hey there...|   147|
|  ham|Even my brother i...|    77|
|  ham|As per your reque...|   160|
| spam|WINNER!! As a val...|   157|
| spam|Had your mobile 1...|   154|
|  ham|I'm gonna be home...|   109|
| spam|SIX chances to wi...|   136|
| spam|URGENT! You have ...|   155|
|  ham|I've been searchi...|   196|
|  ham|I HAVE A DATE ON ...|    35|
| spam|XXXMobileMovieClu...|   149|
|  ham|Oh k...i'm watchi...|    26|
|  ham|Eh u remember how...|    81|
|  ham|Fine if thats th...|    56|
| spam|England v Macedon...|   155|
+-----+--------------------+------+
only showing top 20 rows



In [12]:
data.groupBy('class').mean().show()

+-----+-----------------+
|class|      avg(lenght)|
+-----+-----------------+
|  ham|71.45431945307645|
| spam|138.6706827309237|
+-----+-----------------+



In [22]:
tokenizer = Tokenizer(inputCol='text', outputCol='token_text')
stop_remover = StopWordsRemover(inputCol='token_text', outputCol='stop_token')
count_vec = CountVectorizer(inputCol='stop_token', outputCol='c_vec')
idf = IDF(inputCol='c_vec', outputCol='tf_idf')
ham_spam_to_numeric = StringIndexer(inputCol='class', outputCol='label')

In [32]:
clean_up = VectorAssembler(inputCols=['tf_idf', 'lenght'], outputCol='features')

In [33]:
nb = NaiveBayes()

In [34]:
data_prep_pipe = Pipeline(stages=[ham_spam_to_numeric, tokenizer, stop_remover, count_vec, idf, clean_up])

In [35]:
cleaner = data_prep_pipe.fit(data)

In [36]:
clean_data = cleaner.transform(data)

In [37]:
clean_data = clean_data.select('label', 'features')

In [38]:
clean_data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(13424,[7,11,31,6...|
|  0.0|(13424,[0,24,297,...|
|  1.0|(13424,[2,13,19,3...|
|  0.0|(13424,[0,70,80,1...|
|  0.0|(13424,[36,134,31...|
|  1.0|(13424,[10,60,139...|
|  0.0|(13424,[10,53,103...|
|  0.0|(13424,[125,184,4...|
|  1.0|(13424,[1,47,118,...|
|  1.0|(13424,[0,1,13,27...|
|  0.0|(13424,[18,43,120...|
|  1.0|(13424,[8,17,37,8...|
|  1.0|(13424,[13,30,47,...|
|  0.0|(13424,[39,96,217...|
|  0.0|(13424,[552,1697,...|
|  1.0|(13424,[30,109,11...|
|  0.0|(13424,[82,214,47...|
|  0.0|(13424,[0,2,49,13...|
|  0.0|(13424,[0,74,105,...|
|  1.0|(13424,[4,30,33,5...|
+-----+--------------------+
only showing top 20 rows



In [39]:
training, test = clean_data.randomSplit([0.7, 0.3])

In [40]:
spam_detector = nb.fit(training)

In [41]:
data.printSchema()

root
 |-- class: string (nullable = true)
 |-- text: string (nullable = true)
 |-- lenght: integer (nullable = true)



In [42]:
test_results = spam_detector.transform(test)

In [43]:
test_results.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(13424,[0,1,2,7,8...|[-808.50137400268...|[1.0,4.7743587773...|       0.0|
|  0.0|(13424,[0,1,2,13,...|[-609.03267880851...|[1.0,2.1600370834...|       0.0|
|  0.0|(13424,[0,1,4,50,...|[-840.83621648565...|[1.0,1.7777318434...|       0.0|
|  0.0|(13424,[0,1,5,20,...|[-807.48751958069...|[1.0,2.4970028012...|       0.0|
|  0.0|(13424,[0,1,7,8,1...|[-873.54615847006...|[1.0,3.5125720647...|       0.0|
|  0.0|(13424,[0,1,9,14,...|[-540.75319364805...|[1.0,6.2529974823...|       0.0|
|  0.0|(13424,[0,1,12,33...|[-447.02187973343...|[1.0,5.3636690266...|       0.0|
|  0.0|(13424,[0,1,14,18...|[-1372.0521132097...|[1.0,1.6562376671...|       0.0|
|  0.0|(13424,[0,1,20,27...|[-965.87076797947...|[1.0,1.2335736709...|       0.0|
|  0.0|(13424,[0

In [45]:
acc_eval = MulticlassClassificationEvaluator()

In [46]:
acc = acc_eval.evaluate(test_results)

In [47]:
print('ACC of NB Model: ')
print(acc)

ACC of NB Model: 
0.9186587005911563
