In [1]:
from sklearn.datasets import fetch_20newsgroups
import numpy as np

# Loading the datasets
train = fetch_20newsgroups(subset='train', shuffle=True)
test = fetch_20newsgroups(subset='test', shuffle=True)

# Categories
categories = train.target_names

# Train dataset, casting to numpy array
train_data = np.array(train.data)
train_target = np.array(train.target)
train_size = len(train_data)

# Test dataset, casting to numpy array
test_data = np.array(test.data)
test_target = np.array(test.target)
test_size = len(test_data)

# Logging useful information
print(f'Dataset Train: {train_size} elements')
print(f'Dataset Test: {test_size} elements')

Dataset Train: 11314 elements
Dataset Test: 7532 elements


In [52]:
from sklearn.feature_extraction.text import CountVectorizer

# Preprocessing the data from the origina dataset
vectorizer = CountVectorizer(analyzer='word', stop_words='english')
train_processed_data = vectorizer.fit_transform(train_data)
test_processed_data = vectorizer.transform(test_data)

In [53]:
from sklearn.naive_bayes import MultinomialNB

# Training the multinomial naive bayes model
classifier = MultinomialNB(alpha=1)
classifier.fit(train_processed_data, train_target)

MultinomialNB(alpha=1)

In [72]:
# Validating the model
predictions = classifier.predict(test_processed_data)

# Calculating the accuracy obtained from the model
accuracy = (predictions == test_target).sum() / len(test_data)

print(f'Accuracy obtained: {accuracy}')

Accuracy obtained: 0.8023101433882103


In [67]:
print(classifier.intercept_)
print(classifier.coef_)
print(classifier.coef_.shape)

[-3.16001007 -2.96389519 -2.95198016 -2.95367364 -2.97422231 -2.94860178
 -2.96218433 -2.94691686 -2.94020542 -2.94187906 -2.93686652 -2.94523477
 -2.95198016 -2.94691686 -2.94860178 -2.93853458 -3.0311772  -2.99874192
 -3.19175877 -3.40155099]
[[-10.88622592  -8.71717222 -12.27252028 ... -12.27252028 -12.27252028
  -12.27252028]
 [ -8.71956011  -9.38453641 -12.27490817 ... -12.27490817 -12.27490817
  -12.27490817]
 [ -9.29511395 -10.4235792  -12.82147448 ... -12.82147448 -12.82147448
  -12.82147448]
 ...
 [ -8.99556976  -7.18921149 -11.8287831  ... -12.52193028 -12.52193028
  -12.52193028]
 [ -9.01225522  -8.11687117 -12.37955105 ... -12.37955105 -12.37955105
  -12.37955105]
 [-10.58731645  -9.30638261 -12.19675436 ... -12.19675436 -12.19675436
  -12.19675436]]
(20, 129796)
