# Building a Simple Spam Filter with spark.mllib

Spam is everywhere, but what if you [don't like spam](https://www.youtube.com/watch?v=anwy2MPT5RE)? Fortunately, you can use the power of machine learning to detect and filter out spam messages. In this example we are going to build a very simple classifier that decides whether a message is spam or legitimate.

## Training Data

We are going to start with two text files as our training data. Each line in the file contains a message. They were manually sorted into a spam and a non-spam file. (For a production-ready spam filter and an actual Spark use case, you would of course want to use much larger training sets.)

In [1]:
ls ../.assets/data/spam-filter/

ham   spam


In [2]:
!head ../.assets/data/spam-filter/ham

Rofl. Its true to its name
The guy did some bitching but I acted like i'd be interested in buying something else next week and he gave it to us for free
Pity, * was in mood for that. So...any other suggestions?
Will ü b going to esplanade fr home?
Huh y lei...
Why don't you wait 'til at least wednesday to see if you get your .
Ard 6 like dat lor.
Ok lor... Sony ericsson salesman... I ask shuhui then she say quite gd 2 use so i considering...
Get me out of this dump heap. My mom decided to come to lowes. BORING.
Anything lor. Juz both of us lor.


In [3]:
!head ../.assets/data/spam-filter/spam

You have 1 new message. Please call 08712400200.
Urgent! Please call 09061743811 from landline. Your ABTA complimentary 4* Tenerife Holiday or £5000 cash await collection SAE T&Cs Box 326 CW25WX 150ppm
Dear 0776xxxxxxx U've been invited to XCHAT. This is our final attempt to contact u! Txt CHAT to 86688 150p/MsgrcvdHG/Suite342/2Lands/Row/W1J6HL LDN 18yrs 
U 447801259231 have a secret admirer who is looking 2 make contact with U-find out who they R*reveal who thinks UR so special-call on 09058094597
Congrats! 2 mobile 3G Videophones R yours. call 09061744553 now! videochat wid ur mates, play java games, Dload polyH music, noline rentl. bx420. ip4. 5we. 150pm
PRIVATE! Your 2003 Account Statement for 07815296484 shows 800 un-redeemed S.I.M. points. Call 08718738001 Identifier Code 41782 Expires 18/11/04 
Do you want a new video handset? 750 anytime any network mins? Half Price Line Rental? Camcorder? Reply or call 08000930705 for delivery tomorrow
Money i have won wining number 946 wot do

## Machine Learning Workflow

### Data Import

We start by creating a `SparkContext` and using it to load the text files into RDDs.

In [4]:
import findspark
findspark.init()
import pyspark

In [5]:
conf = pyspark.SparkConf().setAppName("IDontLikeSpam") 
spark_context = pyspark.SparkContext(conf=conf)

In [6]:
spam = spark_context.textFile("../.assets/data/spam-filter/spam")
legit = spark_context.textFile("../.assets/data/spam-filter/ham")

In [7]:
spam.take(5)

['You have 1 new message. Please call 08712400200.',
 'Urgent! Please call 09061743811 from landline. Your ABTA complimentary 4* Tenerife Holiday or £5000 cash await collection SAE T&Cs Box 326 CW25WX 150ppm',
 "Dear 0776xxxxxxx U've been invited to XCHAT. This is our final attempt to contact u! Txt CHAT to 86688 150p/MsgrcvdHG/Suite342/2Lands/Row/W1J6HL LDN 18yrs ",
 'U 447801259231 have a secret admirer who is looking 2 make contact with U-find out who they R*reveal who thinks UR so special-call on 09058094597',
 'Congrats! 2 mobile 3G Videophones R yours. call 09061744553 now! videochat wid ur mates, play java games, Dload polyH music, noline rentl. bx420. ip4. 5we. 150pm']

In [27]:
from pyspark.mllib.regression import LabeledPoint


In [28]:
labelled_spam = spam.map(lambda message: LabeledPoint(1, message))
labelled_legit = legit.map(lambda message: LabelledPoint(0, message))
messages = labelled_spam.union(labelled_legit)

### Featurization

Each line of the file is a message that we classify. In order for a machine learning algorithm to "understand" a message, it is necessary to map it to a **feature vector**, that is, a set of numbers that characterizes the message and ideally allows us to clearly distinguish between different types of messages.

In order to to this we apply a method that is frequently used in text mining: **[TF-IDF](https://spark.apache.org/docs/1.2.0/mllib-feature-extraction.html)** stands for _term frequency - inverse document frequency_. It is a numerical statistic that is intended to reflect how important a word is to a document in a collection of documents - here, messages. Roughly speaking, it reflects how much information about the message a term contains, assigning little importance to terms that are very frequent in general ("the", "a", "it", ...).

Since there are a lot of different words that can occur in our messages, simply counting them in a table of all words would result in very large feature vectors, which is problematic in several ways, including being computationally expensive. Therefore, a common trick is to send words through **hash function** and map them to a slot in a fixed-length table, where the frequency of the term (and all others having the same hash index) are stored. Collisions (two words being mapped to the same slot in the counting table) are unavoidable, but if they are rare enough, the resulting vector can still characterize the message well.

In [29]:
from pyspark.mllib.feature import HashingTF

In [30]:
tf = HashingTF()
messages_tf = messages.map(lambda label, message: tf.transform(message.split(" ")))

In [31]:
messages_tf.take(5)

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.runJob.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 6.0 failed 1 times, most recent failure: Lost task 0.0 in stage 6.0 (TID 12, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/local/opt/apache-spark/libexec/python/lib/pyspark.zip/pyspark/worker.py", line 229, in main
    process()
  File "/usr/local/opt/apache-spark/libexec/python/lib/pyspark.zip/pyspark/worker.py", line 224, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/usr/local/opt/apache-spark/libexec/python/lib/pyspark.zip/pyspark/serializers.py", line 372, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "<ipython-input-28-5d4bea27c45c>", line 1, in <lambda>
  File "/usr/local/opt/apache-spark/libexec/python/lib/pyspark.zip/pyspark/mllib/regression.py", line 54, in __init__
    self.features = _convert_to_vector(features)
  File "/usr/local/opt/apache-spark/libexec/python/lib/pyspark.zip/pyspark/mllib/linalg/__init__.py", line 83, in _convert_to_vector
    raise TypeError("Cannot convert type %s into Vector" % type(l))
TypeError: Cannot convert type <class 'str'> into Vector

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:298)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:438)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:421)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:252)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$class.foreach(Iterator.scala:893)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:204)
	at org.apache.spark.api.python.PythonRunner$$anon$2.writeIteratorToStream(PythonRunner.scala:407)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread$$anonfun$run$1.apply(PythonRunner.scala:215)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1988)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:170)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1599)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1587)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1586)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1586)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:831)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1820)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1769)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1758)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:642)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2027)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2048)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2067)
	at org.apache.spark.api.python.PythonRDD$.runJob(PythonRDD.scala:141)
	at org.apache.spark.api.python.PythonRDD.runJob(PythonRDD.scala)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:497)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:214)
	at java.lang.Thread.run(Thread.java:745)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/local/opt/apache-spark/libexec/python/lib/pyspark.zip/pyspark/worker.py", line 229, in main
    process()
  File "/usr/local/opt/apache-spark/libexec/python/lib/pyspark.zip/pyspark/worker.py", line 224, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/usr/local/opt/apache-spark/libexec/python/lib/pyspark.zip/pyspark/serializers.py", line 372, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "<ipython-input-28-5d4bea27c45c>", line 1, in <lambda>
  File "/usr/local/opt/apache-spark/libexec/python/lib/pyspark.zip/pyspark/mllib/regression.py", line 54, in __init__
    self.features = _convert_to_vector(features)
  File "/usr/local/opt/apache-spark/libexec/python/lib/pyspark.zip/pyspark/mllib/linalg/__init__.py", line 83, in _convert_to_vector
    raise TypeError("Cannot convert type %s into Vector" % type(l))
TypeError: Cannot convert type <class 'str'> into Vector

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:298)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:438)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:421)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:252)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$class.foreach(Iterator.scala:893)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:204)
	at org.apache.spark.api.python.PythonRunner$$anon$2.writeIteratorToStream(PythonRunner.scala:407)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread$$anonfun$run$1.apply(PythonRunner.scala:215)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1988)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:170)


In [14]:
messages_tf.cache()

PythonRDD[7] at RDD at PythonRDD.scala:48

In [17]:
from pyspark.mllib.feature import IDF

In [19]:
messages_idf = IDF().fit(messages_tf)
messages_tfidf = messages_idf.transform(messages_tf)

In [20]:
messages_tfidf.take(5)

[SparseVector(1048576, {167518: 2.7483, 194694: 4.1487, 678178: 4.4517, 706364: 2.537, 749211: 5.8535, 808351: 7.9329, 877522: 4.4672, 1016101: 3.0966}),
 SparseVector(1048576, {20318: 6.1411, 71008: 5.535, 75033: 5.6303, 167518: 2.7483, 208919: 7.9329, 291134: 5.5815, 339611: 5.4072, 340940: 4.6942, 392222: 5.7357, 434090: 6.1411, 553276: 5.448, 593153: 6.0611, 664797: 5.8535, 673956: 5.2588, 678178: 4.4517, 698689: 7.0166, 699526: 6.5466, 714899: 4.0309, 746732: 5.6303, 753414: 7.5274, 896657: 7.2398, 914751: 2.7568, 1017725: 3.1086}),
 SparseVector(1048576, {0: 2.574, 99605: 5.6303, 142838: 6.8343, 154253: 3.6781, 214801: 5.8535, 288871: 6.8343, 344448: 7.0166, 410315: 4.2822, 432395: 3.8139, 451532: 7.0166, 471671: 5.918, 523216: 6.2282, 578619: 3.8554, 598496: 7.2398, 617454: 2.1293, 640647: 4.8648, 796920: 5.448, 820228: 4.4065, 995308: 7.5274, 1001627: 4.4065, 1047659: 6.1411}),
 SparseVector(1048576, {27527: 2.7045, 41083: 7.5274, 117277: 6.6801, 204505: 13.0481, 204835: 2.9806

In [None]:
#spark_context.stop()

In [None]:
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.feature import HashingTF, IDF
from pyspark.mllib.classification import LogisticRegressionWithSGD

from contextlib import contextmanager
from pyspark import SparkContext, SparkConf

@contextmanager
def use_spark_context(appName):
    conf = SparkConf().setAppName(appName) 
    spark_context = SparkContext(conf=conf)

    try:
        print("starting ", appName)
        yield spark_context
    finally:
        spark_context.stop()
        print("stopping ", appName)
        
with use_spark_context(appName="SpamFilter") as sc:
    spam = sc.textFile("../.assets/data/spam-filter/spam")
    normal = sc.textFile("../.assets/data/spam-filter/ham")
    
    # Create a HashingTF instance to map email text to vectors of 10,000 features.
    tf = HashingTF(numFeatures = 10)
    
    # Each email is split into words, and each word is mapped to one feature. 
    spamFeatures = spam.map(lambda email: tf.transform(email.split(" ")))
    normalFeatures = normal.map(lambda email: tf.transform(email.split(" ")))
    
    # Create LabeledPoint datasets for positive (spam) and negative (normal) examples.
    positiveExamples = spamFeatures.map(lambda features: LabeledPoint(1, features))
    negativeExamples = normalFeatures.map(lambda features: LabeledPoint(0, features))
    trainingData = positiveExamples.union(negativeExamples)
    trainingData.cache() # Cache since Logistic Regression is an iterative algorithm.
    
    # Run Logistic Regression using the SGD algorithm.
    model = LogisticRegressionWithSGD.train(trainingData)
    
    # Test on a positive example (spam) and a negative one (normal). We first apply
    # the same HashingTF feature transformation to get vectors, then apply the model. 
    posTest = tf.transform("O M G GET cheap stuff by sending money to ...".split(" "))
    negTest = tf.transform("Hi Dad, I started studying Spark the other ...".split(" "))
    print("Prediction for positive test example: %g" % model.predict(posTest))
    print("Prediction for negative test example: %g" % model.predict(negTest))  



not got much spam in it.

---
_This notebook is licensed under a [Creative Commons Attribution 4.0 International License (CC BY 4.0)](https://creativecommons.org/licenses/by/4.0/). Copyright © 2018 [Point 8 GmbH](https://point-8.de)_