In [None]:
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.linear_model import SGDClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from scipy.ndimage import shift
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seed for reproducibility
np.random.seed(42)

# Load MNIST dataset
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
X, y = mnist["data"], mnist["target"].astype(np.uint8)
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
X_train_scaled = X_train / 255.0
X_test_scaled = X_test / 255.0

# Function to shift image by x, y pixels
def shift_image(image, x, y):
    return shift(image.reshape(28, 28), [y, x]).flatten()

# Augment training set (shift by 1 pixel: right, left, down, up)
# Note: For demonstration, use a subset of 12,000 images to manage memory
subset_size = 12000  # Full set (60,000) creates 300,000 samples, which is memory-intensive
X_train_subset = X_train_scaled[:subset_size]
y_train_subset = y_train[:subset_size]
X_train_augmented = X_train_subset.copy()
y_train_augmented = y_train_subset.copy()
for x, y in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
    for image, label in zip(X_train_subset, y_train_subset):
        X_train_augmented = np.vstack([X_train_augmented, shift_image(image, x, y)])
        y_train_augmented = np.append(y_train_augmented, label)

print("Augmented Training Set Shape:", X_train_augmented.shape)  # (60000, 784) for subset

# Train SGD Classifier on augmented data
sgd_clf = SGDClassifier(loss='hinge', alpha=0.001, learning_rate='adaptive', eta0=0.1, 
                        max_iter=1000, tol=1e-4, random_state=42)
sgd_clf.fit(X_train_augmented, y_train_augmented)
y_pred_aug_sgd = sgd_clf.predict(X_test_scaled)
accuracy_aug_sgd = accuracy_score(y_test, y_pred_aug_sgd)
print("Augmented SGD Test Accuracy:", accuracy_aug_sgd)

# Confusion matrix and classification report
cm_aug_sgd = confusion_matrix(y_test, y_pred_aug_sgd)
print("Augmented SGD Confusion Matrix:\n", cm_aug_sgd)
print("\nAugmented SGD Classification Report:\n", classification_aug_sgd)

# Visualize original vs. augmented confusion matrices
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
sns.heatmap(cm_orig_sgd, annot=True, fmt='d', cmap='Blues', ax=ax1)
ax1.set_title('Original SGD Confusion Matrix')
ax1.set_xlabel('Predicted')
ax1.set_ylabel('True')
sns.heatmap(cm_aug_sgd, annot=True, fmt='d', cmap='Blues', ax=ax2)
ax2.set_title('Augmented SGD Confusion Matrix')
ax2.set_xlabel('Predicted')
ax2.set_ylabel('True')
plt.tight_layout()
plt.savefig('confusion_matrices_comparison.png')
plt.show()

# Visualize misclassified images for error patterns
errors = y_test != y_pred_aug_sgd
error_images = X_test[errors]
error_labels = y_test[errors]
error_preds = y_pred_aug_sgd[errors]
patterns = [(2, 8), (4, 9), (8, 5)]  # Error patterns: 2->8, 4->9, 8->5
for actual, pred in patterns:
    idx = np.where((error_labels == actual) & (error_preds == pred))[0]
    if len(idx) > 0:
        plt.figure(figsize=(3, 3))
        plt.imshow(error_images[idx[0]].reshape(28, 28), cmap='gray')
        plt.title(f'True: {actual}, Predicted: {pred}')
        plt.axis('off')
        plt.savefig(f'misclassified_{actual}_to_{pred}.png')
        plt.show()

# Compare with Random Forest Baseline
rf_clf = RandomForestClassifier(n_estimators=100, random_state=42)
rf_clf.fit(X_train_scaled, y_train)
y_pred_rf = rf_clf.predict(X_test_scaled)
accuracy_rf = accuracy_score(y_test, y_pred_rf)
print("\nRandom Forest Test Accuracy:", accuracy_rf)
cm_rf = confusion_matrix(y_test, y_pred_rf)
print("Random Forest Confusion Matrix:\n", cm_rf)