#Stochastic Gradient Descent

In this module, we will see the classification of the MNIST dataset using a logistic regression classifier. 

First we will import the packages needed for classification in sklearn.linear_model. The final accuracy is measured using the sklearn.metrics. Luckily, Keras already provides the MNIST dataset in keras.datasets.mnist.

In [0]:
from keras.datasets import mnist
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

Using TensorFlow backend.


Now, we will load the datasets into four variables.

In [0]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

We will reshape the data so that it can be classified by the model.

In [0]:
image_vector = 28*28
x_train = x_train.reshape(x_train.shape[0], image_vector)
x_test = x_test.reshape(x_test.shape[0], image_vector)

We will now load the Stochastic Gradient Descent model.

In [0]:
model = SGDClassifier(loss="log", max_iter=100)

Now we train.

In [0]:
model.fit(x_train, y_train)



SGDClassifier(alpha=0.0001, average=False, class_weight=None,
       early_stopping=False, epsilon=0.1, eta0=0.0, fit_intercept=True,
       l1_ratio=0.15, learning_rate='optimal', loss='log', max_iter=100,
       n_iter=None, n_iter_no_change=5, n_jobs=None, penalty='l2',
       power_t=0.5, random_state=None, shuffle=True, tol=None,
       validation_fraction=0.1, verbose=0, warm_start=False)

Now to validate the model created in the last cell.

In [0]:
y_pred = model.predict(x_test)

Now, to display the result.

In [0]:
print("Accuracy:\n", accuracy_score(y_test, y_pred))
print("Confusion Matrix =\n", confusion_matrix(y_test, y_pred))
print("Classification Report =\n", classification_report(y_test, y_pred))

Accuracy:
 0.8844
Confusion Matrix =
 [[ 967    0    0    4    0    0    4    2    2    1]
 [   1 1119    3    2    0    1    4    1    1    3]
 [  25   18  901   16    8    6   29   15   10    4]
 [   9    0   23  899    3   32    6   13   10   15]
 [  10    3    2    4  879    0   14    4    7   59]
 [  22    5    2   37    7  758   28    4   15   14]
 [  11    3    2    2    3    8  926    1    1    1]
 [   6    8   21    2    7    2    1  934    3   44]
 [  51   45   18   34   12  135   30   13  576   60]
 [  16    6    1   10   44   11    1   32    3  885]]
Classification Report =
               precision    recall  f1-score   support

           0       0.86      0.99      0.92       980
           1       0.93      0.99      0.96      1135
           2       0.93      0.87      0.90      1032
           3       0.89      0.89      0.89      1010
           4       0.91      0.90      0.90       982
           5       0.80      0.85      0.82       892
           6       0.89    