In [37]:
pip install scikit-learn

Note: you may need to restart the kernel to use updated packages.


In [38]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.pipeline import make_pipeline
from sklearn.datasets import fetch_20newsgroups

In [39]:
data = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))
emails = data.data
labels = data.target

In [40]:
X_train, X_test, y_train, y_test = train_test_split(emails, labels, test_size=0.2, random_state=42)

In [41]:
model = make_pipeline(CountVectorizer(), MultinomialNB())

In [42]:
model.fit(X_train, y_train)

Pipeline(steps=[('countvectorizer', CountVectorizer()),
                ('multinomialnb', MultinomialNB())])

In [44]:
predictions = model.predict(X_test)

In [45]:
accuracy = accuracy_score(y_test, predictions)
conf_matrix = confusion_matrix(y_test, predictions)

In [46]:
print(f"Accuracy: {accuracy}")
print("Confusion Matrix:")
print(conf_matrix)

Accuracy: 0.6175066312997347
Confusion Matrix:
[[ 38   1   0   0   0   0   0   1   5   0   0   1   0   1   3  75   2  13
    8   3]
 [  1 151   0   9   1   6   1   1   6   0   1   8   0   2   5   6   1   2
    1   0]
 [  1  50   8  48   5  51   0   0   7   0   0  13   2   0   1   5   1   0
    3   0]
 [  0  16   0 134   6  10   2   0   1   0   0   5   2   1   0   4   0   1
    1   0]
 [  2  15   1  16 119   2   2   1  12   0   0  13   3   5   2   7   0   2
    3   0]
 [  0  29   0   5   0 171   0   0   2   0   0   0   0   1   1   4   0   1
    1   0]
 [  0  12   0  26   5   3 103   4   2   0   1  11   6   3   4   5   1   3
    4   0]
 [  1   3   0   0   0   1   2 124  10   0   1   8   3   0   4   8   1  12
   18   0]
 [  0   2   0   0   0   1   4   8  98   1   1   5   0   1   3  11   3  17
   13   0]
 [  0   2   0   0   0   0   0   0  10 142   6   2   0   1   0  23   0  15
   10   0]
 [  1   1   0   0   0   0   0   0   5   0 159   2   0   2   0  13   0   5
   10   0]
 [  0   5   1   0 