# Building a Spam Filter with Naive Bayes

In this guided project, we're going to study the practical side of the algorithm by building a spam filter for SMS messages.

Our first task is to "teach" the computer how to classify messages. To do that, we'll use the multinomial Naive Bayes algorithm along with a dataset of 5,572 SMS messages that are already classified by humans.

The dataset was put together by Tiago A. Almeida and José María Gómez Hidalgo, and it can be downloaded from [The UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/sms+spam+collection). You can also download the dataset directly [from this link](https://dq-content.s3.amazonaws.com/433/SMSSpamCollection). The data collection process is described in more details on [this page](http://www.dt.fee.unicamp.br/~tiago/smsspamcollection/#composition), where you can also find some of the authors' papers.

## Getting the data

In [2]:
import pandas as pd

sms = pd.read_csv('SMSSpamCollection', 
                  sep='\t' , 
                  header=None, 
                  names=['Label', 'SMS'])

In [3]:
sms.shape

(5572, 2)

In [4]:
sms.head()

Unnamed: 0,Label,SMS
0,ham,"Go until jurong point, crazy.. Available only ..."
1,ham,Ok lar... Joking wif u oni...
2,spam,Free entry in 2 a wkly comp to win FA Cup fina...
3,ham,U dun say so early hor... U c already then say...
4,ham,"Nah I don't think he goes to usf, he lives aro..."


In [5]:
sms['Label'].value_counts(normalize=True)*100

ham     86.593683
spam    13.406317
Name: Label, dtype: float64

In this dataset, we have 5572 messages in which 86% are non-spam and 13% are labelled as spam.

## Training and testing set

When creating software (a spam filter is software), a good rule of thumb is that designing the test comes before creating the software. If we write the software first, then it's tempting to come up with a biased test just to make sure the software passes it.

Once our spam filter is done, we'll need to test how good it is with classifying new messages. To test the spam filter, we're first going to split our dataset into two categories:

* A **training set**, which we'll use to "train" the computer how to classify messages.
* A **test set**, which we'll use to test how good the spam filter is with classifying new messages.

We're going to keep 80% of our dataset for training, and 20% for testing.

In [6]:
# randomize the dataset
sample = sms.sample(frac=1, random_state=1)

# Calculate index for split
training_test_index = round(len(sample) * 0.8)

# Training/Test split
training_set = sample[:training_test_index].reset_index(drop=True)
test_set = sample[training_test_index:].reset_index(drop=True)

print(training_set.shape)
print(test_set.shape)

(4458, 2)
(1114, 2)


In [7]:
# percentage of spam and non-spam in each set
print(training_set['Label'].value_counts(normalize=True)*100)
print(test_set['Label'].value_counts(normalize=True)*100)

ham     86.54105
spam    13.45895
Name: Label, dtype: float64
ham     86.804309
spam    13.195691
Name: Label, dtype: float64


## Data cleaning

In [8]:
training_set.head()

Unnamed: 0,Label,SMS
0,ham,"Yep, by the pretty sculpture"
1,ham,"Yes, princess. Are you going to make me moan?"
2,ham,Welp apparently he retired
3,ham,Havent.
4,ham,I forgot 2 ask ü all smth.. There's a card on ...


In [9]:
# remove punctuation and change to lowercase
training_set['SMS'] = training_set['SMS'].str.replace('\W', ' ').str.lower()
training_set.head()

Unnamed: 0,Label,SMS
0,ham,yep by the pretty sculpture
1,ham,yes princess are you going to make me moan
2,ham,welp apparently he retired
3,ham,havent
4,ham,i forgot 2 ask ü all smth there s a card on ...


## Creating the vocabulary

In [10]:
training_set['SMS'] = training_set['SMS'].str.split()

vocabulary = []
for message in training_set['SMS']:
    for word in message:
        vocabulary.append(word)
        
vocabulary = list(set(vocabulary))

In [11]:
len(vocabulary)

7783

There are 7783 unique words in our vocabulary.

## Data transformation

In [12]:
word_count_per_sms = {word : [0]*len(training_set['SMS']) for word in vocabulary}

for index, sms in enumerate(training_set['SMS']):
    for word in sms:
        word_count_per_sms[word][index]+=1

In [13]:
word_counts = pd.DataFrame(word_count_per_sms)

In [14]:
word_counts.head()

Unnamed: 0,0,00,000,000pes,008704050406,0089,01223585334,02,0207,02072069400,...,zindgi,zoe,zogtorius,zouk,zyada,é,ú1,ü,〨ud,鈥
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,2,0,0


In [15]:
word_counts.columns[1000:1230]

Index(['another', 'ans', 'ansr', 'answer', 'answered', 'answerin', 'answering',
       'answers', 'answr', 'antelope',
       ...
       'backdoor', 'bad', 'badass', 'badrith', 'bag', 'bags', 'bahamas',
       'baig', 'bailiff', 'bajarangabali'],
      dtype='object', length=230)

In [16]:
clean_train_set = pd.concat([training_set, word_counts], 
                           axis=1)

clean_train_set.head()

Unnamed: 0,Label,SMS,0,00,000,000pes,008704050406,0089,01223585334,02,...,zindgi,zoe,zogtorius,zouk,zyada,é,ú1,ü,〨ud,鈥
0,ham,"[yep, by, the, pretty, sculpture]",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,ham,"[yes, princess, are, you, going, to, make, me,...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,ham,"[welp, apparently, he, retired]",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,ham,[havent],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,ham,"[i, forgot, 2, ask, ü, all, smth, there, s, a,...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,2,0,0


## Calculating the constants

In [17]:
spam_sms = clean_train_set[clean_train_set['Label']=='spam']
ham_sms = clean_train_set[clean_train_set['Label']=='ham']

In [18]:
# P(spam) and P(ham)
p_spam = len(spam_sms)/len(clean_train_set)
p_ham = len(ham_sms)/len(clean_train_set)

In [19]:
# N_spam, N_ham and N_vocab
words_spam = spam_sms['SMS'].apply(len)
N_spam = words_spam.sum() #total number of words in spam

words_ham = ham_sms['SMS'].apply(len)
N_ham = words_ham.sum() #total number of words in non-spam

N_vocab = len(vocabulary)

In [20]:
# Laplace/additive smoothing
alpha = 1

## Calculating parameters

We have 7,783 words in our vocabulary, which means we'll need to calculate a total of 15,566 probabilities. For each word, we need to calculate both P(wi|Spam) and P(wi|Ham).

The fact that we calculate so many values before even beginning the classification of new messages makes the Naive Bayes algorithm very fast (especially compared to other algorithms). When a new message comes in, most of the needed computations are already done, which enables the algorithm to almost instantly classify the new message.

In [21]:
#initialize the parameters
p_w_spam = {str(word):0 for word in vocabulary}
p_w_ham = {str(word):0 for word in vocabulary}

for word in vocabulary:
    #number of times w occurs in spam sms
    N_w_spam = spam_sms[word].sum()
    
    #probability of w given spam
    w_given_spam = (N_w_spam + alpha)/(N_spam + (alpha*N_vocab))
    p_w_spam[word] = w_given_spam
    
    #number of times w occurs in non-spam sms
    N_w_ham = ham_sms[word].sum()
    
    #probability of w given non-spam
    w_given_ham = (N_w_ham + alpha)/(N_ham + (alpha*N_vocab))
    p_w_ham[word] = w_given_ham

## Classifying a new message

Now that we have all our parameters calculated, we can start creating the spam filter. The spam filter can be understood as a function that:

* Takes in as input a new message (w1, w2, ..., wn).
* Calculates P(Spam|w1, w2, ..., wn) and P(Ham|w1, w2, ..., wn).
* Compares the values of P(Spam|w1, w2, ..., wn) and P(Ham|w1, w2, ..., wn), and:
    * If P(Ham|w1, w2, ..., wn) > P(Spam|w1, w2, ..., wn), then the message is classified as ham.
    * If P(Ham|w1, w2, ..., wn) < P(Spam|w1, w2, ..., wn), then the message is classified as spam.
    * If P(Ham|w1, w2, ..., wn) = P(Spam|w1, w2, ..., wn), then the algorithm may request human help.

In [22]:
import re

def classify(message):

    message = re.sub('\W', ' ', message)
    message = message.lower()
    message = message.split()


    p_spam_given_message = p_spam
    p_ham_given_message = p_ham
    
    for word in message:
        if word in p_w_spam:
            p_spam_given_message *= p_w_spam[word]
            
        if word in p_w_ham:
            p_ham_given_message *= p_w_ham[word]

    print('P(Spam|message):', p_spam_given_message)
    print('P(Ham|message):', p_ham_given_message)

    if p_ham_given_message > p_spam_given_message:
        print('Label: Ham')
    elif p_ham_given_message < p_spam_given_message:
        print('Label: Spam')
    else:
        print('Equal proabilities, have a human classify this!')

Now, let's test our classifier.

In [23]:
classify('WINNER!! This is the secret code to unlock the money: C3421.')

P(Spam|message): 1.3481290211300841e-25
P(Ham|message): 1.9368049028589875e-27
Label: Spam


In [24]:
classify("Sounds good, Tom, then see u there")

P(Spam|message): 2.4372375665888117e-25
P(Ham|message): 3.687530435009238e-21
Label: Ham


## Accuracy on test set

Let's modify our function.

In [27]:
def classify_test_set(message):

    message = re.sub('\W', ' ', message)
    message = message.lower()
    message = message.split()

    p_spam_given_message = p_spam
    p_ham_given_message = p_ham

    for word in message:
        if word in p_w_spam:
            p_spam_given_message *= p_w_spam[word]

        if word in p_w_ham:
            p_ham_given_message *= p_w_ham[word]

    if p_ham_given_message > p_spam_given_message:
        return 'ham'
    elif p_spam_given_message > p_ham_given_message:
        return 'spam'
    else:
        return 'needs human classification'

In [28]:
test_set['predicted'] = test_set['SMS'].apply(classify_test_set)
test_set.head()

Unnamed: 0,Label,SMS,predicted
0,ham,Later i guess. I needa do mcat study too.,ham
1,ham,But i haf enuff space got like 4 mb...,ham
2,spam,Had your mobile 10 mths? Update to latest Oran...,spam
3,ham,All sounds good. Fingers . Makes it difficult ...,ham
4,ham,"All done, all handed in. Don't know if mega sh...",ham


In [29]:
#measure accuracy
correct = 0
total = len(test_set)

for row in test_set.iterrows():
    row = row[1]
    if row['Label'] == row['predicted']:
        correct += 1
        
print('Correct : {}'.format(correct))
print('Total : {}'.format(total))
print('Accuracy : {}'.format(correct/total))

Correct : 1100
Total : 1114
Accuracy : 0.9874326750448833


We have a 98% accuracy, which is good. Our classifier looked at 1114 messages and predicted 1100 correctly.