# Chapter 13 - Naive Bayes Spam Filter

In [1]:
import re
from collections import Counter, defaultdict
import math, random, re

In [2]:
def tokenize(message):
    message = message.lower()
    all_words = re.findall("[a-z0-9']+", message)
    return set(all_words)

In [3]:
tokenize('The quick brown fox jumped over R2-D2, C3PO and 1729 lazy dogs!!!!!')

{'1729',
 'and',
 'brown',
 'c3po',
 'd2',
 'dogs',
 'fox',
 'jumped',
 'lazy',
 'over',
 'quick',
 'r2',
 'the'}

In [4]:
def count_words(training_set):
    """
    training set consists of pairs (message, is_spam)
    """
    counts = defaultdict(lambda: [0, 0])
    for message, is_spam in training_set:
        for word in tokenize(message):
            counts[word][0 if is_spam else 1] += 1
    return counts

In [5]:
def word_probabilities(counts, total_spams, total_non_spams, k=0.5):
    """
    turn the word_counts into a list of triplets
    w, p(w | spam) and p(w | ~spam)
    """
    return [(w,
             (spam + k) / (total_spams + 2 * k),
             (non_spam + k) / (total_non_spams + 2 * k))
            for w, (spam, non_spam) in counts.items()]

In [6]:
def spam_probability(word_probs, message):
    message_words = tokenize(message)
    log_prob_if_spam = log_prob_if_not_spam = 0.0

    for word, prob_if_spam, prob_if_not_spam in word_probs:

        # for each word in the message,
        # add the log probability of seeing it
        if word in message_words:
            log_prob_if_spam += math.log(prob_if_spam)
            log_prob_if_not_spam += math.log(prob_if_not_spam)

        # for each word that's not in the message
        # add the log probability of _not_ seeing it
        else:
            log_prob_if_spam += math.log(1.0 - prob_if_spam)
            log_prob_if_not_spam += math.log(1.0 - prob_if_not_spam)

    prob_if_spam = math.exp(log_prob_if_spam)
    prob_if_not_spam = math.exp(log_prob_if_not_spam)
    return prob_if_spam / (prob_if_spam + prob_if_not_spam)

In [7]:
class NaiveBayesClassifier:

    def __init__(self, k=0.5):
        self.k = k
        self.word_probs = []

    def train(self, training_set):

        # count spam and non-spam messages
        num_spams = sum(is_spam for message,is_spam in training_set)
        num_non_spams = len(training_set) - num_spams

        # run training data through our "pipeline"
        word_counts = count_words(training_set)
        self.word_probs = word_probabilities(word_counts,
                                             num_spams,
                                             num_non_spams,
                                             self.k)

    def classify(self, message):
        return spam_probability(self.word_probs, message)

In [8]:
def p_spam_given_word(word_prob):
    word, prob_if_spam, prob_if_not_spam = word_prob
    return prob_if_spam / (prob_if_spam + prob_if_not_spam)

In [9]:
def split_data(data, prob):
    """
    split data into fractions [prob, 1 - prob]
    """
    results = [], []
    for row in data:
        results[0 if random.random() < prob else 1].append(row)
    return results

We'll use data from the [SMS Spam Collection Data Set](https://archive.ics.uci.edu/ml/datasets/sms+spam+collection)

In [10]:
data = []
with open('../data/sms-spam/SMSSpamCollection', 'rt') as f:
    for line in f:
        label, msg = line.strip().split('\t')
        data.append((msg, label=='spam'))

In [11]:
train_data, test_data = split_data(data, 0.75)

In [12]:
len(train_data), len(test_data)

(4131, 1443)

In [13]:
classifier = NaiveBayesClassifier()
classifier.train(train_data)

In [14]:
classifier.classify('The quick brown fox jumped over R2-D2, C3PO and 1729 lazy dogs!!!!!')

1.0775136697655685e-06

In [15]:
classifier.classify('Claim your prize!!! cheap viagra!!! live girls!!')

0.9388618227701127

In [16]:
classifier.classify('Hi honey. I will be late 4 dinner. love u')

1.3040583780535663e-10

In [17]:
counts = Counter()
spam_probs = []
for message, is_spam in test_data:
    spam_prob = classifier.classify(message)
    counts[(is_spam, spam_prob>0.5)] += 1
    spam_probs.append((message, is_spam, spam_prob))

In [28]:
len(test_data), sum(is_spam for message, is_spam in test_data)

(1443, 203)

In [18]:
tp = counts[(True, True)]
fp = counts[(False, True)]
tn = counts[(False, False)]
fn = counts[(True, False)]
tp, fp, tn, fn

(186, 0, 1240, 17)

In [19]:
# precision
tp / (tp + fp)

1.0

In [20]:
# recall
tp / (tp + fn)

0.916256157635468

In [21]:
spam_probs.sort(key=lambda row: row[2])

In [22]:
# spammiest_hams
[row for row in spam_probs if not row[1]][-5:]

[('Save yourself the stress. If the person has a dorm account, just send your account details and the money will be sent to you.',
  False,
  0.0005594627763162863),
 ('"Keep ur problems in ur heart, b\'coz nobody will fight for u. Only u &amp; u have to fight for ur self &amp; win the battle. -VIVEKANAND- G 9t.. SD..',
  False,
  0.0011252708587743464),
 ('Enjoy the showers of possessiveness poured on u by ur loved ones, bcoz in this world of lies, it is a golden gift to be loved truly..',
  False,
  0.001239572799324235),
 ('We are pleased to inform that your application for Airtel Broadband is processed successfully. Your installation will happen within 3 days.',
  False,
  0.0021193917161031715),
 ('Hi Chachi tried calling u now unable to reach u .. Pl give me a missed cal once u c tiz msg  Kanagu',
  False,
  0.004207961863632256)]

In [23]:
# hammiest_spams
[row for row in spam_probs if row[1]][:5]

[("Do you ever notice that when you're driving, anyone going slower than you is an idiot and everyone driving faster than you is a maniac?",
  True,
  8.331605444750579e-13),
 ('Did you hear about the new "Divorce Barbie"? It comes with all of Ken\'s stuff!',
  True,
  1.5479767138222634e-09),
 ('Hello. We need some posh birds and chaps to user trial prods for champneys. Can i put you down? I need your address and dob asap. Ta r',
  True,
  7.477951002996785e-08),
 ('For sale - arsenal dartboard. Good condition but no doubles or trebles!',
  True,
  3.124560373960676e-06),
 ("Oh my god! I've found your number again! I'm so glad, text me back xafter this msgs cst std ntwk chg £1.50",
  True,
  4.030970628060635e-06)]

In [24]:
words = sorted(classifier.word_probs, key=p_spam_given_word)

In [25]:
#spammiest_words
words[-20:]

[('sae', 0.03211009174311927, 0.00013935340022296544),
 ('rate', 0.03394495412844037, 0.00013935340022296544),
 ('800', 0.03394495412844037, 0.00013935340022296544),
 ('000', 0.03577981651376147, 0.00013935340022296544),
 ('mob', 0.03761467889908257, 0.00013935340022296544),
 ('code', 0.03944954128440367, 0.00013935340022296544),
 ('ringtone', 0.04128440366972477, 0.00013935340022296544),
 ('150ppm', 0.04128440366972477, 0.00013935340022296544),
 ('collection', 0.04128440366972477, 0.00013935340022296544),
 ('awarded', 0.04862385321100918, 0.00013935340022296544),
 ('1000', 0.05412844036697248, 0.00013935340022296544),
 ('cs', 0.05412844036697248, 0.00013935340022296544),
 ('tone', 0.05779816513761468, 0.00013935340022296544),
 ('500', 0.05963302752293578, 0.00013935340022296544),
 ('18', 0.06697247706422019, 0.00013935340022296544),
 ('guaranteed', 0.06697247706422019, 0.00013935340022296544),
 ('won', 0.0908256880733945, 0.00013935340022296544),
 ('prize', 0.10550458715596331, 0.0001

In [26]:
# hammiest words
words[:20]

[('gt', 0.0009174311926605505, 0.04891304347826087),
 ('lt', 0.0009174311926605505, 0.048634336677814936),
 ("i'll", 0.0009174311926605505, 0.036371237458193977),
 ('he', 0.0009174311926605505, 0.029124860646599776),
 ('later', 0.0009174311926605505, 0.028846153846153848),
 ('lor', 0.0009174311926605505, 0.028288740245261984),
 ('da', 0.0009174311926605505, 0.0266164994425864),
 ('oh', 0.0009174311926605505, 0.022157190635451504),
 ('she', 0.0009174311926605505, 0.021321070234113712),
 ('wat', 0.0009174311926605505, 0.020206243032329988),
 ('doing', 0.0009174311926605505, 0.01797658862876254),
 ('ask', 0.0009174311926605505, 0.017697881828316612),
 ('lol', 0.0009174311926605505, 0.015468227424749164),
 ('morning', 0.0009174311926605505, 0.015468227424749164),
 ('said', 0.0009174311926605505, 0.015468227424749164),
 ('gud', 0.0009174311926605505, 0.014910813823857302),
 ('sure', 0.0009174311926605505, 0.014632107023411372),
 ('anything', 0.0009174311926605505, 0.014632107023411372),
 ('