In [2]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
tf.enable_eager_execution()

In [3]:
def load_mnist():
    (train_data, train_labels), (test_data, test_labels) = mnist.load_data()
    train_data = np.expand_dims(train_data, axis = -1)
    test_data = np.expand_dims(test_data, axis = -1)
    
    train_data, test_data = normalize(train_data, test_data)
    
    train_labels = to_categorical(train_labels, 10)
    test_labels = to_categorical(test_labels, 10)
    return train_data, train_labels, test_data, test_labels

def normalize(train_data, test_data):
    train_data = train_data.astype(np.float32)/255.0
    test_data = test_data.astype(np.float32)/255.0
    
    return train_data, test_data

In [4]:
def loss_fn(model, images, labels):
    logits = model(images, training = True)
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits, labels))
    return loss

def accuracy_fn(model, images, labels):
    logits = model(images, training = False)
    prediction = tf.equal(tf.argmax(logits, -1), tf.argmax(labels, -1))
    accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32))
    return accuracy

def grad(model, images, labels):
    with tf.GradientTape() as tape:
        loss = loss_fn(model, images, labels)
    return tape.gradient(loss, model.variables)

In [31]:
def flatten():
    return tf.keras.layers.Flatten()

def dense(label_dim, weight_init):
    return tf.keras.layers.Dense(units = label_dim, use_bias= True, \
                                 kernel_initializer = weight_init)

def relu():
    return tf.keras.layers.Activation(tf.keras.activations.relu)

def dropout(rate):
    return tf.keras.layers.Dropout(rate)

def batch_norm():
    return tf.keras.layers.BatchNormalization()


In [27]:
class create_model(tf.keras.Model):
    def __init__(self, label_dim):
        super(create_model, self).__init__()
        weight_init = tf.keras.initializers.glorot_normal()
        
        self.model = tf.keras.Sequential()
        self.model.add(flatten())
        
        for i in range(2):
            self.model.add(dense(256, weight_init))
            self.model.add(relu())
            self.model.add(dropout(rate = 0.5))
        
        self.model.add(dense(label_dim, weight_init))
        
    def call(self, x, training = None, mask = None):
        x = self.model(x)
        
        return x

In [21]:
train_x, train_y, test_x, test_y = load_mnist()

learning_rate = 0.001
batch_size = 128

training_epochs = 1
training_iterations = len(train_x)//batch_size

label_dim = 10

train_flag = True

In [28]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y)).\
        shuffle(buffer_size = 100000).\
        prefetch(buffer_size = batch_size).\
        batch(batch_size).\
        repeat()

test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y)).\
        shuffle(buffer_size = 100000).\
        prefetch(buffer_size = len(test_x)).\
        batch(len(test_x)).\
        repeat()

train_iterator = train_dataset.make_one_shot_iterator()
test_iterator = test_dataset.make_one_shot_iterator()


In [29]:
network = create_model(label_dim)

optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)


In [30]:
for epoch in range(training_epochs):
    for idx in range(training_iterations):
        train_input, train_label = train_iterator.get_next()
        
        grads = grad(network, train_input, train_label)
        optimizer.apply_gradients(grads_and_vars = zip(grads, network.variables))
        
        train_loss = loss_fn(network, train_input, train_label)
        train_accuracy = accuracy_fn(network, train_input, train_label)
        
        test_input, test_label = test_iterator.get_next()
        
        test_accuracy = accuracy_fn(network, test_input, test_label)
        
        print("Epoch: [%2d] [%5d/%5d], train_loss: %.8f, train_accuracy: %.4f, test_Accuracy: %.4f"\
             % (epoch, idx, training_iterations, train_loss, train_accuracy, test_accuracy))


Epoch: [ 0] [    0/  468], train_loss: -4.81101131, train_accuracy: 0.1250, test_Accuracy: 0.0853
Epoch: [ 0] [    1/  468], train_loss: -12.21676254, train_accuracy: 0.0859, test_Accuracy: 0.0807
Epoch: [ 0] [    2/  468], train_loss: -19.94591331, train_accuracy: 0.0703, test_Accuracy: 0.0821
Epoch: [ 0] [    3/  468], train_loss: -29.86602402, train_accuracy: 0.1016, test_Accuracy: 0.0904
Epoch: [ 0] [    4/  468], train_loss: -44.56247330, train_accuracy: 0.1016, test_Accuracy: 0.0998
Epoch: [ 0] [    5/  468], train_loss: -57.31587219, train_accuracy: 0.0859, test_Accuracy: 0.1033
Epoch: [ 0] [    6/  468], train_loss: -78.05364990, train_accuracy: 0.1094, test_Accuracy: 0.1024
Epoch: [ 0] [    7/  468], train_loss: -97.53433228, train_accuracy: 0.0625, test_Accuracy: 0.1008
Epoch: [ 0] [    8/  468], train_loss: -124.20829773, train_accuracy: 0.0781, test_Accuracy: 0.1011
Epoch: [ 0] [    9/  468], train_loss: -153.40046692, train_accuracy: 0.0938, test_Accuracy: 0.1010
Epoch: [ 

Epoch: [ 0] [   82/  468], train_loss: -138773.15625000, train_accuracy: 0.1016, test_Accuracy: 0.1010
Epoch: [ 0] [   83/  468], train_loss: -143422.34375000, train_accuracy: 0.1016, test_Accuracy: 0.1010
Epoch: [ 0] [   84/  468], train_loss: -157746.81250000, train_accuracy: 0.1094, test_Accuracy: 0.1010
Epoch: [ 0] [   85/  468], train_loss: -169514.18750000, train_accuracy: 0.1562, test_Accuracy: 0.1010
Epoch: [ 0] [   86/  468], train_loss: -172959.18750000, train_accuracy: 0.0859, test_Accuracy: 0.1010
Epoch: [ 0] [   87/  468], train_loss: -173855.64062500, train_accuracy: 0.0859, test_Accuracy: 0.1010
Epoch: [ 0] [   88/  468], train_loss: -183854.75000000, train_accuracy: 0.0703, test_Accuracy: 0.1010
Epoch: [ 0] [   89/  468], train_loss: -183619.46875000, train_accuracy: 0.1641, test_Accuracy: 0.1010
Epoch: [ 0] [   90/  468], train_loss: -208520.32812500, train_accuracy: 0.1016, test_Accuracy: 0.1010
Epoch: [ 0] [   91/  468], train_loss: -211673.21875000, train_accuracy: 

Epoch: [ 0] [  162/  468], train_loss: -1870792.37500000, train_accuracy: 0.1094, test_Accuracy: 0.1010
Epoch: [ 0] [  163/  468], train_loss: -1914674.75000000, train_accuracy: 0.1250, test_Accuracy: 0.1010
Epoch: [ 0] [  164/  468], train_loss: -1880255.50000000, train_accuracy: 0.1172, test_Accuracy: 0.1010
Epoch: [ 0] [  165/  468], train_loss: -2119336.50000000, train_accuracy: 0.0938, test_Accuracy: 0.1010
Epoch: [ 0] [  166/  468], train_loss: -2089560.87500000, train_accuracy: 0.0938, test_Accuracy: 0.1010
Epoch: [ 0] [  167/  468], train_loss: -2230002.25000000, train_accuracy: 0.0703, test_Accuracy: 0.1010
Epoch: [ 0] [  168/  468], train_loss: -2261762.75000000, train_accuracy: 0.1484, test_Accuracy: 0.1010
Epoch: [ 0] [  169/  468], train_loss: -2258359.50000000, train_accuracy: 0.1094, test_Accuracy: 0.1010
Epoch: [ 0] [  170/  468], train_loss: -2283712.50000000, train_accuracy: 0.1094, test_Accuracy: 0.1010
Epoch: [ 0] [  171/  468], train_loss: -2207588.50000000, train_

Epoch: [ 0] [  242/  468], train_loss: -8093262.50000000, train_accuracy: 0.0781, test_Accuracy: 0.1010
Epoch: [ 0] [  243/  468], train_loss: -8205162.00000000, train_accuracy: 0.1641, test_Accuracy: 0.1010
Epoch: [ 0] [  244/  468], train_loss: -8423259.00000000, train_accuracy: 0.0859, test_Accuracy: 0.1010
Epoch: [ 0] [  245/  468], train_loss: -8851444.00000000, train_accuracy: 0.0938, test_Accuracy: 0.1010
Epoch: [ 0] [  246/  468], train_loss: -8710902.00000000, train_accuracy: 0.0859, test_Accuracy: 0.1010
Epoch: [ 0] [  247/  468], train_loss: -8485716.00000000, train_accuracy: 0.0859, test_Accuracy: 0.1010
Epoch: [ 0] [  248/  468], train_loss: -8955540.00000000, train_accuracy: 0.1172, test_Accuracy: 0.1010
Epoch: [ 0] [  249/  468], train_loss: -9243786.00000000, train_accuracy: 0.0547, test_Accuracy: 0.1010
Epoch: [ 0] [  250/  468], train_loss: -9125780.00000000, train_accuracy: 0.1406, test_Accuracy: 0.1010
Epoch: [ 0] [  251/  468], train_loss: -9572522.00000000, train_

Epoch: [ 0] [  322/  468], train_loss: -21668486.00000000, train_accuracy: 0.0781, test_Accuracy: 0.1010
Epoch: [ 0] [  323/  468], train_loss: -22195286.00000000, train_accuracy: 0.1172, test_Accuracy: 0.1010
Epoch: [ 0] [  324/  468], train_loss: -21626132.00000000, train_accuracy: 0.1016, test_Accuracy: 0.1010
Epoch: [ 0] [  325/  468], train_loss: -23186536.00000000, train_accuracy: 0.0859, test_Accuracy: 0.1010
Epoch: [ 0] [  326/  468], train_loss: -23693040.00000000, train_accuracy: 0.1328, test_Accuracy: 0.1010
Epoch: [ 0] [  327/  468], train_loss: -23672840.00000000, train_accuracy: 0.0781, test_Accuracy: 0.1010
Epoch: [ 0] [  328/  468], train_loss: -24569696.00000000, train_accuracy: 0.1250, test_Accuracy: 0.1010
Epoch: [ 0] [  329/  468], train_loss: -23978496.00000000, train_accuracy: 0.1172, test_Accuracy: 0.1010
Epoch: [ 0] [  330/  468], train_loss: -22283712.00000000, train_accuracy: 0.0859, test_Accuracy: 0.1010
Epoch: [ 0] [  331/  468], train_loss: -25268094.000000

Epoch: [ 0] [  402/  468], train_loss: -49087312.00000000, train_accuracy: 0.1250, test_Accuracy: 0.1010
Epoch: [ 0] [  403/  468], train_loss: -48570904.00000000, train_accuracy: 0.0859, test_Accuracy: 0.1010
Epoch: [ 0] [  404/  468], train_loss: -50088328.00000000, train_accuracy: 0.1016, test_Accuracy: 0.1010
Epoch: [ 0] [  405/  468], train_loss: -48455488.00000000, train_accuracy: 0.1484, test_Accuracy: 0.1010
Epoch: [ 0] [  406/  468], train_loss: -47477600.00000000, train_accuracy: 0.0781, test_Accuracy: 0.1010
Epoch: [ 0] [  407/  468], train_loss: -49269456.00000000, train_accuracy: 0.1172, test_Accuracy: 0.1010
Epoch: [ 0] [  408/  468], train_loss: -48019020.00000000, train_accuracy: 0.0703, test_Accuracy: 0.1010
Epoch: [ 0] [  409/  468], train_loss: -48608832.00000000, train_accuracy: 0.0703, test_Accuracy: 0.1010
Epoch: [ 0] [  410/  468], train_loss: -48025544.00000000, train_accuracy: 0.0703, test_Accuracy: 0.1010
Epoch: [ 0] [  411/  468], train_loss: -51018888.000000

In [25]:
print("test_Accuracy: %.4f" % (test_accuracy))

test_Accuracy: 0.0958
