In [91]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
import time
import os

In [132]:
def load_mnist() :
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
    train_images = np.expand_dims(train_images, axis=-1)
    test_images = np.expand_dims(test_images, axis=-1)

    train_images, test_images = normalize(train_images, test_images)

    train_labels = to_categorical(train_labels, 10)
    test_labels = to_categorical(test_labels, 10) # [N,] -> [N, 10]
    
    return train_images, train_labels, test_images, test_labels

def normalize(train_images, test_images):
    train_images = train_images.astype(np.float32)/255.0
    test_images = test_images.astype(np.float32)/255.0
    
    return train_images, test_images

In [133]:
def flatten() :
    return tf.keras.layers.Flatten()

def dense(label_dim, initializer) :
    return tf.keras.layers.Dense(units=label_dim, use_bias=True, kernel_initializer=initializer)

def relu() :
    return tf.keras.layers.Activation(tf.keras.activations.relu)

def dropout(rate):
    return tf.keras.layers.Dropout(rate)

def batch_norm():
    return tf.keras.layers.BatchNormalization()

In [137]:
class CreateModel(tf.keras.Model):
    def __init__(self, label_dim):
        super(CreateModel, self).__init__()

        self.model = tf.keras.Sequential()
        self.model.add(flatten())

        initializer = tf.keras.initializers.he_uniform()
#         initializer = tf.keras.initializers.RandomNormal()
        for i in range(4):
            self.model.add(dense(512, initializer))
            self.model.add(batch_norm())
            self.model.add(relu())
#             self.model.add(dropout(rate=0.5))

        self.model.add(dense(label_dim, initializer))

    def __call__(self, x, training=None, mask=None):
        y = self.model(x)
        return y

In [138]:
def loss_fn(model, images, labels):
    logits = model(images, training=True)    
    loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_pred=logits, y_true=labels, from_logits=True))
    return loss

def acc_fn(model, images, labels):
    logits = model(images, training=False)
    is_equall = tf.equal(tf.argmax(logits, -1), tf.argmax(labels, -1))
    acc = tf.reduce_mean(tf.cast(is_equall, tf.float32))
    return acc

def grad_fn(model, images, labels):
    with tf.GradientTape() as tape:
        loss = loss_fn(model, images, labels)
    return tape.gradient(loss, model.trainable_variables)

In [139]:
train_x, train_y, test_x, test_y = load_mnist()

lr = 0.001
batch_size = 128

training_epochs = 1
training_iters = len(train_x)//batch_size

label_dim = 10

train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y)).shuffle(len(train_x)).batch(batch_size, drop_remainder=True).prefetch(batch_size)

test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y)).shuffle(len(test_x)).batch(len(test_x)).prefetch(len(test_x))

network = CreateModel(label_dim)

optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

In [140]:
start_time = time.time()

for epoch in range(1):
    for idx, (train_image, train_label) in enumerate(train_dataset):            
        grads = grad_fn(network, train_image, train_label)
        optimizer.apply_gradients(grads_and_vars=zip(grads, network.trainable_variables))

        train_loss = loss_fn(network, train_image, train_label)
        train_acc = acc_fn(network, train_image, train_label)

        for test_image, test_label in test_dataset:                
            test_acc = acc_fn(network, test_image, test_label)

        tf.summary.scalar(name='train_loss', data=train_loss, step=counter)
        tf.summary.scalar(name='train_acc', data=train_acc, step=counter)
        tf.summary.scalar(name='test_acc', data=test_acc, step=counter)

        print("epoch : {:2d} | {:3d}/{:3d} | time passed : {:4.4f}\n train_loss : {:.8f} | train_acc : {:.4f} | test_acc : {:.4f}".format(epoch+1, idx+1, training_iters, round(time.time()-start_time, 0), train_loss, train_acc, test_acc), end="\n\n")

epoch :  1 |   1/468 | time passed : 1.0000
 train_loss : 1.77174842 | train_acc : 0.4453 | test_acc : 0.2623

epoch :  1 |   2/468 | time passed : 2.0000
 train_loss : 1.53632891 | train_acc : 0.7188 | test_acc : 0.5561

epoch :  1 |   3/468 | time passed : 3.0000
 train_loss : 1.38182271 | train_acc : 0.7109 | test_acc : 0.6570

epoch :  1 |   4/468 | time passed : 3.0000
 train_loss : 1.07817650 | train_acc : 0.7109 | test_acc : 0.6902

epoch :  1 |   5/468 | time passed : 4.0000
 train_loss : 0.92769945 | train_acc : 0.7578 | test_acc : 0.7255

epoch :  1 |   6/468 | time passed : 5.0000
 train_loss : 0.63930744 | train_acc : 0.8281 | test_acc : 0.7345

epoch :  1 |   7/468 | time passed : 5.0000
 train_loss : 0.64495623 | train_acc : 0.8281 | test_acc : 0.7476

epoch :  1 |   8/468 | time passed : 6.0000
 train_loss : 0.62774467 | train_acc : 0.8203 | test_acc : 0.7753

epoch :  1 |   9/468 | time passed : 7.0000
 train_loss : 0.63348389 | train_acc : 0.8438 | test_acc : 0.7699

e

epoch :  1 |  75/468 | time passed : 58.0000
 train_loss : 0.26215035 | train_acc : 0.9219 | test_acc : 0.9308

epoch :  1 |  76/468 | time passed : 59.0000
 train_loss : 0.17986330 | train_acc : 0.9297 | test_acc : 0.9298

epoch :  1 |  77/468 | time passed : 60.0000
 train_loss : 0.27521828 | train_acc : 0.9375 | test_acc : 0.9337

epoch :  1 |  78/468 | time passed : 60.0000
 train_loss : 0.27460778 | train_acc : 0.9062 | test_acc : 0.9342

epoch :  1 |  79/468 | time passed : 61.0000
 train_loss : 0.21110541 | train_acc : 0.9297 | test_acc : 0.9342

epoch :  1 |  80/468 | time passed : 62.0000
 train_loss : 0.18238570 | train_acc : 0.9297 | test_acc : 0.9317

epoch :  1 |  81/468 | time passed : 63.0000
 train_loss : 0.26585668 | train_acc : 0.9375 | test_acc : 0.9354

epoch :  1 |  82/468 | time passed : 64.0000
 train_loss : 0.13445710 | train_acc : 0.9609 | test_acc : 0.9398

epoch :  1 |  83/468 | time passed : 65.0000
 train_loss : 0.19636586 | train_acc : 0.9453 | test_acc : 

epoch :  1 | 148/468 | time passed : 118.0000
 train_loss : 0.23713368 | train_acc : 0.9297 | test_acc : 0.9482

epoch :  1 | 149/468 | time passed : 119.0000
 train_loss : 0.13558771 | train_acc : 0.9453 | test_acc : 0.9485

epoch :  1 | 150/468 | time passed : 119.0000
 train_loss : 0.09962246 | train_acc : 0.9844 | test_acc : 0.9457

epoch :  1 | 151/468 | time passed : 120.0000
 train_loss : 0.13174029 | train_acc : 0.9609 | test_acc : 0.9451

epoch :  1 | 152/468 | time passed : 121.0000
 train_loss : 0.17289075 | train_acc : 0.9531 | test_acc : 0.9467

epoch :  1 | 153/468 | time passed : 122.0000
 train_loss : 0.10149488 | train_acc : 0.9609 | test_acc : 0.9477

epoch :  1 | 154/468 | time passed : 123.0000
 train_loss : 0.15675634 | train_acc : 0.9453 | test_acc : 0.9495

epoch :  1 | 155/468 | time passed : 123.0000
 train_loss : 0.19449580 | train_acc : 0.9453 | test_acc : 0.9521

epoch :  1 | 156/468 | time passed : 124.0000
 train_loss : 0.21338098 | train_acc : 0.9453 | te

epoch :  1 | 221/468 | time passed : 179.0000
 train_loss : 0.12787731 | train_acc : 0.9688 | test_acc : 0.9499

epoch :  1 | 222/468 | time passed : 179.0000
 train_loss : 0.14094058 | train_acc : 0.9688 | test_acc : 0.9520

epoch :  1 | 223/468 | time passed : 180.0000
 train_loss : 0.14615843 | train_acc : 0.9531 | test_acc : 0.9513

epoch :  1 | 224/468 | time passed : 181.0000
 train_loss : 0.07848596 | train_acc : 0.9844 | test_acc : 0.9487

epoch :  1 | 225/468 | time passed : 182.0000
 train_loss : 0.15264003 | train_acc : 0.9453 | test_acc : 0.9421

epoch :  1 | 226/468 | time passed : 182.0000
 train_loss : 0.15514678 | train_acc : 0.9609 | test_acc : 0.9404

epoch :  1 | 227/468 | time passed : 183.0000
 train_loss : 0.09512448 | train_acc : 0.9766 | test_acc : 0.9458

epoch :  1 | 228/468 | time passed : 183.0000
 train_loss : 0.13944964 | train_acc : 0.9766 | test_acc : 0.9498

epoch :  1 | 229/468 | time passed : 184.0000
 train_loss : 0.17793737 | train_acc : 0.9531 | te

epoch :  1 | 294/468 | time passed : 228.0000
 train_loss : 0.06853327 | train_acc : 0.9766 | test_acc : 0.9585

epoch :  1 | 295/468 | time passed : 229.0000
 train_loss : 0.06973559 | train_acc : 0.9766 | test_acc : 0.9553

epoch :  1 | 296/468 | time passed : 229.0000
 train_loss : 0.17803952 | train_acc : 0.9609 | test_acc : 0.9549

epoch :  1 | 297/468 | time passed : 230.0000
 train_loss : 0.11704922 | train_acc : 0.9766 | test_acc : 0.9546

epoch :  1 | 298/468 | time passed : 230.0000
 train_loss : 0.05215214 | train_acc : 0.9844 | test_acc : 0.9539

epoch :  1 | 299/468 | time passed : 231.0000
 train_loss : 0.04847595 | train_acc : 0.9844 | test_acc : 0.9591

epoch :  1 | 300/468 | time passed : 232.0000
 train_loss : 0.10055123 | train_acc : 0.9688 | test_acc : 0.9595

epoch :  1 | 301/468 | time passed : 233.0000
 train_loss : 0.08964846 | train_acc : 0.9609 | test_acc : 0.9586

epoch :  1 | 302/468 | time passed : 233.0000
 train_loss : 0.05957304 | train_acc : 0.9844 | te

epoch :  1 | 367/468 | time passed : 278.0000
 train_loss : 0.15065756 | train_acc : 0.9453 | test_acc : 0.9629

epoch :  1 | 368/468 | time passed : 279.0000
 train_loss : 0.15073659 | train_acc : 0.9453 | test_acc : 0.9645

epoch :  1 | 369/468 | time passed : 280.0000
 train_loss : 0.04078939 | train_acc : 0.9844 | test_acc : 0.9624

epoch :  1 | 370/468 | time passed : 280.0000
 train_loss : 0.18484649 | train_acc : 0.9766 | test_acc : 0.9600

epoch :  1 | 371/468 | time passed : 281.0000
 train_loss : 0.10221320 | train_acc : 0.9844 | test_acc : 0.9618

epoch :  1 | 372/468 | time passed : 282.0000
 train_loss : 0.10385568 | train_acc : 0.9609 | test_acc : 0.9623

epoch :  1 | 373/468 | time passed : 282.0000
 train_loss : 0.12346103 | train_acc : 0.9609 | test_acc : 0.9605

epoch :  1 | 374/468 | time passed : 283.0000
 train_loss : 0.16426972 | train_acc : 0.9766 | test_acc : 0.9604

epoch :  1 | 375/468 | time passed : 284.0000
 train_loss : 0.06992332 | train_acc : 0.9844 | te

epoch :  1 | 440/468 | time passed : 327.0000
 train_loss : 0.06928790 | train_acc : 0.9766 | test_acc : 0.9623

epoch :  1 | 441/468 | time passed : 328.0000
 train_loss : 0.02539056 | train_acc : 1.0000 | test_acc : 0.9610

epoch :  1 | 442/468 | time passed : 329.0000
 train_loss : 0.08963461 | train_acc : 0.9766 | test_acc : 0.9604

epoch :  1 | 443/468 | time passed : 329.0000
 train_loss : 0.09402244 | train_acc : 0.9688 | test_acc : 0.9636

epoch :  1 | 444/468 | time passed : 330.0000
 train_loss : 0.06850176 | train_acc : 0.9766 | test_acc : 0.9660

epoch :  1 | 445/468 | time passed : 331.0000
 train_loss : 0.06003134 | train_acc : 0.9766 | test_acc : 0.9639

epoch :  1 | 446/468 | time passed : 331.0000
 train_loss : 0.08623245 | train_acc : 0.9766 | test_acc : 0.9629

epoch :  1 | 447/468 | time passed : 332.0000
 train_loss : 0.08731760 | train_acc : 0.9766 | test_acc : 0.9609

epoch :  1 | 448/468 | time passed : 333.0000
 train_loss : 0.08389469 | train_acc : 0.9688 | te