In [1]:
import math
import random
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [8]:
class NaiveBayesClassifier:
    def __init__(self):
        # Explicitly initializing both spam and not_spam to avoid KeyError
        self.word_counts = {'spam': {}, 'not_spam': {}}
        self.p_spam = 0
        self.p_non_spam = 0

    def train(self, data):
        total_spam = 0
        total_non_spam = 0

        for email, label in data:
            if label == 'spam':
                total_spam += 1
            elif label == 'not_spam':  # Ensure this label matches exactly with your data
                total_non_spam += 1

            # Ensure this loop is properly aligned to handle all labels
            if label not in self.word_counts:
                self.word_counts[label] = {}
            for word in email.split():
                if word not in self.word_counts[label]:
                    self.word_counts[label][word] = 0
                self.word_counts[label][word] += 1

        self.p_spam = total_spam / len(data)
        self.p_non_spam = total_non_spam / len(data)

    def predict(self, email):
        spam_prob = self.p_spam
        non_spam_prob = self.p_non_spam

        for word in email.split():
            # Handling missing words with Laplace smoothing for spam
            spam_prob *= (self.word_counts['spam'].get(word, 0) + 1) / \
                         (sum(self.word_counts['spam'].values()) + len(self.word_counts['spam']))

            # Handling missing words with Laplace smoothing for not_spam
            non_spam_prob *= (self.word_counts['not_spam'].get(word, 0) + 1) / \
                             (sum(self.word_counts['not_spam'].values()) + len(self.word_counts['not_spam']))

        return 'spam' if spam_prob > non_spam_prob else 'not_spam'

In [9]:
emails = [
    ("win free money now", "spam"),
    ("low price for valued customer", "spam"),
    ("meet me at noon", "not_spam"),
    ("this is your captain speaking", "not_spam")
]

classifier = NaiveBayesClassifier()
classifier.train(emails)

# Test with a new email
test_email = "free money offer"
print(classifier.predict(test_email))  # Expected output: 'spam'


spam
