In [2]:
from sklearn.datasets import fetch_20newsgroups
import numpy as np

# Loading the datasets
train = fetch_20newsgroups(subset='train', shuffle=True)
test = fetch_20newsgroups(subset='test', shuffle=True)

# Categories
categories = train.target_names

# Train dataset, casting to numpy array
train_data = np.array(train.data)
train_target = np.array(train.target)
train_size = len(train_data)

# Test dataset, casting to numpy array
test_data = np.array(test.data)
test_target = np.array(test.target)
test_size = len(test_data)

# Logging useful information
print(f'Dataset Train: {train_size} elements')
print(f'Dataset Test: {test_size} elements')

Dataset Train: 11314 elements
Dataset Test: 7532 elements


In [3]:
from sklearn.feature_extraction.text import CountVectorizer

# Preprocessing the data from the origina dataset
vectorizer = CountVectorizer(analyzer='word', stop_words='english')
train_processed_data = vectorizer.fit_transform(train_data)
test_processed_data = vectorizer.transform(test_data)

In [55]:
from sklearn.naive_bayes import MultinomialNB

# Training the multinomial naive bayes model
classifier = MultinomialNB(alpha=0.2, class_prior=None, fit_prior=True)
classifier.fit(train_processed_data, train_target)

MultinomialNB(alpha=0.2)

In [56]:
# Validating the model
predictions = classifier.predict(test_processed_data)

# Calculating the accuracy obtained from the model
accuracy = (predictions == test_target).sum() / test_size
print(f'Accuracy obtained: {accuracy}')

Accuracy obtained: 0.8096123207647371


In [57]:
print(classifier.intercept_)
print(classifier.coef_)
print(classifier.coef_.shape)

[-3.16001007 -2.96389519 -2.95198016 -2.95367364 -2.97422231 -2.94860178
 -2.96218433 -2.94691686 -2.94020542 -2.94187906 -2.93686652 -2.94523477
 -2.95198016 -2.94691686 -2.94860178 -2.93853458 -3.0311772  -2.99874192
 -3.19175877 -3.40155099]
[[-10.44421355  -8.07513871 -13.21680227 ... -13.21680227 -13.21680227
  -13.21680227]
 [ -8.07977744  -8.7670937  -13.221441   ... -13.221441   -13.221441
  -13.221441  ]
 [ -8.98961558 -10.16977774 -14.10160337 ... -14.10160337 -14.10160337
  -14.10160337]
 ...
 [ -8.54364858  -6.71735188 -11.86387689 ... -13.65563636 -13.65563636
  -13.65563636]
 [ -8.46664766  -7.55462132 -13.41540755 ... -13.41540755 -13.41540755
  -13.41540755]
 [-10.0192389   -8.60941404 -13.06376133 ... -13.06376133 -13.06376133
  -13.06376133]]
(20, 129796)
