In [23]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist

tf.__version__

'2.1.0'

In [24]:
(train_data, train_labels), (test_data, test_labels) = mnist.load_data()

train_data = np.expand_dims(train_data, -1)
test_data = np.expand_dims(test_data, -1)

train_data = train_data.astype(np.float32) / 255.
test_data = test_data.astype(np.float32) / 255.

train_labels = to_categorical(train_labels, 10)
test_labels = to_categorical(test_labels, 10)

In [58]:
def loss_fn(model, images, labels):
    logits = model(images, training=True)
    return tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_pred=logits, y_true=labels, from_logits=True))

def accuracy_fn(model, images, labels):
    logits = model(images, training=False)
    prediction = tf.equal(tf.argmax(logits, -1), tf.argmax(labels, -1))
    return tf.reduce_mean(tf.cast(prediction, dtype=tf.float32))

def grad(model, images, labels):
    with tf.GradientTape() as tape:
        loss = loss_fn(model, images, labels)
    return tape.gradient(loss, model.variables)

In [59]:
def flatten():
    return tf.keras.layers.Flatten()

def dense(label_dim, weight_init):
    return tf.keras.layers.Dense(units=label_dim, use_bias=True, kernel_initializer=weight_init)

def relu():
    return tf.keras.layers.Activation(tf.keras.activations.relu)

def dropout(rate):
    return tf.keras.layers.Dropout(rate)

In [60]:
class model(tf.keras.Model):
    def __init__(self, label_dim):
        super(model, self).__init__()
        weight_init = tf.keras.initializers.glorot_uniform()
        
        self.model = tf.keras.Sequential()
        self.model.add(flatten())
        
        for i in range(4):
            self.model.add(dense(512, weight_init))
            self.model.add(relu())
            self.model.add(dropout(rate=0.5))
            
        self.model.add(dense(label_dim, weight_init))
    def call(self, x, training=None, mask=None):
        x = self.model(x)
        return x

In [61]:
learning_rate = 0.001
batch_size = 128

label_dim = 10

train_dataset = tf.data.Dataset.from_tensor_slices((
    train_data, train_labels
)). shuffle(buffer_size=100000).\
    prefetch(buffer_size=batch_size).\
    batch(batch_size, drop_remainder=True).\
    repeat(5)

test_dataset = tf.data.Dataset.from_tensor_slices((
    test_data, test_labels
)). shuffle(buffer_size=100000).\
    prefetch(buffer_size=batch_size).\
    batch(len(test_data)).\
    repeat(5)

In [63]:
network = model(label_dim)

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

iterations = len(train_data) // batch_size

print(iterations)

for idx, (train_x, train_label) in enumerate(train_dataset):
    grads = grad(network, train_x, train_label)
    optimizer.apply_gradients(grads_and_vars=zip(grads, network.variables))
    
    train_loss = loss_fn(network, train_x, train_label)
    train_accuracy = accuracy_fn(network, train_x, train_label)
    
    for test_x, test_label in test_dataset:
        test_accuracy = accuracy_fn(network, test_x, test_label)
    
    print('[{:5d}/{:5d}] train_loss: {:2.4f} | train_accuracy: {:2.4f} | test_accuracy: {:2.4f}'.format(idx, iterations, train_loss, train_accuracy, test_accuracy))

468
[    0/  468] train_loss: 2.3321 | train_accuracy: 0.1562 | test_accuracy: 0.1189
[    1/  468] train_loss: 2.3211 | train_accuracy: 0.2812 | test_accuracy: 0.2396
[    2/  468] train_loss: 2.3134 | train_accuracy: 0.3281 | test_accuracy: 0.2856
[    3/  468] train_loss: 2.2850 | train_accuracy: 0.3047 | test_accuracy: 0.3261
[    4/  468] train_loss: 2.2796 | train_accuracy: 0.4297 | test_accuracy: 0.3541
[    5/  468] train_loss: 2.1566 | train_accuracy: 0.4375 | test_accuracy: 0.4070
[    6/  468] train_loss: 2.2448 | train_accuracy: 0.5000 | test_accuracy: 0.4513
[    7/  468] train_loss: 2.2406 | train_accuracy: 0.4844 | test_accuracy: 0.4851
[    8/  468] train_loss: 2.2364 | train_accuracy: 0.4453 | test_accuracy: 0.5027
[    9/  468] train_loss: 2.2290 | train_accuracy: 0.5625 | test_accuracy: 0.4989
[   10/  468] train_loss: 2.1757 | train_accuracy: 0.4922 | test_accuracy: 0.5141
[   11/  468] train_loss: 2.0926 | train_accuracy: 0.5859 | test_accuracy: 0.5340
[   12/  468

[  100/  468] train_loss: 0.3567 | train_accuracy: 0.9375 | test_accuracy: 0.8963
[  101/  468] train_loss: 0.4539 | train_accuracy: 0.8906 | test_accuracy: 0.9002
[  102/  468] train_loss: 0.5126 | train_accuracy: 0.8906 | test_accuracy: 0.9039
[  103/  468] train_loss: 0.4886 | train_accuracy: 0.8984 | test_accuracy: 0.9038
[  104/  468] train_loss: 0.5827 | train_accuracy: 0.8984 | test_accuracy: 0.9047
[  105/  468] train_loss: 0.4299 | train_accuracy: 0.9062 | test_accuracy: 0.9048
[  106/  468] train_loss: 0.4808 | train_accuracy: 0.9062 | test_accuracy: 0.9051
[  107/  468] train_loss: 0.4471 | train_accuracy: 0.9375 | test_accuracy: 0.9069
[  108/  468] train_loss: 0.6375 | train_accuracy: 0.9141 | test_accuracy: 0.9093
[  109/  468] train_loss: 0.3774 | train_accuracy: 0.9375 | test_accuracy: 0.9093
[  110/  468] train_loss: 0.4992 | train_accuracy: 0.8828 | test_accuracy: 0.9097
[  111/  468] train_loss: 0.5875 | train_accuracy: 0.9062 | test_accuracy: 0.9061
[  112/  468] tr

[  200/  468] train_loss: 0.2922 | train_accuracy: 0.9609 | test_accuracy: 0.9356
[  201/  468] train_loss: 0.2536 | train_accuracy: 0.9531 | test_accuracy: 0.9345
[  202/  468] train_loss: 0.3375 | train_accuracy: 0.9141 | test_accuracy: 0.9343
[  203/  468] train_loss: 0.2630 | train_accuracy: 0.9688 | test_accuracy: 0.9353
[  204/  468] train_loss: 0.3966 | train_accuracy: 0.9062 | test_accuracy: 0.9376
[  205/  468] train_loss: 0.1863 | train_accuracy: 0.9375 | test_accuracy: 0.9361
[  206/  468] train_loss: 0.2942 | train_accuracy: 0.9531 | test_accuracy: 0.9335
[  207/  468] train_loss: 0.4989 | train_accuracy: 0.8906 | test_accuracy: 0.9307
[  208/  468] train_loss: 0.3160 | train_accuracy: 0.9531 | test_accuracy: 0.9268
[  209/  468] train_loss: 0.2426 | train_accuracy: 0.9688 | test_accuracy: 0.9237
[  210/  468] train_loss: 0.3659 | train_accuracy: 0.9219 | test_accuracy: 0.9226
[  211/  468] train_loss: 0.3595 | train_accuracy: 0.9297 | test_accuracy: 0.9242
[  212/  468] tr

[  300/  468] train_loss: 0.4497 | train_accuracy: 0.9297 | test_accuracy: 0.9418
[  301/  468] train_loss: 0.1952 | train_accuracy: 0.9688 | test_accuracy: 0.9424
[  302/  468] train_loss: 0.4779 | train_accuracy: 0.8906 | test_accuracy: 0.9440
[  303/  468] train_loss: 0.3298 | train_accuracy: 0.9453 | test_accuracy: 0.9447
[  304/  468] train_loss: 0.2841 | train_accuracy: 0.9453 | test_accuracy: 0.9457
[  305/  468] train_loss: 0.1736 | train_accuracy: 0.9688 | test_accuracy: 0.9470
[  306/  468] train_loss: 0.2182 | train_accuracy: 0.9375 | test_accuracy: 0.9484
[  307/  468] train_loss: 0.3386 | train_accuracy: 0.9141 | test_accuracy: 0.9489
[  308/  468] train_loss: 0.1802 | train_accuracy: 0.9766 | test_accuracy: 0.9504
[  309/  468] train_loss: 0.2875 | train_accuracy: 0.9375 | test_accuracy: 0.9504
[  310/  468] train_loss: 0.4193 | train_accuracy: 0.9453 | test_accuracy: 0.9509
[  311/  468] train_loss: 0.2073 | train_accuracy: 0.9688 | test_accuracy: 0.9510
[  312/  468] tr

[  400/  468] train_loss: 0.2291 | train_accuracy: 0.9531 | test_accuracy: 0.9553
[  401/  468] train_loss: 0.2759 | train_accuracy: 0.9375 | test_accuracy: 0.9552
[  402/  468] train_loss: 0.2973 | train_accuracy: 0.9453 | test_accuracy: 0.9550
[  403/  468] train_loss: 0.1826 | train_accuracy: 0.9688 | test_accuracy: 0.9557
[  404/  468] train_loss: 0.2337 | train_accuracy: 0.9375 | test_accuracy: 0.9551
[  405/  468] train_loss: 0.2627 | train_accuracy: 0.9688 | test_accuracy: 0.9547
[  406/  468] train_loss: 0.1689 | train_accuracy: 0.9844 | test_accuracy: 0.9541
[  407/  468] train_loss: 0.1985 | train_accuracy: 0.9453 | test_accuracy: 0.9540
[  408/  468] train_loss: 0.2900 | train_accuracy: 0.9531 | test_accuracy: 0.9541
[  409/  468] train_loss: 0.2810 | train_accuracy: 0.9609 | test_accuracy: 0.9547
[  410/  468] train_loss: 0.2188 | train_accuracy: 0.9609 | test_accuracy: 0.9546
[  411/  468] train_loss: 0.2895 | train_accuracy: 0.9609 | test_accuracy: 0.9543
[  412/  468] tr

[  500/  468] train_loss: 0.2105 | train_accuracy: 0.9688 | test_accuracy: 0.9556
[  501/  468] train_loss: 0.1199 | train_accuracy: 0.9844 | test_accuracy: 0.9545
[  502/  468] train_loss: 0.2029 | train_accuracy: 0.9609 | test_accuracy: 0.9529
[  503/  468] train_loss: 0.4400 | train_accuracy: 0.8906 | test_accuracy: 0.9520
[  504/  468] train_loss: 0.1824 | train_accuracy: 0.9844 | test_accuracy: 0.9523
[  505/  468] train_loss: 0.1747 | train_accuracy: 0.9531 | test_accuracy: 0.9533
[  506/  468] train_loss: 0.2677 | train_accuracy: 0.9531 | test_accuracy: 0.9543
[  507/  468] train_loss: 0.2422 | train_accuracy: 0.9453 | test_accuracy: 0.9557
[  508/  468] train_loss: 0.1747 | train_accuracy: 0.9922 | test_accuracy: 0.9557
[  509/  468] train_loss: 0.1673 | train_accuracy: 0.9844 | test_accuracy: 0.9569
[  510/  468] train_loss: 0.2098 | train_accuracy: 0.9531 | test_accuracy: 0.9563
[  511/  468] train_loss: 0.2242 | train_accuracy: 0.9531 | test_accuracy: 0.9571
[  512/  468] tr

KeyboardInterrupt: 