In [1]:
import numpy as np
import numpy.linalg as npl
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml

In [2]:
# Load MNIST data

def load_mnist_data_openml():
    # Returns X_train: (60000, 784), X_test: (10000, 784), scaled [0...1]
    # y_train: (60000,) 0..9 ints, y_test: (10000,)
    print("Downloading mnist...")
    data = fetch_openml('mnist_784', version=1, cache=True)
    print("Done")
    # data = fetch_mldata('MNIST original')
    x = data['data'].astype('float32')
    y = data["target"].astype('int64')
    # Create train-test split (as [Joachims, 2006])
    n_train = 60000
    x_train = x[:n_train]
    y_train = y[:n_train]
    x_test = x[n_train:]
    y_test = y[n_train:]
    return x_train, x_test, y_train, y_test

# Put the data into variables
xtr, xte, ytr, yte = load_mnist_data_openml()
xtr = np.array(xtr.transpose())
xte = np.array(xte.transpose())

# Center the Data
xtr = -128 + xtr
A = np.column_stack((xtr.transpose(), np.ones(60000)))
xte = -128 + xte

Downloading mnist...
Done


In [3]:
# Array containing -1 or 1 depending if the element matches the index
ys = np.array([np.array([1 if j == i else -1 for j in ytr]) for i in range(0, 10)])

# Array with the least square results
ws = np.array([npl.lstsq(A, y, rcond=None)[0] for y in ys])

In [4]:
# Get the number by finding the largest dot product.
def getNumber(squares, vector):
    maximum = np.dot(squares[0][:784], vector)
    index = 0
    for i in range(1, 10):
        d = np.dot(squares[i][:784], vector)
        if d > maximum:
            index = i
            maximum = d
            
    return index

In [5]:
# Print Tests
test_counter = 0
correct_counter = 0

for r in range(0, 10000):
    experimental = getNumber(ws, xte.transpose()[r])
    actual = np.array(yte)[r]
    print(f"result: {experimental}  actual: {actual}")
    if experimental == actual:
        correct_counter += 1
    test_counter += 1
    
print(f"{correct_counter / test_counter * 100}% Correct")

result: 7  actual: 7
result: 2  actual: 2
result: 1  actual: 1
result: 0  actual: 0
result: 4  actual: 4
result: 1  actual: 1
result: 4  actual: 4
result: 9  actual: 9
result: 5  actual: 5
result: 9  actual: 9
result: 0  actual: 0
result: 4  actual: 6
result: 9  actual: 9
result: 0  actual: 0
result: 1  actual: 1
result: 5  actual: 5
result: 9  actual: 9
result: 7  actual: 7
result: 2  actual: 3
result: 4  actual: 4
result: 9  actual: 9
result: 6  actual: 6
result: 6  actual: 6
result: 5  actual: 5
result: 4  actual: 4
result: 0  actual: 0
result: 7  actual: 7
result: 4  actual: 4
result: 0  actual: 0
result: 1  actual: 1
result: 3  actual: 3
result: 1  actual: 1
result: 3  actual: 3
result: 5  actual: 4
result: 7  actual: 7
result: 2  actual: 2
result: 7  actual: 7
result: 1  actual: 1
result: 1  actual: 2
result: 1  actual: 1
result: 1  actual: 1
result: 7  actual: 7
result: 4  actual: 4
result: 1  actual: 2
result: 3  actual: 3
result: 3  actual: 5
result: 3  actual: 1
result: 6  ac