In [2]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.metrics import plot_confusion_matrix
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.model_selection import train_test_split
from sklearn.model_selection import validation_curve
from sklearn.svm import SVC

In [6]:
X, y = fetch_openml('mnist_784', return_X_y=True)

In [7]:
X_train, X_test, y_train, y_test = train_test_split(X, y)

In [8]:
X_train = X_train/255.0
X_test = X_test/255.0

In [10]:
#default is rbf
svm = SVC(C=4)
svm.fit(X_train, y_train)

SVC(C=4)

In [14]:
svm.score(X_train, y_train)

0.9987809523809524

In [11]:
svm.score(X_test, y_test)

0.9826857142857143

In [12]:
svm_linear = SVC(kernel='linear')
svm_linear.fit(X_train, y_train)

SVC(kernel='linear')

In [15]:
svm_linear.score(X_train, y_train)

0.9731238095238095

In [13]:
svm_linear.score(X_test, y_test)

0.9344571428571429

In [9]:
parameters = {'C':[1, 10, 100], 
             'gamma': [1e-2, 1e-3, 1e-4]}

# instantiate a model 
svc_grid_search = SVC(kernel="rbf")

# create a classifier to perform grid search
clf = GridSearchCV(svc_grid_search, param_grid=parameters, scoring='accuracy', n_jobs=-1)

# fit
clf.fit(X_train, y_train)

KeyboardInterrupt: 

In [None]:
cv_results = pd.DataFrame(clf.cv_results_)
cv_results

In [None]:
# code adapted from https://github.com/akshayr89/MNSIST_Handwritten_Digit_Recognition-SVM/blob/master/MNIST_Handwritten_Digit_Recognition-SVM.ipynb
# converting C to numeric type for plotting on x-axis
cv_results['param_C'] = cv_results['param_C'].astype('int')

# # plotting
plt.figure(figsize=(16,6))

# subplot 1/3
plt.subplot(131)
gamma_01 = cv_results[cv_results['param_gamma']==0.01]

plt.plot(gamma_01["param_C"], gamma_01["mean_test_score"])
plt.plot(gamma_01["param_C"], gamma_01["mean_train_score"])
plt.xlabel('C')
plt.ylabel('Accuracy')
plt.title("Gamma=0.01")
plt.ylim([0.60, 1])
plt.legend(['test accuracy', 'train accuracy'], loc='lower right')
plt.xscale('log')

# subplot 2/3
plt.subplot(132)
gamma_001 = cv_results[cv_results['param_gamma']==0.001]

plt.plot(gamma_001["param_C"], gamma_001["mean_test_score"])
plt.plot(gamma_001["param_C"], gamma_001["mean_train_score"])
plt.xlabel('C')
plt.ylabel('Accuracy')
plt.title("Gamma=0.001")
plt.ylim([0.60, 1])
plt.legend(['test accuracy', 'train accuracy'], loc='lower right')
plt.xscale('log')


# subplot 3/3
plt.subplot(133)
gamma_0001 = cv_results[cv_results['param_gamma']==0.0001]

plt.plot(gamma_0001["param_C"], gamma_0001["mean_test_score"])
plt.plot(gamma_0001["param_C"], gamma_0001["mean_train_score"])
plt.xlabel('C')
plt.ylabel('Accuracy')
plt.title("Gamma=0.0001")
plt.ylim([0.60, 1])
plt.legend(['test accuracy', 'train accuracy'], loc='lower right')
plt.xscale('log')

plt.show()

In [None]:
# Mean and Standard Deviation of training scores 
mean_training = np.mean(training_scores, axis=1) 
Standard_Deviation_training = np.std(training_scores, axis=1) 
  
# Mean and Standard Deviation of testing scores 
mean_testing = np.mean(testing_scores, axis=1) 
Standard_Deviation_testing = np.std(testing_scores, axis=1) 
  
# dotted blue line is for training scores and green line is for cross-validation score 
plt.plot(sizes, mean_training, '--', color="b",  label="Training score") 
plt.plot(sizes, mean_testing, color="g", label="Cross-validation score") 
  
# Drawing plot 
plt.title("LEARNING CURVE FOR KNN Classifier") 
plt.xlabel("Training Set Size"), plt.ylabel("Accuracy Score"), plt.legend(loc="best") 
plt.tight_layout() 
plt.show()

In [3]:
np.logspace(-4, -2, 100)

array([0.0001    , 0.00010476, 0.00010975, 0.00011498, 0.00012045,
       0.00012619, 0.00013219, 0.00013849, 0.00014508, 0.00015199,
       0.00015923, 0.00016681, 0.00017475, 0.00018307, 0.00019179,
       0.00020092, 0.00021049, 0.00022051, 0.00023101, 0.00024201,
       0.00025354, 0.00026561, 0.00027826, 0.00029151, 0.00030539,
       0.00031993, 0.00033516, 0.00035112, 0.00036784, 0.00038535,
       0.0004037 , 0.00042292, 0.00044306, 0.00046416, 0.00048626,
       0.00050941, 0.00053367, 0.00055908, 0.0005857 , 0.00061359,
       0.00064281, 0.00067342, 0.00070548, 0.00073907, 0.00077426,
       0.00081113, 0.00084975, 0.00089022, 0.0009326 , 0.00097701,
       0.00102353, 0.00107227, 0.00112332, 0.00117681, 0.00123285,
       0.00129155, 0.00135305, 0.00141747, 0.00148497, 0.00155568,
       0.00162975, 0.00170735, 0.00178865, 0.00187382, 0.00196304,
       0.00205651, 0.00215443, 0.00225702, 0.00236449, 0.00247708,
       0.00259502, 0.00271859, 0.00284804, 0.00298365, 0.00312