In [None]:
from sklearn.datasets import fetch_openml
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.base import BaseEstimator
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, precision_recall_curve, roc_curve, roc_auc_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import StandardScaler

# Image Classification Dataset

In [None]:
mnist = fetch_openml('mnist_784',version=1)

In [None]:
mnist.keys()

In [None]:
X, y = mnist['data'],mnist['target']

In [None]:
some_digit = X[0]
some_digit_image = some_digit.reshape(28,28)

plt.imshow(some_digit_image,cmap='binary')
plt.axis('off')

In [None]:
y[0]

# Label Engineering 

In [None]:
# Cast
y = y.astype(np.uint8)

# Train/Test Split

In [None]:
# Split

X_train,X_test,y_train,y_test = X[:60000],X[60000:],y[:60000],y[60000:]

# Binary Label for 5s

In [None]:
# 5-detector

# 1 in all instances its a 5, else 0 
y_train_5 = (y_train==5)
y_test_5 = (y_test==5)

# SGD Binary Classifier for 5s

In [None]:
# Stochastic Gradient Descent without randomness

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train,y_train_5)

In [None]:
sgd_clf.predict([some_digit])

# Accuracy

In [None]:
#Cross Validation
cross_val_score(sgd_clf,X_train,y_train_5,cv=3,scoring='accuracy')

In [None]:
# How many instances are actually 5 in the entire dataset
class Never5Classifier(BaseEstimator):
    def fit(self,X,y=None):
        return self
    def predict(self,X):
        return np.zeros((len(X),1),dtype=bool) # Returns all instances with 0 (not '5')

In [None]:
never_5_clf = Never5Classifier()
cross_val_score(never_5_clf,X_train,y_train_5,cv=3,scoring='accuracy')

Accuracy overrated in skewed datasets

# Confusion Matrix

In [None]:
y_train_pred = cross_val_predict(sgd_clf,X_train,y_train_5,cv=3)

In [None]:
confusion_matrix(y_train_5,y_train_pred)

- Negative Row on Top & Negative Column First (TN,FN)

- Postive Row on Bottom & Negative COlumn First (TP,FP)

- Maximize TN & FP

- Precision = TP / (TP+FP)

- Recall = TP / (TP + FN)

# Precision: When it claims an image is a 5, how often is it correct


In [None]:
precision_score(y_train_5,y_train_pred)


# Recall: How many 5's does it detect


In [None]:

recall_score(y_train_5,y_train_pred)

# F1 Score: 2 * (Precision X Recall / Precision + Recall)


In [None]:

f1_score(y_train_5,y_train_pred)

- High Precision when classifying videos for kids
- High Recall when detecting shoplifters

- Precision/Recall trade off depends on threshold
    - If threshold low: Higher Recall, Lower Precision
    - If threshild high: Lower Recall, Higher Precision
   

# Adjusting Threshold 

In [None]:
y_scores = cross_val_predict(sgd_clf,X_train,y_train_5,cv=3,method = 'decision_function')
precisions, recalls, thresholds = precision_recall_curve(y_train_5,y_scores)


In [None]:
def plot_precision_recall_curve(precisions,recalls,thresholds):
    plt.plot(thresholds,precisions[:-1],"b--",label="precision")
    plt.plot(thresholds,recalls[:-1],"b--",label="recall")
    
plot_precision_recall_curve(precisions,recalls,thresholds)

# 90% Precision?

In [None]:
threshold_90_precision = thresholds[np.argmax(precisions>=0.9)]
threshold_90_precision

y_train_pred_90 = (y_scores >= threshold_90_precision)
precision_score(y_train_5,y_train_pred_90)
recall_score(y_train_5,y_train_pred_90)

# ROC Curve: FP rate vs TP rate

In [None]:
fpr, tpr, thresholds = roc_curve(y_train_5,y_scores)

def plot_roc_curve(fpr,tpr,label=None):
    plt.plot(fpr,tpr,linewidth=2,label=label)
    plt.plot([0,1],[0,1],'k--')

plot_roc_curve(fpr,tpr)

# AUC under ROC: Closer to 1 the better

In [None]:
roc_auc_score(y_train_5,y_scores)


- Use PR curve when positive class is small or false positive is worse than false negative 
- Use ROC otherswise

# Compare Random Forest to SGD via ROC Curve

In [None]:
forest_clf = RandomForestClassifier(random_state = 42)
y_probas_forest = cross_val_predict(forest_clf,X_train,y_train_5,cv=3,method='predict_proba')

In [None]:
y_scores_forest = y_probas_forest[:,-1]
fpr_forest,tpr_forest,thresholds_forest = roc_curve(y_train_5,y_scores_forest)

In [None]:
plt.plot(fpr,tpr,"b:",label="SGD")
plot_roc_curve(fpr_forest,tpr_forest,"Random Forest")
plt.legend(loc="lower right")

In [None]:
roc_auc_score(y_train_5,y_scores_forest)

In [None]:
# Precision & Recall
y_train_pred_forest = cross_val_predict(forest_clf,X_train,y_train_5,cv=3)



In [None]:
precision_score(y_train_5,y_train_pred_forest)


In [None]:
recall_score(y_train_5,y_train_pred_forest)


- Random Forest is far better Precision/Recall/ROC wise

# Multiclass Classification

- LR & SVM are strictly binary
    - However, One vs All strategy can turn these into MultiClass Classification (LR & rest ideal)
    - One vs One would need many classifers as it would compare each label with every other label (SVM ideal)

In [None]:
# SVM with One vs All
ovr_clf = OneVsRestClassifier(SVC())
ovr_clf.fit(X_train,y_train)
ovr_clf.predict([some_digit])

In [None]:
# SGD 
sgd_clf.fit(X_train,y_train)
sgd_clf.predict([some_digit])

In [None]:
# RF
forest_clf.fit(X_train,y_train)
forest_clf.predict([some_digit])

In [None]:
# Scale
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))

In [None]:
# Cross Validation Check for SGD
cross_val_score(sgd_clf,X_train_scaled,y_train,cv=3,scoring='accuracy')

In [None]:
# Cross Validation Check for RF
cross_val_score(forest_clf,X_train_scaled,y_train,cv=3,scoring='accuracy')

In [None]:
# Error Analysis

y_train_pred = cross_val_score(forest_clf,X_train_scaled,y_train,cv=3)
conf_mx = confusion_matrix(y_train,y_train_pred)
row_sums = conf_mx.sum(axis = 1, keepdims = True)
norm_conf_mx = conf_mx / row_sums
np.fill_diagonal(norm_conf_mx,0)
plt.matshow(norm_conf_mx,cmap=plt.cm.gray)

In [None]:
# Digits are being misclassified as 8, but actual 8s are being correctly classified 