## References

* [MNIST](http://yann.lecun.com/exdb/mnist/)
* [Deep MNIST for experts](https://www.tensorflow.org/get_started/mnist/pros)

In [None]:
import numpy as np
import tensorflow as tf

from tensorflow.examples.tutorials.mnist.input_data import read_data_sets as load_data

np.random.seed(0)
tf.set_random_seed(0)

In [None]:
CLASS_COUNT = 10
IMAGE_SIZE = 28

data = load_data('data', one_hot=True)

print('Training: {}'.format(data.train.num_examples))
print('Validation: {}'.format(data.validation.num_examples))
print('Testing: {}'.format(data.test.num_examples))

In [None]:
tf.reset_default_graph()

class Model:
    def __init__(self, x, y):
        self.x, self.y = x, y
        with tf.name_scope('model'):
            x_image = tf.reshape(x, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])

            with tf.name_scope('layer1'):
                W = self._create_weight([5, 5, 1, 32])
                b = self._create_bias([32])
                h = tf.nn.relu(self._create_convolution(x_image, W) + b)
                h = self._create_pooling(h)

            with tf.name_scope('layer2'):
                W = self._create_weight([5, 5, 32, 64])
                b = self._create_bias([64])
                h = tf.nn.relu(self._create_convolution(h, W) + b)
                h = self._create_pooling(h)
                h = tf.reshape(h, [-1, 7 * 7 * 64])

            with tf.name_scope('layer3'):
                W = self._create_weight([7 * 7 * 64, 1024])
                b = self._create_bias([1024]) 
                h = tf.nn.relu(tf.matmul(h, W) + b)
                self.keep = tf.placeholder(tf.float32)
                h = tf.nn.dropout(h, self.keep)

            with tf.name_scope('layer4'):
                W = self._create_weight([1024, 10])
                b = self._create_bias([10])
                self.y_logit = tf.matmul(h, W) + b
                self.y_score = tf.nn.softmax(self.y_logit)

    def _create_weight(self, shape):
        return tf.Variable(tf.truncated_normal(shape, stddev=0.1))

    def _create_bias(self, shape):
        return tf.Variable(tf.constant(0.1, shape=shape))

    def _create_convolution(self, x, W):
        return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

    def _create_pooling(self, x):
        return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

class Objective:
    def __init__(self, model, **arguments):
        with tf.name_scope('objective'):
            self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                labels=model.y, logits=model.y_logit), name='loss')
            optimizer = tf.train.AdamOptimizer(**arguments)
            self.train = optimizer.minimize(self.loss)

class Experiment:
    def run(self, data, model, objective, session, batch_size=2**4, step_count=2**9):
        for i in range(step_count):
            batch = data.train.next_batch(batch_size)
            if i % 100 == 0:
                stats = self._assess(batch[0], batch[1], model, session)
                print('{:10}: {}'.format(i, stats))
            feed = {model.x: batch[0], model.y: batch[1], model.keep: 0.5}
            session.run(objective.train, feed_dict=feed)
        return self._assess(data.test.images, data.test.labels, model, session)

    def _assess(self, images, labels, model, session):
        feed = {model.x: images, model.y: labels, model.keep: 1.0}
        y_score = session.run(model.y_score, feed_dict=feed)
        y_predicted = np.argmax(y_score, axis=1)
        y_test = np.argmax(labels, axis=1)
        return {
            'Accuracy': (y_test == y_predicted).sum() / images.shape[0],
        }

In [None]:
graph = tf.Graph()
with graph.as_default():
    model = Model(
       tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE * IMAGE_SIZE], name='input'),
       tf.placeholder(tf.float32, shape=[None, CLASS_COUNT], name='output'),
    )
    objective = Objective(model)
    initialize = tf.global_variables_initializer()

with tf.Session(graph=graph) as session:
    with tf.summary.FileWriter('/tmp/model', graph) as writer:
        session.run(initialize)
        experiment = Experiment()
        stats = experiment.run(data, model, objective, session)

print('{:>10}: {}'.format('Test', stats))