In [183]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

df = pd.read_csv('spam_ham_dataset.csv')

In [185]:
train_data, test_data, train_labels, test_labels = train_test_split(df['text'], df['label'], test_size=0.2, random_state=42)

In [162]:
def preprocess(text):
    text = text.lower()
    text = ''.join([i for i in text if i.isalpha() or i.isspace()])
    words = text.split()
    return words

In [163]:
train_data = train_data.apply(preprocess)

In [164]:
spam_count = train_labels.value_counts()['spam']
ham_count = train_labels.value_counts()['ham']
total_count = len(train_labels)

p_spam = spam_count / total_count
p_ham = ham_count / total_count

In [165]:
train_words = [word for email in train_data for word in email]

vocabulary = set(train_words)

spam_word_count = len([word for email in train_data[train_labels == 'spam'] for word in email])
ham_word_count = len([word for email in train_data[train_labels == 'ham'] for word in email])

word_probs_spam = {}
word_probs_ham = {}
alpha = 1

for word in vocabulary:
    word_count_spam = sum([email.count(word) for email in train_data[train_labels == 'spam']])
    word_count_ham = sum([email.count(word) for email in train_data[train_labels == 'ham']])
    
    p_word_spam = (word_count_spam + alpha) / (spam_word_count + alpha * len(vocabulary))
    p_word_ham = (word_count_ham + alpha) / (ham_word_count + alpha * len(vocabulary))
    
    word_probs_spam[word] = p_word_spam
    word_probs_ham[word] = p_word_ham

In [167]:
def classify(email):
    words = preprocess(email)
    
    log_p_spam = np.log(p_spam)
    log_p_ham = np.log(p_ham)
    
    for word in words:
        if word in word_probs_spam:
            log_p_spam += np.log(word_probs_spam[word])
        else:
            log_p_spam += np.log(alpha / (spam_word_count + alpha * len(vocabulary)))
            
        if word in word_probs_ham:
            log_p_ham += np.log(word_probs_ham[word])
        else:
            log_p_ham += np.log(alpha / (ham_word_count + alpha * len(vocabulary)))
    

    if log_p_spam > log_p_ham:
        return 'spam'
    else:
        return 'ham'

In [187]:
accuracy = (predicted_labels == test_labels).mean()

print('Accuracy:', accuracy)

Accuracy: 0.975


In [171]:
classify('''Subject: Meeting on Friday

Hi all,

Just a reminder that we have a meeting scheduled for Friday at 10am in the conference room. Please be on time and come prepared with your updates.

Thanks,
John''')


'ham'

In [188]:
classify('congratualations! you won a lottery')

'spam'