# Mxnet MNIST mlp

Dimension: 784x10<br>
Training size: 60,000<br>
Test size: 10,000

In [1]:
import mxnet as mx

In [2]:
import numpy as np
from functools import reduce
import datetime

### Load data

In [3]:
from mnist import MNIST
#you can find python-mnist source code on https://github.com/sorki/python-mnist

datahandler = MNIST('../mnist_data_loader') #change for data path
train_data = datahandler.load_training()
test_data = datahandler.load_testing()

In [4]:
train_image_array = np.asarray(train_data[0])
test_image_array = np.asarray(test_data[0])

In [5]:
train_label_list = []
for i in train_data[1]:
    l = [0]*10
    l[i] = 1
    train_label_list.append(l)

train_label_array = np.asarray(train_label_list)

In [6]:
test_label_list = []
for i in test_data[1]:
    l = [0]*10
    l[i] = 1
    test_label_list.append(l)

test_label_array = np.asarray(test_label_list)

### Create iterators for training and testing

In [7]:
train_iter = mx.io.NDArrayIter(
    
    data = train_image_array,
    label = train_label_array,
    batch_size = 10,

)

In [8]:
test_iter = mx.io.NDArrayIter(
    
    data = test_image_array,
    batch_size = 1,

)

### Create and train model

In [9]:
data = mx.symbol.Variable('data')
fc1  = mx.symbol.FullyConnected(data = data, num_hidden=784)
act1 = mx.symbol.Activation(data = fc1, act_type="sigmoid")
fc2  = mx.symbol.FullyConnected(data = act1, num_hidden=10)
mlp  = mx.symbol.SoftmaxOutput(data = fc2, name = 'softmax')

In [10]:
model = mx.model.FeedForward(
    
    symbol = mlp,
    num_epoch = 10,
    learning_rate = .01

)

In [11]:
t = datetime.datetime.now()
model.fit(X = train_iter)
print("Tiempo de ejecución: {}".format(datetime.datetime.now()-t))

Tiempo de ejecución: 0:04:06.063847


### Make predictions

In [12]:
p = model.predict(X = test_iter)

In [17]:
print("Dimensión predicción: {0}".format(p.shape)) #check prediction dimensions

Dimensión predicción: (10000, 10)


# Metrics

### Confusion matrix
Useful for metric calculation

In [18]:
conf_mtx = np.zeros([10, 10])
for j in range(0, len(p)):
    prediction = p[j]
    max_pred = reduce(lambda x, y: (x if prediction[x] > prediction[y] else y), range(0,10))
    actual_label = list(filter(lambda i: test_label_array[j][i], range(0,10)))[0]
    conf_mtx[actual_label][max_pred] += 1
np.set_printoptions(suppress=True)
print(conf_mtx)    

[[  965.     0.     0.     0.     0.     3.     5.     1.     5.     1.]
 [    0.  1123.     0.     3.     0.     1.     1.     0.     7.     0.]
 [   10.     4.   947.    22.     5.     1.     2.    10.    27.     4.]
 [    2.     1.     8.   966.     1.     6.     0.     6.    17.     3.]
 [    2.     2.    10.     1.   897.     1.     7.     0.     6.    56.]
 [   11.     5.     3.    31.     4.   808.    15.     1.     6.     8.]
 [   13.     4.     3.     0.    10.    11.   912.     1.     4.     0.]
 [    1.    19.    21.     9.     4.     0.     0.   921.     2.    51.]
 [    3.     5.     3.    28.     9.     5.    10.     6.   875.    30.]
 [    3.     7.     1.    14.    15.     1.     1.     3.     6.   958.]]


### Accuracy

In [19]:
acc = (sum(conf_mtx[i][i] for i in range(0, 10))/10000)

In [20]:
print('Accuracy: {:.5f}'.format(acc))

Accuracy: 0.93720


### Recall

In [21]:
recall = lambda i: (conf_mtx[i][i]/sum(conf_mtx[i][j] for j in range(0,10)))

In [22]:
recall_sum = 0
for i in range(0,10):
    rcl = recall(i)
    recall_sum += rcl
    print('Recall {}: {:.5f}'.format(i, rcl))
print()
print('Recall mean: {:.5f}'.format(recall_sum/10))

Recall 0: 0.98469
Recall 1: 0.98943
Recall 2: 0.91764
Recall 3: 0.95644
Recall 4: 0.91344
Recall 5: 0.90583
Recall 6: 0.95198
Recall 7: 0.89591
Recall 8: 0.89836
Recall 9: 0.94945

Recall mean: 0.93632


### Precision

In [23]:
precision = lambda i: (conf_mtx[i][i]/sum(conf_mtx[j][i] for j in range(0,10)))

In [24]:
precision_sum = 0
for i in range(0,10):
    label_precision = precision(i)
    precision_sum += label_precision
    print('Precision {}: {:.5f}'.format(i, label_precision))
print()
print('Precision mean: {:.5f}'.format(precision_sum/10))

Precision 0: 0.95545
Precision 1: 0.95983
Precision 2: 0.95080
Precision 3: 0.89944
Precision 4: 0.94921
Precision 5: 0.96535
Precision 6: 0.95698
Precision 7: 0.97050
Precision 8: 0.91623
Precision 9: 0.86229

Precision mean: 0.93861
