Import Necessary Libraries

In [421]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
from ANN import *

Load the MNIST Dataset

In [442]:
X, y = fetch_openml('mnist_784', version=1, return_X_y=True, parser='auto')

Preprocessing (Formatting the data)

In [443]:
X = X.to_numpy() if hasattr(X, 'to_numpy') else np.array(X)
y = y.to_numpy() if hasattr(y, 'to_numpy') else np.array(y)
y = y.astype(int) # numpy can't implicit cast, so explicit cast here

# shuffles the contents
idx = np.random.choice(len(X), 5000, replace=False) # 5000 is sample size, change if necessary
X, y = X[idx], y[idx]
# normalizes the RGB scale for the black-white value
X = X / 255.0

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Model Preparation

In [522]:
# from scratch ANN model
inp = Layer(linear, [he_init, {"seed": 42}], 784) # initiate the layers
hid_1 = Layer(h_tan, [he_init, {"seed": 47}], 64)
hid_2 = Layer(h_tan, [he_init, {"seed": 45}], 64)
hid_3 = Layer(h_tan, [he_init, {"seed": 43}], 64)
hid_4 = Layer(h_tan, [he_init, {"seed": 48}], 64)
out = Layer(softmax, [he_init, {"seed": 44}], 10)
scratch_model = ANN(None, [hid_1], input=inp, output=out, error=CCE) # inititate the model
scratch_model.train(X_train, y_train, batch_size=32, l_rate=0.01, epoch=5, verb=0)

# sklearn library ANN model
mlp = MLPClassifier(hidden_layer_sizes=(64,), activation='identity', solver='sgd', 
                      max_iter=5, random_state=42, learning_rate_init=0.01, batch_size=32) # initiate the model

Model Testing and Comparison

In [523]:
preds = scratch_model.predict(X_test)
test_acc = accuracy_score(y_test, np.argmax(preds, axis=1))
print(f"Test accuracy: {test_acc:.4f}")

mlp.fit(X_train, y_train)
mlp_acc = mlp.score(X_test, y_test)
print(f"sklearn accuracy: {mlp_acc:.4f}")

Test accuracy: 0.8660
sklearn accuracy: 0.9110


