# Building a Spam Filter with Naive Bayes
***
## Introduction 

In this project, we will build a spam filter for SMS messages using the multinomial Naive Bayes algorithm along with Logistic Regression to compare the two models. We aim to write a program that classifies new messages with an accuracy greater than 90% — so we expect that more than 90% of the new messages will be classified correctly as spam or ham (non-spam).

To train our algorithm, we will use a dataset of 5,572 pre-classified (spam vs. non-spam) text messages. The dataset can be downloaded from [The UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/sms+spam+collection). When it comes to training the algorithm, we'll be teaching it how to classify messages from the perspective of a human as accurately as possible. This means we want the algorithm to see text messages through our eyes as humans would when we look at a message and can tell whether it's spam. With this knowledge, the algorithm will classify new messages and tell us how many are spam or ham. It's based on probability, and the algorithm combines everything it learns to make it as accurate as possible.

## Importing Packages And Exploring The Dataset

In [46]:
import pandas as pd
import re
sms_spam = pd.read_csv('SMSSpamCollection', sep='\t', header=None, names=['Label', 'SMS'])

print(sms_spam.shape)
sms_spam.head()

(5572, 2)


Unnamed: 0,Label,SMS
0,ham,"Go until jurong point, crazy.. Available only ..."
1,ham,Ok lar... Joking wif u oni...
2,spam,Free entry in 2 a wkly comp to win FA Cup fina...
3,ham,U dun say so early hor... U c already then say...
4,ham,"Nah I don't think he goes to usf, he lives aro..."


Above, we have the head of our dataset, which is the first five rows. We see we have a **Label** column with the words **ham** and **spam**. What **ham** means in the context of this analysis are messages that aren't considered spam. **Spam** means that the message is well, spam, or unwanted. We also have an **SMS** column which means **Short Message Service**, or in otherwords, a text message. We can see some wording for each row of the SMS above.

Next, we will look at how many messages are considered spam and how many are not.

In [47]:
sms_spam['Label'].value_counts(normalize=True)

ham     0.865937
spam    0.134063
Name: Label, dtype: float64

In [48]:
sms_spam['Label'].value_counts(normalize=False)

ham     4825
spam     747
Name: Label, dtype: int64

Approximately 86.6% of messages are ham messages or not spam, and roughly 13.4% are considered spam messages. There are 5,572 messages, with 4,825 being ham and 747 being spam.

## Training And Test Set

We're now going to split our dataset into a training and a test set, where the training set accounts for 80% of the data, and the test set for the remaining 20%.

In [49]:
data_randomized = sms_spam.sample(frac=1, random_state=1)

training_test_index = round(len(data_randomized) * 0.8)

training_set = data_randomized[:training_test_index].reset_index(drop=True)
test_set = data_randomized[training_test_index:].reset_index(drop=True)

print(training_set.shape)
print(test_set.shape)

(4458, 2)
(1114, 2)


We'll now analyze the percentage of spam and ham messages in the training and test sets. We expect the percentages to be close to what we have in the full dataset, where about 87% of the messages are ham, and the remaining 13% are spam.

In [50]:
training_set['Label'].value_counts(normalize=True)

ham     0.86541
spam    0.13459
Name: Label, dtype: float64

In [51]:
test_set['Label'].value_counts(normalize=True)

ham     0.868043
spam    0.131957
Name: Label, dtype: float64

The results look promising! We'll now move on to cleaning the dataset.

## Data Cleaning

To calculate all the probabilities required by the algorithm, we'll first need to perform a bit of data cleaning to bring the data in a format that will allow us to extract easily all the information we need.

Essentially, we want to bring data to this format:

![Alt text](SpamFilterPhoto.png)

One thing to notice from this new format on the bottom of the graphic is that the **SMS** column has been taken away and replaced with various words. The reason is that these words are common ones found in spam messages, so the numbers for each row under the words mean how many times these unique words appear within an SMS message. For example, in the first row, the words **secret** and **prize** popped up twice in the SMS message.

We want the words to all be in lowercase because every unique word will be classified the same way, which you will see in the next section. At the top of the graphic, you see that in the rows for the SMS messages, the words have a mix of both uppercase and lowercase. The algorithm will have an easier time breaking down if a message is spam or not if every letter in every word is the same, which is why we have changed every letter to be lowercase. Punctuation such as having **!** or **?** is taken out as well since the algorithm is solely focused on the words and adding them all up to see if there's common spam phrasing in the SMS messages.

## Letter Case And Punctuation

We'll begin by removing all the punctuation and bringing every letter to lowercase.

In [52]:
training_set.head()

Unnamed: 0,Label,SMS
0,ham,"Yep, by the pretty sculpture"
1,ham,"Yes, princess. Are you going to make me moan?"
2,ham,Welp apparently he retired
3,ham,Havent.
4,ham,I forgot 2 ask ü all smth.. There's a card on ...


In [53]:
training_set['SMS'] = training_set['SMS'].str.replace('\W', ' ')
training_set['SMS'] = training_set['SMS'].str.lower()
training_set.head()

  training_set['SMS'] = training_set['SMS'].str.replace('\W', ' ')


Unnamed: 0,Label,SMS
0,ham,yep by the pretty sculpture
1,ham,yes princess are you going to make me moan
2,ham,welp apparently he retired
3,ham,havent
4,ham,i forgot 2 ask ü all smth there s a card on ...


## Creating The Vocabulary

Let's now move to creating the vocabulary, which in this context means a list with all the unique words in our training set.

In [54]:
training_set['SMS'] = training_set['SMS'].str.split()

vocabulary = []
for sms in training_set['SMS']:
    for word in sms:
        vocabulary.append(word)
        
vocabulary = list(set(vocabulary))

In [55]:
len(vocabulary)

7783

What **Vocabulary** means in the dataset is all of the unique words in the training set. We have 7,783 unique words.

## The Final Training Set

Now we will build a dictionary that we'll convert to the Dataframe we need. 

In [56]:
word_counts_per_sms = {unique_word: [0] * len(training_set['SMS']) for unique_word in vocabulary}

for index, sms in enumerate(training_set['SMS']):
    for word in sms:
        word_counts_per_sms[word][index] += 1

In [57]:
word_counts = pd.DataFrame(word_counts_per_sms)
word_counts.head()

Unnamed: 0,luv,ah,asshole,chocolate,imagine,mums,andros,2find,nigpun,no1,...,permission,beendropping,trav,mns,hospital,image,gas,tmorow,81303,replys150
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [58]:
training_set_clean = pd.concat([training_set, word_counts], axis=1)
training_set_clean.head()

Unnamed: 0,Label,SMS,luv,ah,asshole,chocolate,imagine,mums,andros,2find,...,permission,beendropping,trav,mns,hospital,image,gas,tmorow,81303,replys150
0,ham,"[yep, by, the, pretty, sculpture]",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,ham,"[yes, princess, are, you, going, to, make, me,...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,ham,"[welp, apparently, he, retired]",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,ham,[havent],0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,ham,"[i, forgot, 2, ask, ü, all, smth, there, s, a,...",0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


We've printed out each unique word in the dataset within **Vocabulary**, and each column is a word within the dataset that shows how many times that word is used in each SMS. 

## Calculating Constants First

We're now done with cleaning the training set, and we can begin creating the spam filter. The Naive Bayes algorithm will need to answer these two probability questions to be able to classify new messages:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
  <mtable displaystyle="true">
    <mlabeledtr>
      <mtd>
        <mtext>(1)</mtext>
      </mtd>
      <mtd>
        <mi>P</mi>
        <mo stretchy="false">(</mo>
        <mi>S</mi>
        <mi>p</mi>
        <mi>a</mi>
        <mi>m</mi>
        <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
        <msub>
          <mi>w</mi>
          <mn>1</mn>
        </msub>
        <mo>,</mo>
        <msub>
          <mi>w</mi>
          <mn>2</mn>
        </msub>
        <mo>,</mo>
        <mo>.</mo>
        <mo>.</mo>
        <mo>.</mo>
        <mo>,</mo>
        <msub>
          <mi>w</mi>
          <mi>n</mi>
        </msub>
        <mo stretchy="false">)</mo>
        <mo>&#x221D;</mo>
        <mi>P</mi>
        <mo stretchy="false">(</mo>
        <mi>S</mi>
        <mi>p</mi>
        <mi>a</mi>
        <mi>m</mi>
        <mo stretchy="false">)</mo>
        <mo>&#x22C5;</mo>
        <munderover>
          <mo data-mjx-texclass="OP">&#x220F;</mo>
          <mrow data-mjx-texclass="ORD">
            <mi>i</mi>
            <mo>=</mo>
            <mn>1</mn>
          </mrow>
          <mrow data-mjx-texclass="ORD">
            <mi>n</mi>
          </mrow>
        </munderover>
        <mi>P</mi>
        <mo stretchy="false">(</mo>
        <msub>
          <mi>w</mi>
          <mi>i</mi>
        </msub>
        <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
        <mi>S</mi>
        <mi>p</mi>
        <mi>a</mi>
        <mi>m</mi>
        <mo stretchy="false">)</mo>
      </mtd>
    </mlabeledtr>
  </mtable>
</math>

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
  <mtable displaystyle="true">
    <mlabeledtr>
      <mtd>
        <mtext>(2)</mtext>
      </mtd>
      <mtd>
        <mi>P</mi>
        <mo stretchy="false">(</mo>
        <mi>H</mi>
        <mi>a</mi>
        <mi>m</mi>
        <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
        <msub>
          <mi>w</mi>
          <mn>1</mn>
        </msub>
        <mo>,</mo>
        <msub>
          <mi>w</mi>
          <mn>2</mn>
        </msub>
        <mo>,</mo>
        <mo>.</mo>
        <mo>.</mo>
        <mo>.</mo>
        <mo>,</mo>
        <msub>
          <mi>w</mi>
          <mi>n</mi>
        </msub>
        <mo stretchy="false">)</mo>
        <mo>&#x221D;</mo>
        <mi>P</mi>
        <mo stretchy="false">(</mo>
        <mi>H</mi>
        <mi>a</mi>
        <mi>m</mi>
        <mo stretchy="false">)</mo>
        <mo>&#x22C5;</mo>
        <munderover>
          <mo data-mjx-texclass="OP">&#x220F;</mo>
          <mrow data-mjx-texclass="ORD">
            <mi>i</mi>
            <mo>=</mo>
            <mn>1</mn>
          </mrow>
          <mrow data-mjx-texclass="ORD">
            <mi>n</mi>
          </mrow>
        </munderover>
        <mi>P</mi>
        <mo stretchy="false">(</mo>
        <msub>
          <mi>w</mi>
          <mi>i</mi>
        </msub>
        <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
        <mi>H</mi>
        <mi>a</mi>
        <mi>m</mi>
        <mo stretchy="false">)</mo>
      </mtd>
    </mlabeledtr>
  </mtable>
</math>

Also, to calculate P(wi|Spam) and P(wi|Ham) inside the formulas above, we'll need to use these equations:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
  <mtable displaystyle="true">
    <mlabeledtr>
      <mtd>
        <mtext>(3)</mtext>
      </mtd>
      <mtd>
        <mi>P</mi>
        <mo stretchy="false">(</mo>
        <msub>
          <mi>w</mi>
          <mi>i</mi>
        </msub>
        <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
        <mi>S</mi>
        <mi>p</mi>
        <mi>a</mi>
        <mi>m</mi>
        <mo stretchy="false">)</mo>
        <mo>=</mo>
        <mfrac>
          <mrow>
            <msub>
              <mi>N</mi>
              <mrow data-mjx-texclass="ORD">
                <msub>
                  <mi>w</mi>
                  <mi>i</mi>
                </msub>
                <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
                <mi>S</mi>
                <mi>p</mi>
                <mi>a</mi>
                <mi>m</mi>
              </mrow>
            </msub>
            <mo>+</mo>
            <mi>&#x3B1;</mi>
          </mrow>
          <mrow>
            <msub>
              <mi>N</mi>
              <mrow data-mjx-texclass="ORD">
                <mi>S</mi>
                <mi>p</mi>
                <mi>a</mi>
                <mi>m</mi>
              </mrow>
            </msub>
            <mo>+</mo>
            <mi>&#x3B1;</mi>
            <mo>&#x22C5;</mo>
            <msub>
              <mi>N</mi>
              <mrow data-mjx-texclass="ORD">
                <mi>V</mi>
                <mi>o</mi>
                <mi>c</mi>
                <mi>a</mi>
                <mi>b</mi>
                <mi>u</mi>
                <mi>l</mi>
                <mi>a</mi>
                <mi>r</mi>
                <mi>y</mi>
              </mrow>
            </msub>
          </mrow>
        </mfrac>
      </mtd>
    </mlabeledtr>
  </mtable>
</math>

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
  <mtable displaystyle="true">
    <mlabeledtr>
      <mtd>
        <mtext>(4)</mtext>
      </mtd>
      <mtd>
        <mi>P</mi>
        <mo stretchy="false">(</mo>
        <msub>
          <mi>w</mi>
          <mi>i</mi>
        </msub>
        <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
        <mi>H</mi>
        <mi>a</mi>
        <mi>m</mi>
        <mo stretchy="false">)</mo>
        <mo>=</mo>
        <mfrac>
          <mrow>
            <msub>
              <mi>N</mi>
              <mrow data-mjx-texclass="ORD">
                <msub>
                  <mi>w</mi>
                  <mi>i</mi>
                </msub>
                <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
                <mi>H</mi>
                <mi>a</mi>
                <mi>m</mi>
              </mrow>
            </msub>
            <mo>+</mo>
            <mi>&#x3B1;</mi>
          </mrow>
          <mrow>
            <msub>
              <mi>N</mi>
              <mrow data-mjx-texclass="ORD">
                <mi>H</mi>
                <mi>a</mi>
                <mi>m</mi>
              </mrow>
            </msub>
            <mo>+</mo>
            <mi>&#x3B1;</mi>
            <mo>&#x22C5;</mo>
            <msub>
              <mi>N</mi>
              <mrow data-mjx-texclass="ORD">
                <mi>V</mi>
                <mi>o</mi>
                <mi>c</mi>
                <mi>a</mi>
                <mi>b</mi>
                <mi>u</mi>
                <mi>l</mi>
                <mi>a</mi>
                <mi>r</mi>
                <mi>y</mi>
              </mrow>
            </msub>
          </mrow>
        </mfrac>
      </mtd>
    </mlabeledtr>
  </mtable>
</math>

What's going on is a new message comes in, and the algorithm uses the first two equations to classify based on the words in the SMS if it should be considered spam or ham. Before it gets to the first two equations, it must calculate the probability of the SMS being a spam message or not by using equations three and four shown above. What these two equations are doing is looking at the words in **Vocabulary** which is our training set we originally used to train the algorithm, and is looking at the words in the SMS to compare it with the overall **Vocabulary** dataset and classify the message as being spam or ham.

Some of the terms in the four equations above will have the same value for every new message. We can calculate the value of these terms once and avoid doing the computations again when a new message comes in. Below, we'll use our training set to calculate:

- P(Spam) and P(Ham)
- NSpam, NHam, NVocabulary

We'll also use Laplace smoothing and set <math xmlns="http://www.w3.org/1998/Math/MathML">
  <mi>&#x3B1;</mi>
  <mo>=</mo>
  <mn>1</mn>
</math>

In [59]:
spam_messages = training_set_clean[training_set_clean['Label'] == 'spam']
ham_messages = training_set_clean[training_set_clean['Label'] == 'ham']

p_spam = len(spam_messages) / len(training_set_clean)
p_ham = len(ham_messages) / len(training_set_clean)

n_words_per_spam_message = spam_messages['SMS'].apply(len)
n_spam = n_words_per_spam_message.sum()

n_words_per_ham_message = ham_messages['SMS'].apply(len)
n_ham = n_words_per_ham_message.sum()

n_vocabulary = len(vocabulary)

alpha = 1

## Calculating Parameters

Now that we have the constant terms calculated above, we can move on with calculating the parameters <math xmlns="http://www.w3.org/1998/Math/MathML">
  <mi>P</mi>
  <mo stretchy="false">(</mo>
  <msub>
    <mi>w</mi>
    <mi>i</mi>
  </msub>
  <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
  <mi>S</mi>
  <mi>p</mi>
  <mi>a</mi>
  <mi>m</mi>
  <mo stretchy="false">)</mo>
</math>
and 
<math xmlns="http://www.w3.org/1998/Math/MathML">
  <mi>P</mi>
  <mo stretchy="false">(</mo>
  <msub>
    <mi>w</mi>
    <mi>i</mi>
  </msub>
  <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
  <mi>H</mi>
  <mi>a</mi>
  <mi>m</mi>
  <mo stretchy="false">)</mo>
</math>
Each parameter will thus be a conditional probability value associated with each word in the vocabulary.

The parameters are calculated using the formulas:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
  <mtable displaystyle="true">
    <mlabeledtr>
      <mtd>
        <mtext>(5)</mtext>
      </mtd>
      <mtd>
        <mi>P</mi>
        <mo stretchy="false">(</mo>
        <msub>
          <mi>w</mi>
          <mi>i</mi>
        </msub>
        <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
        <mi>S</mi>
        <mi>p</mi>
        <mi>a</mi>
        <mi>m</mi>
        <mo stretchy="false">)</mo>
        <mo>=</mo>
        <mfrac>
          <mrow>
            <msub>
              <mi>N</mi>
              <mrow data-mjx-texclass="ORD">
                <msub>
                  <mi>w</mi>
                  <mi>i</mi>
                </msub>
                <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
                <mi>S</mi>
                <mi>p</mi>
                <mi>a</mi>
                <mi>m</mi>
              </mrow>
            </msub>
            <mo>+</mo>
            <mi>&#x3B1;</mi>
          </mrow>
          <mrow>
            <msub>
              <mi>N</mi>
              <mrow data-mjx-texclass="ORD">
                <mi>S</mi>
                <mi>p</mi>
                <mi>a</mi>
                <mi>m</mi>
              </mrow>
            </msub>
            <mo>+</mo>
            <mi>&#x3B1;</mi>
            <mo>&#x22C5;</mo>
            <msub>
              <mi>N</mi>
              <mrow data-mjx-texclass="ORD">
                <mi>V</mi>
                <mi>o</mi>
                <mi>c</mi>
                <mi>a</mi>
                <mi>b</mi>
                <mi>u</mi>
                <mi>l</mi>
                <mi>a</mi>
                <mi>r</mi>
                <mi>y</mi>
              </mrow>
            </msub>
          </mrow>
        </mfrac>
      </mtd>
    </mlabeledtr>
  </mtable>
</math>

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
  <mtable displaystyle="true">
    <mlabeledtr>
      <mtd>
        <mtext>(6)</mtext>
      </mtd>
      <mtd>
        <mi>P</mi>
        <mo stretchy="false">(</mo>
        <msub>
          <mi>w</mi>
          <mi>i</mi>
        </msub>
        <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
        <mi>H</mi>
        <mi>a</mi>
        <mi>m</mi>
        <mo stretchy="false">)</mo>
        <mo>=</mo>
        <mfrac>
          <mrow>
            <msub>
              <mi>N</mi>
              <mrow data-mjx-texclass="ORD">
                <msub>
                  <mi>w</mi>
                  <mi>i</mi>
                </msub>
                <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
                <mi>H</mi>
                <mi>a</mi>
                <mi>m</mi>
              </mrow>
            </msub>
            <mo>+</mo>
            <mi>&#x3B1;</mi>
          </mrow>
          <mrow>
            <msub>
              <mi>N</mi>
              <mrow data-mjx-texclass="ORD">
                <mi>H</mi>
                <mi>a</mi>
                <mi>m</mi>
              </mrow>
            </msub>
            <mo>+</mo>
            <mi>&#x3B1;</mi>
            <mo>&#x22C5;</mo>
            <msub>
              <mi>N</mi>
              <mrow data-mjx-texclass="ORD">
                <mi>V</mi>
                <mi>o</mi>
                <mi>c</mi>
                <mi>a</mi>
                <mi>b</mi>
                <mi>u</mi>
                <mi>l</mi>
                <mi>a</mi>
                <mi>r</mi>
                <mi>y</mi>
              </mrow>
            </msub>
          </mrow>
        </mfrac>
      </mtd>
    </mlabeledtr>
  </mtable>
</math>

In [60]:
parameters_spam = {unique_word:0 for unique_word in vocabulary}
parameters_ham = {unique_word:0 for unique_word in vocabulary}

for word in vocabulary:
    n_word_given_spam = spam_messages[word].sum()   
    p_word_given_spam = (n_word_given_spam + alpha) / (n_spam + alpha*n_vocabulary)
    parameters_spam[word] = p_word_given_spam
    
    n_word_given_ham = ham_messages[word].sum()  
    p_word_given_ham = (n_word_given_ham + alpha) / (n_ham + alpha*n_vocabulary)
    parameters_ham[word] = p_word_given_ham

## Classifying A New Message

Now that we have all of our parameters calculated, we can start creating the spam filter. The spam filter can be understood as a function that:

- Takes in as input a new message (w1, w2, ..., wn).
- Calculates P(Spam|w1, w2, ..., wn) and P(Ham|w1, w2, ..., wn).
- Compares the values of P(Spam|w1, w2, ..., wn) and P(Ham|w1, w2, ..., wn), and:
- If P(Ham|w1, w2, ..., wn) > P(Spam|w1, w2, ..., wn), then the message is classified as ham.
- If P(Ham|w1, w2, ..., wn) < P(Spam|w1, w2, ..., wn), then the message is classified as spam.
- If P(Ham|w1, w2, ..., wn) = P(Spam|w1, w2, ..., wn), then the algorithm may request human help.

In [61]:
def classify(message):
    '''
    message: a string
    '''
    
    message = re.sub('\W', ' ', message)
    message = message.lower().split()
    
    p_spam_given_message = p_spam
    p_ham_given_message = p_ham

    for word in message:
        if word in parameters_spam:
            p_spam_given_message *= parameters_spam[word]
            
        if word in parameters_ham:
            p_ham_given_message *= parameters_ham[word]
            
    print('P(Spam|message):', p_spam_given_message)
    print('P(Ham|message):', p_ham_given_message)
    
    if p_ham_given_message > p_spam_given_message:
        print('Label: Ham')
    elif p_ham_given_message < p_spam_given_message:
        print('Label: Spam')
    else:
        print('Equal proabilities, have a human classify this!')

In [62]:
classify('WINNER!! This is the secret code to unlock the money: C3421.')

P(Spam|message): 1.3481290211300841e-25
P(Ham|message): 1.9368049028589875e-27
Label: Spam


In [63]:
classify("Sounds good, Tom, then see u there")

P(Spam|message): 2.4372375665888117e-25
P(Ham|message): 3.687530435009238e-21
Label: Ham


## Measuring The Spam Filter's Accuracy

The results above look promising, but let's see how well the filter does on our test set, which has 1,114 messages. We'll start by writing a function that returns classification labels instead of printing them.

In [64]:
def classify_test_set(message):    
    '''
    message: a string
    '''
    
    message = re.sub('\W', ' ', message)
    message = message.lower().split()
    
    p_spam_given_message = p_spam
    p_ham_given_message = p_ham

    for word in message:
        if word in parameters_spam:
            p_spam_given_message *= parameters_spam[word]
            
        if word in parameters_ham:
            p_ham_given_message *= parameters_ham[word]
    
    if p_ham_given_message > p_spam_given_message:
        return 'ham'
    elif p_spam_given_message > p_ham_given_message:
        return 'spam'
    else:
        return 'needs human classification'

Now that we have a function that returns labels instead of printing them, we can use it to create a new column in our test set.

In [65]:
test_set['predicted'] = test_set['SMS'].apply(classify_test_set)
test_set.head()

Unnamed: 0,Label,SMS,predicted
0,ham,Later i guess. I needa do mcat study too.,ham
1,ham,But i haf enuff space got like 4 mb...,ham
2,spam,Had your mobile 10 mths? Update to latest Oran...,spam
3,ham,All sounds good. Fingers . Makes it difficult ...,ham
4,ham,"All done, all handed in. Don't know if mega sh...",ham


Now, we'll write a function to measure the accuracy of our spam filter to find out how well our spam filter does.

In [66]:
correct = 0
total = test_set.shape[0]
    
for row in test_set.iterrows():
    row = row[1]
    if row['Label'] == row['predicted']:
        correct += 1
        
print('Correct:', correct)
print('Incorrect:', total - correct)
print('Accuracy:', correct/total)

Correct: 1100
Incorrect: 14
Accuracy: 0.9874326750448833


The accuracy is close to 98.74%, which is good. Our spam filter looked at 1,114 messages that it hadn't seen in training and classified 1,100 correctly. An accuracy of almost 99% is pretty darn good, but I will also try using Logistic Regression and compare it against Naive Bayes.

In [67]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

training_set['SMS'] = training_set['SMS'].apply(lambda x: ' '.join(x))
test_set['SMS'] = test_set['SMS'].apply(lambda x: ' '.join(x))

vectorizer = CountVectorizer()
X_train = vectorizer.fit_transform(training_set['SMS'])
X_test = vectorizer.transform(test_set['SMS'])

label_encoder = LabelEncoder()
y_train = label_encoder.fit_transform(training_set['Label'])
y_test = label_encoder.transform(test_set['Label'])

log_reg = LogisticRegression()
log_reg.fit(X_train, y_train)

y_pred = log_reg.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Logistic Regression Model Accuracy: {accuracy}")

Logistic Regression Model Accuracy: 0.8680430879712747


So when comparing the Naive Bayes method which gives roughly a 99% filter accuracy and then looking at the Logistic Regression model accuracy of roughly 87%, the clear winner model wise is the Naive Bayes model. The reason for this more likely than not is that Naive Bayes is typically more suited for spam filtering tasks such as this one. Naive Bayes typically will assume that all features are independent of each other given the class label of either "ham" or "spam". This works well with text classification which is what this dataset is focused on. So while Logistic Regression isn't bad, Naive Bayes is the way to go when it comes to datasets trying to filter text.

## Conclusion 
In this project, we built a spam filter for SMS messages using the multinomial Naive Bayes algorithm as well as filtering spam messages with Logistic Regression as a comparison. The filter had an accuracy of 98.74% on the test set we used with Naive Bayes, which is a really good result. Logistic Regression had about 87% accuracy. Our initial goal was an accuracy of over 90%, and we managed to do way better than that with the Naive Bayes model.