Discriminant Analysis to recognize the digits in the MNIST data set

In [None]:
def load_mnist(path, kind='train'):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 28, 28)

    return images, labels


In [None]:
#Imports
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import confusion_matrix


#loading the data
X_train, y_train = load_mnist('data', kind='train')
X_test, y_test = load_mnist('data', kind='t10k')

#creating a dictionary with [0-9 as keys] and corresponding X_train data as values
digit_data={}
for i in range(y_train.shape[0]):
    if y_train[i] in digit_data:
        digit_data[y_train[i]].append(X_train[i, :, :])
    else:
        digit_data[y_train[i]]=[X_train[i, :, :]]
        
        


In [None]:
#finding mean and standard deviation of the digits
mean_dict={}
std_dict={}
for i in digit_data:
    mean_dict[i] = np.mean(digit_data[i],axis=0)
    std_dict[i] = np.std(digit_data[i], axis=0)
    

In [None]:
#plotting Mean digits and Standard deviation digits
for i in range(10):
    fig, ax = plt.subplots(1,2)

    ax[0].title.set_text('Mean Digits of '+ str(i))
    ax[1].title.set_text('Standard Deviation Digits of '+ str(i))
    ax[0].imshow(mean_dict[i])
    ax[1].imshow(std_dict[i],cmap="gray")

In [None]:
#calculating w for Quadratic Discriminant Analysis
w={}
tmp_dict={}
for i in range(10):
    std_dict[i] = std_dict[i].flatten()
    tmp = np.zeros((len(std_dict[i]), len(std_dict[i])))

    for j in range(len(std_dict[i])):
        tmp[j][j] = (std_dict[i][j] ** 2) 
        np.fill_diagonal(tmp,tmp.diagonal()+0.1) #Adding noise to avoid Singular Matrix error
    tmp_dict[i]=tmp
    w[i] = -0.5 * np.linalg.inv(tmp)              #linalg.inv calculates the multiplicative inverse of the matrix


In [None]:
#calculating n for Quadratic Discriminant Analysis
n={}
for i in range(10):
    mean_dict[i] = mean_dict[i].flatten()
    n[i] = np.matmul( np.linalg.inv(tmp_dict[i]) , mean_dict[i])
print(n)

In [None]:
#calculating b for Quadratic Discriminant Analysis
b={}
for i in range(10):
    b[i] = -0.5 * (np.matmul(np.matmul(np.transpose(mean_dict[i]),np.linalg.inv(tmp_dict[i])) , mean_dict[i]))
    
    b[i] = b[i] - ((0.5)*(np.linalg.slogdet(tmp_dict[i]))[1])
print(b)


In [None]:
print(b)

In [None]:
def find_accuracy(actual, predicted):
        ctr = 0
        for i in range(len(actual)):
            if predicted[i] == actual[i]:
                ctr += 1

        return ctr / len(actual)
    
def find_confusion_matrix(y, values):
        """
        :param y: actual label
        :param yhat: predicted label
        :return: confusion matrix
        """
        return confusion_matrix(y, values)

def predict(x_test, y_test,w,n):
    print('\nTesting mnist using Discriminant Analysis')
    values = []
    for x in tqdm(x_test, total=len(x_test)):
        g = []
        x = x.flatten()
        
        for i in range(10):
            
            g.append(np.matmul((np.matmul(np.transpose(x), w[i])), x) + np.matmul(np.transpose(n[i]), x) + b[i])

        g = np.asarray(g)
        values.append(np.argmax(g))
    print(values[0:10])
    print(values[:10])
    accuracy = find_accuracy(values, yhat)
    confusion = find_confusion_matrix(y_test, values)

    return accuracy, confusion

In [None]:
acc, conf = predict(X_test,y_test,w,n)
print(acc)
print(conf)