# Chapter 3: Classification

## 1. Chapter Overview
**Goal:** In this chapter, we will tackle the Classification task. Unlike Regression (predicting a value), Classification is about predicting a category (class). We will use the famous **MNIST** dataset, which contains 70,000 images of handwritten digits.

**Key Concepts:**
* **Binary Classifiers:** distinguishing between two classes (e.g., "Is this a 5?" vs. "Not a 5").
* **Performance Measures:** Accuracy is not enough. We need Confusion Matrices, Precision, Recall, and F1 Score.
* **The Precision/Recall Trade-off:** Understanding the balance between catching all positive cases and being correct when predicting positive.
* **Multiclass Classification:** distinguishing between more than two classes (0, 1, ..., 9).
* **Error Analysis:** Analyzing where the model makes mistakes to improve it.

**Practical Skills:**
* Fetching datasets using `fetch_openml`.
* Training `SGDClassifier` and `RandomForestClassifier`.
* Using Cross-Validation for accuracy.
* Plotting ROC Curves.

In [None]:
# Setup
import sys
assert sys.version_info >= (3, 5)

import sklearn
assert sklearn.__version__ >= "0.20"

import numpy as np
import os

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

## 2. Theoretical Explanation

### 1. MNIST Dataset
Often called the "Hello World" of Machine Learning. It consists of 70,000 small images of digits handwritten by high school students and employees of the US Census Bureau. Each image is labeled with the digit it represents.

### 2. Binary vs. Multiclass
* **Binary Classifier:** Capable of distinguishing between just two classes (e.g., "5" and "Not-5").
* **Multiclass Classifier:** Capable of distinguishing between more than two classes (e.g., digits 0 through 9).
Some algorithms (SGD, SVM) are strictly binary classifiers but can be used for multiclass classification using strategies like **OvR** (One-versus-the-Rest) or **OvO** (One-versus-One).

### 3. Performance Metrics (Crucial!)
**Accuracy** is often a bad metric for classifiers, especially with *skewed datasets* (e.g., if 90% of data is "Not-5", a dummy classifier that always guesses "Not-5" has 90% accuracy but is useless).

* **Confusion Matrix:** A table showing the counts of True Positives (TP), True Negatives (TN), False Positives (FP), and False Negatives (FN).
* **Precision:** Accuracy of positive predictions.
$$ Precision = \frac{TP}{TP + FP} $$
* **Recall (Sensitivity):** Ratio of positive instances that are correctly detected.
$$ Recall = \frac{TP}{TP + FN} $$
* **F1 Score:** The harmonic mean of Precision and Recall.

### 4. ROC Curve
The Receiver Operating Characteristic (ROC) curve plots the **True Positive Rate (Recall)** against the **False Positive Rate**. A good classifier stays as far away from the dotted diagonal line (random guessing) as possible.

## 3. Code Reproduction

We will start by loading the MNIST dataset and building a "5-detector".

In [None]:
from sklearn.datasets import fetch_openml

# Fetch MNIST dataset
# as_frame=False ensures we get numpy arrays (standard for image data in this book)
mnist = fetch_openml('mnist_784', version=1, as_frame=False)
mnist.keys()

In [None]:
X, y = mnist["data"], mnist["target"]
y = y.astype(np.uint8) # Convert labels from strings to integers

print("Data shape:", X.shape)
print("Target shape:", y.shape)

# Visualize one digit
some_digit = X[0]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap=mpl.cm.binary)
plt.axis("off")
plt.show()

print("Label:", y[0])

### Data Splitting
The MNIST dataset is already split into a training set (first 60,000 images) and a test set (last 10,000 images).

In [None]:
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

### Training a Binary Classifier (The "5-Detector")
We simplify the problem to only distinguish between two classes: "5" and "Not-5".

In [None]:
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)

from sklearn.linear_model import SGDClassifier

# Stochastic Gradient Descent (SGD) classifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_5)

# Predict the digit we visualized earlier
sgd_clf.predict([some_digit])

### Performance Evaluation
We will calculate the Confusion Matrix, Precision, and Recall.

In [None]:
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score

# Get predictions using Cross-Validation (cleaner than testing on test set)
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)

# Confusion Matrix
conf_mx = confusion_matrix(y_train_5, y_train_pred)
print("Confusion Matrix:\n", conf_mx)

# Precision and Recall
print("Precision:", precision_score(y_train_5, y_train_pred))
print("Recall:", recall_score(y_train_5, y_train_pred))
print("F1 Score:", f1_score(y_train_5, y_train_pred))

### Multiclass Classification
Scikit-Learn detects when you try to use a binary classification algorithm for a multiclass task and automatically runs OvR (One-versus-the-Rest) or OvO, depending on the algorithm.

In [None]:
# Determine if the model can classify all digits (0-9)
sgd_clf.fit(X_train, y_train)
sgd_clf.predict([some_digit])

# Check the decision function scores for all 10 classes
some_digit_scores = sgd_clf.decision_function([some_digit])
print("Scores for each class:", some_digit_scores)
print("Predicted class:", np.argmax(some_digit_scores))

## 4. Step-by-Step Explanation

### 1. Data Fetching and Reshaping
**Input:** `fetch_openml('mnist_784')`.
**Process:** We download the dataset. `X` contains the pixel intensities (784 features per image, from 28x28 pixels). `y` contains the labels.
**Output:** Arrays `X` (70000, 784) and `y` (70000,).

### 2. Binary Training
We create a target vector `y_train_5` which is `True` for all 5s and `False` for all other digits. The `SGDClassifier` relies on randomness, so `random_state=42` ensures reproducible results. It finds a linear hyperplane that best separates the 5s from the non-5s.

### 3. Confusion Matrix Analysis
* **True Negatives (Top-Left):** Non-5s correctly classified as Non-5s.
* **False Positives (Top-Right):** Non-5s incorrectly classified as 5s.
* **False Negatives (Bottom-Left):** 5s incorrectly classified as Non-5s.
* **True Positives (Bottom-Right):** 5s correctly classified as 5s.

This matrix tells us *how* the model is failing, not just *that* it is failing.

### 4. Multiclass Strategy (OvR)
When we run `sgd_clf.fit(X_train, y_train)` with 10 classes, Scikit-Learn actually trains 10 binary classifiers:
1. 0-detector
2. 1-detector
...
10. 9-detector

When you ask for a prediction, it gets the decision score from all 10 classifiers and picks the class with the highest score.

## 5. Chapter Summary

* **Classification vs Regression:** Classification predicts categories; Regression predicts values.
* **MNIST:** The standard dataset for learning image classification basics.
* **Accuracy Trap:** Do not rely solely on accuracy for skewed datasets. Use the **Confusion Matrix**.
* **Precision vs Recall:**
    * Precision: "When it claims it's a 5, is it really a 5?"
    * Recall: "Did it find all the 5s?"
    * You can't have both 100%; increasing one usually reduces the other (Trade-off).
* **Multiclass:** Algorithms can handle multiple classes natively (Random Forest) or use OvR/OvO strategies (SGD, SVM).