The [TF-Slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim) library provides common abstractions which enable us to define models quickly and concisely, while keeping the model architecture transparent and its hyperparameters explicit. Let's construct models by TF-Slim and see how concisely and beautiful they are.

In [1]:
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
import gzip
import os
from scipy import ndimage
from six.moves import urllib

  from ._conv import register_converters as _register_converters


In [2]:
image_size = 28
num_channels = 1
pixel_depth = 255
num_labels = 10
validation_size = 5000

In [3]:
source_url = 'http://yann.lecun.com/exdb/mnist/'
data_dir = './mnist'

def maybe_download(filename):
    if not tf.gfile.Exists(data_dir):
        tf.gfile.MakeDirs(data_dir)
    file_path = os.path.join(data_dir, filename)
    
    if not tf.gfile.Exists(file_path):
        file_path, _ = urllib.request.urlretrieve(source_url + filename, file_path)
        
        with tf.gfile.Gfile(file_path) as f:
            size = f.size()
        print('Successfully download', filename, size, 'bytes.')
    return file_path

def extract_data(filename, num_images):
    """"Extracting images into a 4D tensor - [image_index,y,x,channels]"""
    print('Extracting', filename)
    with gzip.open(filename) as f:
        f.read(16)
        buf = f.read(image_size * image_size * num_images * num_channels)
        data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
        data = (data-(pixel_depth/2.0))/pixel_depth
        data = data.reshape(num_images,image_size,image_size,num_channels)
        data = np.reshape(data,[num_images,-1])
        
    return data

def extract_labels(filename,num_images):
    print('Ectracting', filename)
    with gzip.open(filename) as f:
        f.read(8)
        buf = f.read(1* num_images)
        labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
        num_labels_data = len(labels)
        one_hot_encoding = np.zeros((num_labels_data, num_labels))
        one_hot_encoding[np.arange(num_labels_data),labels] = 1
        one_hot_encoding = np.reshape(one_hot_encoding, [-1,num_labels])
        
    return one_hot_encoding

def expand_training_data(images,labels):
    expanded_images = []
    expanded_labels = []
    
    j = 0
    for x,y in zip(images,labels):
        j += 1
        if j%100 == 0:
            print('Expanding data: %03d / %03d' % (j, np.size(images,0)))
            
        expanded_images.append(x)
        expanded_labels.append(y)
        
        background_value = np.median(x)
        image = np.reshape(x,(-1,28))
        
        for i in range(4):
            angle = np.random.randint(-15,15,1)
            new_image = ndimage.rotate(image,angle,reshape=False,cval=background_value)
            
            shift = np.random.randint(-2,2,2)
            new_image2 = ndimage.shift(new_image, shift, cval=background_value)
            
            expanded_images.append(np.reshape(new_image2,784))
            expanded_labels.append(y)
        
        expanded_train_total_data = np.concatenate((expanded_images,expanded_labels),axis=1)
        np.random.shuffle(expanded_train_total_data)
        
        return expanded_train_total_data
    
def prepare_MNIST_data(use_data_augmentation=True):
    train_data_filename = maybe_download('train-images-idx3-ubyte.gz')
    train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz')
    test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz')
    test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz')
    
    train_data = extract_data(train_data_filename, 60000)
    train_labels = extract_labels(train_labels_filename, 60000)
    test_data = extract_data(test_data_filename, 10000)
    test_labels = extract_labels(test_labels_filename, 10000)

    validation_data = train_data[:validation_size,:]
    validation_labels = train_labels[:validation_size,:]
    train_data = train_data[validation_size:,:]
    train_labels = train_labels[validation_size:,:]
    
    if use_data_augmentation:
        train_total_data = expand_training_data(train_data, train_labels)
    else:
        train_total_data = np.concatenate((train_data,train_labels),axis=1)
        
    train_size = train_total_data.shape[0]
    return train_total_data,train_size, validation_data, validation_labels, test_data, test_labels
    

In [4]:
def CNN(inputs,is_training=True):
    batch_norm_params = {'is_training':is_training,'decay':0.9,'updates_collections':None}
    with slim.arg_scope([slim.conv2d, slim.fully_connected],
                        normalizer_fn=slim.batch_norm,
                       normalizer_params=batch_norm_params):
        x = tf.reshape(inputs,[-1,28,28,1])
        net = slim.conv2d(x,32,[5,5],scope='conv1')
        net = slim.max_pool2d(net,[2,2],scope='pool1')
        net = slim.conv2d(net,64,[5,5],scope='conv2')
        net = slim.max_pool2d(net,[2,2],scope='pool2')
        net = slim.flatten(net,scope='flatten3')
        
        net = slim.fully_connected(net,1024,scope='fc3')
        net = slim.dropout(net, is_training=is_training,scope='dropout3')
        outputs = slim.fully_connected(net,10,activation_fn=None,normalizer_fn=None,scope='fco')
        
    return outputs
        
        

In [5]:
model_dir = 'model/mnist_model.ckpt'
logs_dir = 'logs/train'

display_step = 100
validation_step = 500

def train():
    epochs = 1000
    batch_size = 32
    num_labels = 10
    
    train_total_data, train_size, validation_data,validation_labels,test_data, test_labels = prepare_MNIST_data(True)
    
    is_training = tf.placeholder(tf.bool, name='MDOE')
    
    x = tf.placeholder(tf.float32,[None,784])
    y_ = tf.placeholder(tf.float32,[None,10])
    
    y = CNN(x)
    
    with tf.name_scope('loss'):
        loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=y,labels=y_)
        
    tf.summary.scalar('loss',loss)
    
    with tf.name_scope('optimizer'):
        batch = tf.Variable(0)
        learning_rate = tf.train.exponential_decay(1e-4,
                                                   batch*batch_size, # current index into the dataset
                                                  train_size, # decay step
                                                  0.95, # decay rate
                                                  staircase=True)
        train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss,global_step=batch)
        
    tf.summary.scalar('learning_rate', learning_rate)
    
    with tf.name_scope('accuracy'):
        correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        
    tf.summary.scalar('accuracy',accuracy)
    
    merged_summary_op = tf.summary.merge_all()
    
    saver = tf.train.Saver()
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer(), feed_dict={is_training:True})
    
    total_batch = int(train_size/batch_size)
    
    if not tf.gfile.Exists(logs_dir):
        tf.gfile.MakeDirs(logs_dir)
        
    if not tf.gfile.Exists(model_dir):
        tf.gfile.MakeDirs(model_dir)
    
    summary_writer = tf.summary.FileWriter(logs_dir, graph=tf.get_default_graph())
    
    max_acc = 0
    
    for epoch in range(epochs):
        np.random.shuffle(train_total_data)
        train_data_ = train_total_data[:, :-num_labels]
        train_labels_ = train_total_data[:, -num_labels:]
        
        for i in range(total_batch):
            offset = (i*batch_size)%(train_size)
            batch_x = train_data_[offset:(offset+batch_size), :]
            batch_y = train_labels_[offset:(offset+batch_size), :]
            
            _, train_accuracy, summary = sess.run([train_step,accuracy, merged_summary_op],feed_dict={x:batch_x,y_:batch_y,is_training:True})
            
            summary_writer.add_summary(summary, epoch*total_batch+1)
            
            if i%display_step == 0:
                print('Epoch: %04d, batch_index: %4d/%4d, training_accuracy: %.5f'%(epoch+1, i, total_batch, train_accuracy))
                
            if i%validation_step == 0:
                validation_accuracy = sess.run(accuracy,feed_dict={x:validation_data,y_:validation_labels,is_training:False})
                print('Epoch: %04d, batch_index: %4d/%4d, validation_accuracy: %.5f'%(epoch+1, i, total_batch, validation_accuracy))
                
            if validation_accuracy > max_acc:
                max_acc = validation_accuracy
                save_path = saver.save(sess, model_dir)
    print('Optimizationn finished!')
    
#     saver.restore(sess, model_dir)
    
    test_size = test_labels.shape[0]
    acc_buffer= []
    
    y_final = sess.run(y, feed_dict={x:test_data,y_:test_labels, is_training:False})
    correct_prediction = tf.equal(tf.argmax(y_final,1),tf.argmax(test_labels,1))
    acc_buffer.append(tf.reduce_mean(tf.cast(correct_prediction,tf.float32)))
    
if __name__=='__main__':
    train()

Extracting ./mnist/train-images-idx3-ubyte.gz
Ectracting ./mnist/train-labels-idx1-ubyte.gz
Extracting ./mnist/t10k-images-idx3-ubyte.gz
Ectracting ./mnist/t10k-labels-idx1-ubyte.gz
Optimizationn finished!
