In [3]:
from mnist import *
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import cross_val_predict
from sklearn.multiclass import OneVsOneClassifier
import warnings

warnings.simplefilter("ignore")

training_set_path = "D:\\Projects\\ml-experiments\\datasets\\mnist\\train-images-idx3-ubyte.gz"
labels_path = "D:\\Projects\\ml-experiments\\datasets\\mnist\\train-labels-idx1-ubyte.gz"

f_train = gzip.open(training_set_path)
f_labels = gzip.open(labels_path)

training_set = parse_idx(f_train)
labels = parse_idx(f_labels)

training_set_tr = training_set.reshape((60000, 784))

**Some utility function to reuse throughout experiment**

In [6]:
def get_random_digit(training_set, labels, digit):
    indexes = np.where(labels == digit)[0]
    return training_set[indexes[np.random.randint(0, len(indexes) - 1)]]

**Scikit-learn is smart enough to detect when you try to use a binary classification algorithm such as SGD on a multiclass classification task (when the labels are not binary) and automatically runs OvA strategy (trains N binary classifiers, one for each class) except for SVM for which it runs OvO (trains $\frac{N(N-1)}{2}$ binary classifiers, one between 0 and 1, one between 1 and 2 etc)**

In [8]:
sgd_classifier = SGDClassifier(random_state=77)
sgd_classifier.fit(training_set_tr, labels)

seven = get_random_digit(training_set_tr, labels, 7)
print(f"The digit is:{sgd_classifier.predict([seven])}")

The digit is:[7]


**Get the classifier to return the decision scores for each class rather than a prediction. The class with the higher score is used for prediction**

In [9]:
scores = sgd_classifier.decision_function([seven])
print(f"The decision scores for the digit are: {scores}")

The decision scores for the digit are: [[-396080.43900624 -891881.56024155 -513824.30984111 -123079.68175448
  -446437.14777247 -396935.24971248 -973280.38077838   60582.85570627
  -279833.19815629 -311486.91345771]]


**Can also force Scikit-Learn to use the SGDClassifier with OvO strategy**

In [10]:
ovo = OneVsOneClassifier(sgd_classifier)
ovo.fit(training_set_tr, labels)
print("OvO: The digit is:", ovo.predict([seven]))

OvO: The digit is: [7]


**Random Forest algorithm can also be used for classification (besides regression - RandomForestRegressor) and is a multiclass algorithm so no need for OvA or OvO strategies**

In [11]:
rnd_forest = RandomForestClassifier()
rnd_forest.fit(training_set_tr, labels)
print(f"Random Forest: The digit is:{rnd_forest.predict([seven])}")
print(f"Random Forest: Probabilities:{rnd_forest.predict_proba([seven])}")

Random Forest: The digit is:[7]
Random Forest: Probabilities:[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]


**Evaluate SGD Classifier vs Random Forest based on confusion matrix**

In [13]:
sgd_predictions = cross_val_predict(sgd_classifier, training_set_tr, labels, cv=3)
rnd_forest_predictions = cross_val_predict(rnd_forest, training_set_tr, labels, cv=3)

print("SGD Classifier Confusion Matrix:")
print(confusion_matrix(labels, sgd_predictions))
print()
print("Random Forest Classifier Confusion Matrix:")
print(confusion_matrix(labels, rnd_forest_predictions))

SGD Classifier Confusion Matrix:
[[5671    3   22   11   30   31   34   10  104    7]
 [   1 6447   38   24    6   23   10   11  159   23]
 [ 107   77 4673  189  101   41  123   98  509   40]
 [  56   33  106 5254   14  158   23   54  349   84]
 [  24   24   23    6 4823   10   59   31  262  580]
 [ 128   33   36  271   81 4140   75   20  454  183]
 [  80   25   49    8  107  133 5404   11   96    5]
 [  32   24   33   34   68   10    4 5599   76  385]
 [  87  132   37  160   32  199   36   30 4965  173]
 [  39   41   24   85  136   55    2  325  313 4929]]

Random Forest Classifier Confusion Matrix:
[[5813    3   23   12    6   16   20    2   25    3]
 [   1 6613   42   18   15    5   11   13   12   12]
 [  55   38 5612   46   44   13   27   48   62   13]
 [  30   20  147 5626   14  115    3   44   90   42]
 [  20   19   24   16 5539    7   33   17   34  133]
 [  42   18   21  179   26 4989   52   11   54   29]
 [  48   15   22    9   28   71 5699    1   25    0]
 [   9   44   97   25

**Random Forest generally seems better - higher values on the main diagonal!**