In [1]:
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report


In [4]:
categories = ['comp.graphics', 'comp.sys.mac.hardware', 'talk.politics.misc']
data = fetch_20newsgroups(subset='all', categories=categories,
                          remove=('headers', 'footers', 'quotes'))

In [9]:
data.data[0][:100]

"\n\n\tYou don't know much about the fall of Diem's government in Vietnam.\n\tOr the traditional Indian pr"

In [10]:
y = np.array([1 if data.target_names[t] == 'talk.politics.misc' else 0
              for t in data.target])

In [12]:
vectorizer = CountVectorizer(binary=True, stop_words='english', max_features=2000)
X = vectorizer.fit_transform(data.data)
X = X.toarray()

n, V = X.shape   # n = docs, V = vocabulary size
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

In [19]:
class BernoulliNB:
    def fit(self, X, y, alpha =1.0):
        #alpha: laplace smoothing feature
        n, V = X.shape
        self.V = V

        self.phi_y = np.mean(y)

        X_y1 = X[y==1]
        self.phi_y1 = (alpha + np.sum(X_y1, axis=0)) / (V*alpha + np.sum(X_y1))

        X_y0 = X[y==0]
        self.phi_y0 = (alpha + np.sum(X_y0, axis=0)) / (V*alpha + np.sum(X_y0))

    def predict_log_prob(self, X):
        log_p_y1 = np.log(self.phi_y) + (X * np.log(self.phi_y1) + (1-X) * np.log(1 - self.phi_y1)).sum(axis=1)
        log_p_y0 = np.log(1 - self.phi_y) + (X * np.log(self.phi_y0) + (1-X) * np.log(1 - self.phi_y0)).sum(axis=1)

        return np.vstack([log_p_y0, log_p_y1]).T
    
    def predict(self, X):
        log_probs = self.predict_log_prob(X)
        return np.argmax(log_probs, axis = 1)


In [20]:
model = BernoulliNB()
model.fit(X_train, y_train, alpha=1.0)

In [21]:
y_pred = model.predict(X_test)

print("Accuracy:", accuracy_score(y_test, y_pred))
print("\nClassification Report:\n", classification_report(y_test, y_pred, target_names=["Non-Spam", "Spam"]))

Accuracy: 0.9557739557739557

Classification Report:
               precision    recall  f1-score   support

    Non-Spam       0.96      0.98      0.97       569
        Spam       0.95      0.90      0.92       245

    accuracy                           0.96       814
   macro avg       0.95      0.94      0.95       814
weighted avg       0.96      0.96      0.96       814

