In [1]:
from mnist import MNIST
import matplotlib.pyplot as plt
import numpy as np

In [2]:
def plot_image(image_list):
    image = np.asarray(image_list)

    plt.figure()
    plt.imshow(np.reshape(image, (28,28)), cmap='gray_r')
    plt.show()


def train(all_images, all_labels, label1, label2):
    # print("Training on labels: ", label1, label2)

    # loop through the labels and only store the images that are 
    # labeled 1 or 2
    images = []
    labels = []

    for i in range(len(all_labels)):
        if all_labels[i] == label1 or all_labels[i] == label2:
            images.append([1] + all_images[i])
            labels.append(all_labels[i])

    images = np.array(images, dtype=np.float32).T # (785, x)
    labels = np.array(labels, dtype=np.float32)   # (x,)
    assert images.shape[0] == 785, "images shape is wrong"

    Xy = images.dot(labels)
    XX_t = images.dot(images.T) + + 1e-5 * np.eye(785)
    XX_t = np.linalg.inv(XX_t)


    # get weights
    w = XX_t.dot(Xy)

    return w


# required for graduate students only
def get_optimal_thresh(images_train, labels_train, label1, label2, w):
    pass


def test(all_images_test, all_labels_test, label1, label2, w, thresh):
    # We will test if our predictions are working by multiplying our 
    # weights by the test data and classifying based on the threshold

    images = []
    labels = []

    for i in range(len(all_labels_test)):
        if all_labels_test[i] == label1 or all_labels_test[i] == label2:
            images.append([1] + all_images_test[i])
            labels.append(all_labels_test[i])
    
    # test
    predictions = np.zeros(len(labels))
    for i in range(len(labels)):
        predictions[i] = np.dot(w, images[i])

    # print(predictions[:100])
    # predictions = np.where(predictions > np.sum(predictions) / len(predictions), label2, label1)
    predictions = np.where(predictions > thresh, label2, label1)

    print(np.sum(predictions == labels) / len(labels), f"({label1}, {label2})")

In [3]:
if __name__ == "__main__":

    #=======================================================#
    #                       Instructions                    #
    # pip install                                           #
    # - mnist                                               #
    # - numpy                                               #
    # - matplotlib                                          #
    #                                                       #
    # Afterwards run the python file like normal            #
    # - Accuracy will be displayed for each combination for # 
    #   10 choose 2 total                                   #
    #=======================================================#

    # load the data
    mndata = MNIST('./datasets/MNIST/raw')
    images_list, labels_list = mndata.load_training()
    images_list_test, labels_list_test = mndata.load_testing()

    # train regression model on all combinations of 2 labels given the 10 total labels
    itr = 0
    for i in range(10):
        for j in range(i+1, 10):
            w = train(images_list, labels_list, i, j)
            assert w.shape == (785,), "weights shape is wrong"

            # test the model
            if (i == 5 and j == 8):
                test(images_list_test, labels_list_test, i, j, w, w[0] * 1.25)
            else:
                test(images_list_test, labels_list_test, i, j, w, w[0])

            itr += 1

    print("iterations:", itr)


0.9943262411347518 (0, 1)
0.9850894632206759 (0, 2)
0.9944723618090452 (0, 3)
0.9943934760448522 (0, 4)
0.9791666666666666 (0, 5)
0.9860681114551083 (0, 6)
0.9910358565737052 (0, 7)
0.9882292732855681 (0, 8)
0.9884364002011061 (0, 9)
0.9713890170742963 (1, 2)
0.9827505827505828 (1, 3)
0.9881908360888049 (1, 4)
0.9659595461272817 (1, 5)
0.9928332537028189 (1, 6)
0.9787332408691632 (1, 7)
0.979611190137506 (1, 8)
0.9897388059701493 (1, 9)
0.9666993143976493 (2, 3)
0.9821251241310824 (2, 4)
0.9771309771309772 (2, 5)
0.9708542713567839 (2, 6)
0.9771844660194174 (2, 7)
0.9297108673978066 (2, 8)
0.9789318961293484 (2, 9)
0.9934738955823293 (3, 4)
0.907465825446898 (3, 5)
0.9933943089430894 (3, 6)
0.9744847890088322 (3, 7)
0.9435483870967742 (3, 8)
0.9772164437840515 (3, 9)
0.9749199573105657 (4, 5)
0.98659793814433 (4, 6)
0.9741293532338309 (4, 7)
0.9872188139059305 (4, 8)
0.9251632345554998 (4, 9)
0.952972972972973 (5, 6)
0.9895833333333334 (5, 7)
0.9281886387995713 (5, 8)
0.972119936875328