In [6]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import (GaussianNB, MultinomialNB)
from sklearn.metrics import classification_report
from time import perf_counter

from converters import convert_from_file

In [7]:
x_train = convert_from_file('train-images.idx3-ubyte')
y_train = convert_from_file('train-labels.idx1-ubyte')
x_train = x_train.reshape(x_train.shape[0], -1)

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.1, random_state=42)

x_test = convert_from_file('t10k-images.idx3-ubyte')
y_test = convert_from_file('t10k-labels.idx1-ubyte')
x_test = x_test.reshape(x_test.shape[0], -1)

In [8]:
gauss_clf = GaussianNB()
tic = perf_counter()
gauss_clf.fit(x_train, y_train)
toc = perf_counter()
score = gauss_clf.score(x_val, y_val)
print('Gaussian Naive Bayes classifier achieved accuracy of %.2f%%' % (score * 100))
print('Spent %.4f seconds' % (toc - tic))

Gaussian Naive Bayes classifier achieved accuracy of 56.40%
Spent 1.0687 seconds


In [9]:
# multinomial classifier with Laplace smoothing
multinomial_clf = MultinomialNB(alpha=1.0)
tic = perf_counter()
multinomial_clf.fit(x_train, y_train)
toc = perf_counter()
score = multinomial_clf.score(x_val, y_val)
print('Multinomial Naive Bayes classifier achieved accuracy of %.2f%%' % (score * 100))
print('Spent %.4f seconds' % (toc - tic))

Multinomial Naive Bayes classifier achieved accuracy of 83.00%
Spent 0.5951 seconds


In [10]:
# multinomial classifier is better, so run it on test set
predictions = multinomial_clf.predict(x_test)
print(classification_report(y_test, predictions))

              precision    recall  f1-score   support

           0       0.92      0.93      0.93       980
           1       0.91      0.93      0.92      1135
           2       0.90      0.83      0.86      1032
           3       0.80      0.85      0.82      1010
           4       0.85      0.75      0.79       982
           5       0.86      0.66      0.75       892
           6       0.89      0.90      0.89       958
           7       0.94      0.84      0.88      1028
           8       0.66      0.80      0.73       974
           9       0.71      0.85      0.78      1009

    accuracy                           0.84     10000
   macro avg       0.84      0.83      0.84     10000
weighted avg       0.85      0.84      0.84     10000

