In [2]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Load MNIST dataset
X, y = fetch_openml('mnist_784', version=1, return_X_y=True, parser='auto')
X = X.astype(float)
y = y.astype(int)

# Split dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Convert to numpy arrays to avoid indexing issues
X_train = np.array(X_train)
X_test = np.array(X_test)
y_train = np.array(y_train)
y_test = np.array(y_test)

# Standardize data
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Add bias term
X_train_bias = np.hstack((X_train, np.ones((X_train.shape[0], 1))))
X_test_bias = np.hstack((X_test, np.ones((X_test.shape[0], 1))))


class BinaryPerceptron:
    def __init__(self, n_features):
        self.weights = np.zeros(n_features)

    def predict(self, X):
        return np.sign(np.dot(X, self.weights))

    def train(self, X, y, epochs=10):
        n_samples = X.shape[0]

        for epoch in range(epochs):
            errors = 0

            for i in range(n_samples):
                x_i = X[i]
                y_i = y[i]

                if y_i * np.dot(self.weights, x_i) <= 0:  # If misclassified
                    self.weights += y_i * x_i
                    errors += 1

            error_rate = errors / n_samples
            if epoch == epochs - 1:
                print(f"Final epoch error: {error_rate:.4f}")


class OneVsAllPerceptron:
    def __init__(self, n_classes, n_features, epochs=10):
        self.n_classes = n_classes
        self.epochs = epochs
        self.classifiers = [BinaryPerceptron(n_features) for _ in range(n_classes)]

    def train(self, X, y):
        for epoch in range(self.epochs):
            print(f"Epoch {epoch + 1}/{self.epochs}")

            for k in range(self.n_classes):
                # Create binary labels for current class
                binary_y = np.where(y == k, 1, -1)

                # Train binary perceptron for current class
                self.classifiers[k].train(X, binary_y, epochs=1)

    def predict(self, X):
        # Calculate scores for all classes for each sample
        scores = np.zeros((X.shape[0], self.n_classes))

        for k in range(self.n_classes):
            # Calculate scores from the kth classifier for each sample
            scores[:, k] = np.dot(X, self.classifiers[k].weights)

        # Return class with highest score
        return np.argmax(scores, axis=1)

    def evaluate(self, X, y):
        predictions = self.predict(X)
        accuracy = np.mean(predictions == y)
        return 1 - accuracy  # Return error rate

# Train One vs. All perceptron
num_classes = 10  # MNIST has 10 classes
perceptron = OneVsAllPerceptron(num_classes, X_train_bias.shape[1], epochs=10)
perceptron.train(X_train_bias, y_train)

# Evaluate model
train_error = perceptron.evaluate(X_train_bias, y_train)
test_error = perceptron.evaluate(X_test_bias, y_test)

print(f"Final Training Error: {train_error:.4f}")
print(f"Final Test Error: {test_error:.4f}")

Epoch 1/10
Final epoch error: 0.0324
Final epoch error: 0.0318
Final epoch error: 0.0526
Final epoch error: 0.0596
Final epoch error: 0.0486
Final epoch error: 0.0631
Final epoch error: 0.0394
Final epoch error: 0.0474
Final epoch error: 0.0770
Final epoch error: 0.0741
Epoch 2/10
Final epoch error: 0.0179
Final epoch error: 0.0169
Final epoch error: 0.0351
Final epoch error: 0.0414
Final epoch error: 0.0297
Final epoch error: 0.0440
Final epoch error: 0.0243
Final epoch error: 0.0314
Final epoch error: 0.0629
Final epoch error: 0.0592
Epoch 3/10
Final epoch error: 0.0161
Final epoch error: 0.0163
Final epoch error: 0.0332
Final epoch error: 0.0390
Final epoch error: 0.0278
Final epoch error: 0.0398
Final epoch error: 0.0232
Final epoch error: 0.0280
Final epoch error: 0.0612
Final epoch error: 0.0558
Epoch 4/10
Final epoch error: 0.0148
Final epoch error: 0.0145
Final epoch error: 0.0319
Final epoch error: 0.0387
Final epoch error: 0.0260
Final epoch error: 0.0401
Final epoch error: 0