# MNIST handwritten digits classification with support vector machines 

In this notebook, we'll use [support vector machines](http://scikit-learn.org/stable/modules/svm.html#svm-classification) to classify MNIST digits using scikit-learn.

First, the needed imports. 

In [None]:
%matplotlib inline

from time import time
import numpy as np
from sklearn import svm

import matplotlib.pyplot as plt
import seaborn as sns

Then we load the MNIST data. First time it downloads the data, which can take a while.

In [None]:
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()

print()
print('MNIST data loaded: train:',len(X_train),'test:',len(X_test))
print('X_train:', X_train.shape)
print('y_train:', y_train.shape)
print('X_test', X_test.shape)
print('y_test', y_test.shape)

## Linear SVM

Let's first train a linear SVM with a subset of training data:

In [None]:
t0 = time()
clf = svm.LinearSVC(max_iter=5000)
clf.fit(X_train[:5000,:,:].reshape(-1,28*28), y_train[:5000])
print('Time elapsed: %.2fs' % (time()-t0))

In [None]:
def accuracy(pred):
    plen = len(pred)
    errors = pred!=y_test[:plen]
    nerrors = np.sum(errors)
    return (plen-nerrors)/plen, errors 

predictions = clf.predict(X_test.reshape(-1,28*28))
acc, _ = accuracy(predictions)
print('Predicted', len(predictions), 'digits with accuracy:', acc)