In [1]:
import mxnet as mx
import numpy as np
import mnist
import logging

In [2]:
train_x, train_y,test_x, test_y = mnist.readmnist('data/mnist')

In [3]:
epochs = 50
batch_size = 100

In [None]:
#Setting iterators

train_iter = mx.io.NDArrayIter(data = train_x, label = train_y, batch_size = batch_size,data_name = 'data', 
                               label_name = 'softmax_label',last_batch_handle = "discard", shuffle = True)
data = mx.sym.Variable("data")
fc1 = mx.sym.FullyConnected(data = data, name ="fc1", num_hidden = 64, flatten = 1)
relu1 = mx.sym.Activation(data = fc1, act_type = "relu", name = 'relu1')
fc2 = mx.sym.FullyConnected(data = relu1, name = "fc2", num_hidden = 10, flatten = 1)
out = mx.sym.SoftmaxOutput(data = fc2, name = "softmax")

mod = mx.mod.Module(out,context = mx.cpu())
print(mod.data_names)
print(mod.label_names)

mod.bind(data_shapes = train_iter.provide_data, label_shapes = train_iter.provide_label, for_training = True)
mod.init_params(initializer = mx.init.Xavier(magnitude = 1.0))
mod.init_optimizer(optimizer = "sgd", optimizer_params=(('learning_rate',0.01),), force_init = False)
mod.fit(train_data = train_iter, eval_metric = "accuracy", num_epoch = epochs)
mod.score(eval_data = train_iter, eval_metric = ["acc","mse"])

In [None]:
test_iter = mx.io.NDArrayIter(data = test_x, label = test_y, batch_size = batch_size, shuffle = False, last_batch_handle = "discard")

total_correct = 0
batch_correct = 0
for preds,i, batch in mod.iter_predict(test_iter):
    label = batch.label[0].asnumpy().argmax(axis = 1)
    pred_label = preds[0].asnumpy().argmax(axis = 1)
    batch_correct = np.sum(label == pred_label)
    total_correct += batch_correct

print("Test Set accuracy {%2.2f}" %(total_correct/len(test_y)))

In [4]:
#Another way to do this
train_iter = mx.io.NDArrayIter(data = train_x, label = train_y, batch_size = batch_size,data_name = 'data', 
                               label_name = 'softmax_label',last_batch_handle = "discard", shuffle = True)
data = mx.sym.Variable("data")
fc1 = mx.sym.FullyConnected(data = data,name = "fc1", num_hidden = 64, flatten = 1)
relu1 = mx.sym.Activation(data = fc1, name = "relu1", act_type = "relu")
fc2 = mx.sym.FullyConnected(data = relu1, name = "fc2", num_hidden = 10)
softmax = mx.sym.SoftmaxOutput(data = fc2, name = 'softmax')

mod = mx.mod.Module(softmax)
mod.bind(data_shapes = train_iter.provide_data, label_shapes = train_iter.provide_label)
mod.init_params()
mod.init_optimizer(optimizer_params = {'learning_rate': 0.01, 'momentum' : 0.9})
metric = mx.metric.create("acc")

for i in range(epochs):
    for i_iter, batch in enumerate(train_iter):
        mod.forward(batch)
        mod.update_metric(metric, batch.label)
        mod.backward()
        mod.update()

    for name,val in metric.get_name_value():
        print('epoch %03d: %s = %f'%(i,name,val))
    metric.reset()
    train_iter.reset()

test_iter = mx.io.NDArrayIter(data = test_x, label = test_y, batch_size = batch_size, shuffle = False, last_batch_handle = "discard")

total_correct = 0
batch_correct = 0
for preds,i, batch in mod.iter_predict(test_iter):
    label = batch.label[0].asnumpy().argmax(axis = 1)
    pred_label = preds[0].asnumpy().argmax(axis = 1)
    batch_correct = np.sum(label == pred_label)
    total_correct += batch_correct

print("Test Set accuracy {%2.2f}" %(1.0*total_correct/len(test_y)))

epoch 000: accuracy = 0.900000
epoch 001: accuracy = 0.900000
epoch 002: accuracy = 0.900000
epoch 003: accuracy = 0.900000
epoch 004: accuracy = 0.900000
epoch 005: accuracy = 0.900002
epoch 006: accuracy = 0.900002
epoch 007: accuracy = 0.900005
epoch 008: accuracy = 0.900010
epoch 009: accuracy = 0.900013
epoch 010: accuracy = 0.900020
epoch 011: accuracy = 0.900030
epoch 012: accuracy = 0.900043
epoch 013: accuracy = 0.900055
epoch 014: accuracy = 0.900078
epoch 015: accuracy = 0.900113
epoch 016: accuracy = 0.900158
epoch 017: accuracy = 0.900182
epoch 018: accuracy = 0.900238
epoch 019: accuracy = 0.900312
epoch 020: accuracy = 0.900390
epoch 021: accuracy = 0.900463
epoch 022: accuracy = 0.900560
epoch 023: accuracy = 0.900675
epoch 024: accuracy = 0.900802
epoch 025: accuracy = 0.900925
epoch 026: accuracy = 0.901070
epoch 027: accuracy = 0.901217
epoch 028: accuracy = 0.901383
epoch 029: accuracy = 0.901513
epoch 030: accuracy = 0.901692
epoch 031: accuracy = 0.901897
epoch 03