## ALTERNATIVE APPROACH TO K-MEANS SQUARE FOR MNIST DATA

### PROBLEM
Not able to get clusters of certain digits even with higher values of k.

### HYPOTHESIZED SOLUTION
To form initial clusters by taking mean of all points of one class in training set.

In [10]:
import numpy as np
import pandas as pd
import datetime
from matplotlib import pyplot as plt
%matplotlib inline

In [11]:
ds = pd.read_csv('./train.csv')

data = ds.values
print (data.shape) 
print (data)

(42000, 785)
[[1 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [1 0 0 ... 0 0 0]
 ...
 [7 0 0 ... 0 0 0]
 [6 0 0 ... 0 0 0]
 [9 0 0 ... 0 0 0]]


In [12]:
split = int(0.8 * data.shape[0])
X_train = data[:split, 1:]
y_train = data[:split, 0]

X_test = data[split:, 1:]
y_test = data[split:, 0]

In [13]:
def hellinger_dist(x1, x2):
    return np.sqrt(0.5*((np.sqrt(x1) - np.sqrt(x2))**2).sum())

def formClusters(X_data, y_data):
    clusters = {}
    for px in range(X_data.shape[0]):
        if y_data[px] in clusters.keys():
            clusters[y_data[px]]['points'].append(X_data[px])
        else:
            clusters[y_data[px]] = {
                'center': [],
                'points': [X_data[px]],
            }

    for kx in clusters.keys():
        clusters[kx]['center'] = np.mean(clusters[kx]['points'], axis=0)
    return clusters

def findClusterModified(x):    
    vals = []
    for kx in clusters.keys():
        v = [hellinger_dist(x, clusters[kx]['center']), kx]
        vals.append(v)
    vals = sorted(vals, key=lambda x:x[0])
    return vals[0][1]

In [14]:
clusters = formClusters(X_train, y_train)

In [15]:
correct = 0
incorrect = 0
start = datetime.datetime.now()

for ix in range(X_test.shape[0]):
    res = findClusterModified(X_test[ix])
    if res == y_test[ix]:
        correct += 1
    else:
        incorrect += 1

end = datetime.datetime.now()
accuracy = (float(correct)/(correct+incorrect))*100
print ('Accuracy for Modified K-Means Square on MNIST Data: ', accuracy)
print ('Time Taken: ', (end - start).seconds, 'seconds')

Accuracy for Modified K-Means Square on MNIST Data:  69.13095238095238
Time Taken:  2 seconds
