In [3]:
import tensorflow as tf
import numpy as np

In [4]:
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images # Returns np.array
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images # Returns np.array
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting MNIST-data\train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting MNIST-data\train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST-data\t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST-data\t10k-labels-idx1-ubyte.gz


In [10]:
def cnn_model_fn(features,labels, mode):
    input_layer = tf.reshape(features['x'],[-1,28,28,1])
    conv1 = tf.layers.conv2d(inputs=input_layer, kernel_size=[5,5], filters=32, padding='same',activation=tf.nn.relu)
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2,2], strides=2)
    
    conv2 = tf.layers.conv2d(inputs=pool1, kernel_size=[5,5], filters=64, padding='same', activation=tf.nn.relu)
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2,2], strides=2)
    
    pool2_flat = tf.reshape(pool2, shape=[-1,7*7*64])
    dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
    dropout = tf.layers.dropout(inputs=dense, rate=0.4, training=mode==tf.estimator.ModeKeys.TRAIN)
    
    logits = tf.layers.dense(inputs=dropout, units=10)
    
    predictions = {'probabilities': tf.nn.softmax(logits, name='softmax_tensor'), 'classes': tf.argmax(logits,1)}
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
    
    onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=10)
    loss=tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits)
    
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.GradientDescentOptimizer(0.001)
        train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op = train_op)
    
    eval_metric_ops = {'accuracy':tf.metrics.accuracy(labels=labels, predictions=predictions['classes'])}
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

    

In [11]:
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_keep_checkpoint_max': 5, '_model_dir': 'C:\\Users\\Lei\\AppData\\Local\\Temp\\tmpddwbth89', '_log_step_count_steps': 100, '_tf_random_seed': 1, '_keep_checkpoint_every_n_hours': 10000, '_session_config': None, '_save_summary_steps': 100, '_save_checkpoints_secs': 600, '_save_checkpoints_steps': None}


In [14]:
tensors_to_log = {'probabilities': 'softmax_tensor'}
logging_hook = tf.train.LoggingTensorHook(every_n_iter=50,tensors=tensors_to_log)


In [15]:
train_input_fn = tf.estimator.inputs.numpy_input_fn(x={'x':train_data}, y=train_labels,batch_size=100, shuffle=True, num_epochs=None)
mnist_classifier.train(input_fn=train_input_fn, steps=1000,hooks=[logging_hook])

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from C:\Users\Lei\AppData\Local\Temp\tmpddwbth89\model.ckpt-1
INFO:tensorflow:Saving checkpoints for 2 into C:\Users\Lei\AppData\Local\Temp\tmpddwbth89\model.ckpt.
INFO:tensorflow:probabilities = [[ 0.10028942  0.11734748  0.0772321   0.10601557  0.10153314  0.08791868
   0.09964505  0.10405999  0.09567519  0.11028352]
 [ 0.08816717  0.11949042  0.07310519  0.11460417  0.11355447  0.08904001
   0.10081957  0.09456273  0.09762391  0.10903242]
 [ 0.09757876  0.10460256  0.09653636  0.10370924  0.1016323   0.09289508
   0.10508886  0.08787038  0.09847593  0.11161055]
 [ 0.0933327   0.10582371  0.09098361  0.10037117  0.10529879  0.10135241
   0.10406583  0.09127274  0.09887052  0.10862858]
 [ 0.08825591  0.12377109  0.09758469  0.09127928  0.09734066  0.09394201
   0.10547535  0.09049757  0.10100821  0.11084523]
 [ 0.0882958   0.12051293  0.0908253   0.09515576  0.1042937   0.09761348
   0.09615774  0.0989043

KeyboardInterrupt: 