In [1]:
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
import numpy as np

tf.__version__

'2.1.0'

In [2]:
def load_mnist():
    (train_data, train_labels), (test_data, test_labels) = mnist.load_data()
    train_data = np.expand_dims(train_data, axis=-1)
    test_data = np.expand_dims(test_data, axis=-1)
    
    train_data, test_data = normalization(train_data, test_data)
    
    train_labels = to_categorical(train_labels, 10)
    test_labels = to_categorical(test_labels, 10)
    
    return train_data, train_labels, test_data, test_labels

def normalization(train_data, test_data):
    train_data = train_data.astype(np.float32) / 255.0
    test_data = test_data.astype(np.float32) / 255.0
    return train_data, test_data

In [3]:
def loss_fn(model, images, labels):
    logits = model(images, training=True)
    return tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_pred=logits, y_true=labels, from_logits=True))
    # from_logits : y_pred 가 logits 인지의 유무

def accuracy_fn(model, images, labels):
    logits = model(images, training=False)
    prediction = tf.equal(tf.argmax(logits, axis=-1), tf.argmax(labels, axis=-1))
    accuracy = tf.reduce_mean(tf.cast(prediction, dtype=tf.float32))
    return accuracy

def grad(model, images, labels):
    with tf.GradientTape() as tape:
        loss = loss_fn(model, images, labels)
    return tape.gradient(loss, model.variables)

In [4]:
def flatten():
    return tf.keras.layers.Flatten()

def Dense(label_dim, weight_init):
    return tf.keras.layers.Dense(units=label_dim, use_bias=True, kernel_initializer=weight_init)

def relu():
    return tf.keras.layers.Activation(tf.keras.activations.relu)

In [5]:
class create_model(tf.keras.Model):
    def __init__(self, label_dim):
        super(create_model, self).__init__()
        weight_init = tf.keras.initializers.glorot_uniform()
        
        self.model = tf.keras.Sequential()
        self.model.add(flatten())
        
        for i in range(2):
            self.model.add(Dense(256, weight_init))
            self.model.add(relu())
            
        self.model.add(Dense(label_dim, weight_init))
        
    def call(self, x, training=None, mask=None):
        # training 여부를 묻는 training
        # tf.keras.model를 상속받아서 생기는 옵션
        x = self.model(x)
        return x

In [6]:
train_data, train_labels, test_data, test_labels = load_mnist()

learning_rate = 0.001
batch_size = 128

label_dim = 10

train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_labels)).\
    shuffle(buffer_size=100000).\
    prefetch(buffer_size=batch_size).\
    batch(batch_size, drop_remainder=True)

test_dataset = tf.data.Dataset.from_tensor_slices((test_data, test_labels)).\
    shuffle(buffer_size=100000).\
    prefetch(buffer_size=len(test_data)).\
    batch(len(test_data))

In [10]:
network = create_model(label_dim)

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

for idx, (train_input, train_label) in enumerate(train_dataset):
    grads = grad(network, train_input, train_label)
    optimizer.apply_gradients(grads_and_vars=zip(grads, network.variables))
    
    train_loss = loss_fn(network, train_input, train_label)
    train_accuracy = accuracy_fn(network, train_input, train_label)
    
    for test_input, test_label in test_dataset:
        test_accuracy = accuracy_fn(network, test_input, test_label)
        
    print("[{:5d}]/[{:5d}] | train_loss: {:2.4f} | train_accuracy: {:2.4f} | test_accuracy: {:2.4f}".\
         format(idx, len(train_data)//batch_size, train_loss, train_accuracy, test_accuracy))

[    0]/[  468] | train_loss: 2.0816 | train_accuracy: 0.3203 | test_accuracy: 0.2185
[    1]/[  468] | train_loss: 2.0510 | train_accuracy: 0.3984 | test_accuracy: 0.3208
[    2]/[  468] | train_loss: 1.9042 | train_accuracy: 0.3906 | test_accuracy: 0.3891
[    3]/[  468] | train_loss: 1.8139 | train_accuracy: 0.4922 | test_accuracy: 0.4644
[    4]/[  468] | train_loss: 1.8149 | train_accuracy: 0.5156 | test_accuracy: 0.5699
[    5]/[  468] | train_loss: 1.6218 | train_accuracy: 0.6484 | test_accuracy: 0.6738
[    6]/[  468] | train_loss: 1.4578 | train_accuracy: 0.7500 | test_accuracy: 0.7223
[    7]/[  468] | train_loss: 1.3399 | train_accuracy: 0.8203 | test_accuracy: 0.7370
[    8]/[  468] | train_loss: 1.2483 | train_accuracy: 0.6875 | test_accuracy: 0.7363
[    9]/[  468] | train_loss: 1.2308 | train_accuracy: 0.7266 | test_accuracy: 0.7621
[   10]/[  468] | train_loss: 1.0475 | train_accuracy: 0.8125 | test_accuracy: 0.7786
[   11]/[  468] | train_loss: 0.9551 | train_accuracy:

[   96]/[  468] | train_loss: 0.1985 | train_accuracy: 0.9375 | test_accuracy: 0.9214
[   97]/[  468] | train_loss: 0.1338 | train_accuracy: 0.9844 | test_accuracy: 0.9209
[   98]/[  468] | train_loss: 0.1462 | train_accuracy: 0.9609 | test_accuracy: 0.9186
[   99]/[  468] | train_loss: 0.2668 | train_accuracy: 0.9141 | test_accuracy: 0.9168
[  100]/[  468] | train_loss: 0.2305 | train_accuracy: 0.9219 | test_accuracy: 0.9178
[  101]/[  468] | train_loss: 0.1621 | train_accuracy: 0.9609 | test_accuracy: 0.9214
[  102]/[  468] | train_loss: 0.2293 | train_accuracy: 0.9219 | test_accuracy: 0.9216
[  103]/[  468] | train_loss: 0.1938 | train_accuracy: 0.9297 | test_accuracy: 0.9204
[  104]/[  468] | train_loss: 0.1952 | train_accuracy: 0.9453 | test_accuracy: 0.9186
[  105]/[  468] | train_loss: 0.3069 | train_accuracy: 0.9219 | test_accuracy: 0.9157
[  106]/[  468] | train_loss: 0.2392 | train_accuracy: 0.9375 | test_accuracy: 0.9135
[  107]/[  468] | train_loss: 0.2018 | train_accuracy:

[  192]/[  468] | train_loss: 0.1841 | train_accuracy: 0.9609 | test_accuracy: 0.9365
[  193]/[  468] | train_loss: 0.1927 | train_accuracy: 0.9375 | test_accuracy: 0.9368
[  194]/[  468] | train_loss: 0.2615 | train_accuracy: 0.9062 | test_accuracy: 0.9393
[  195]/[  468] | train_loss: 0.0828 | train_accuracy: 0.9844 | test_accuracy: 0.9403
[  196]/[  468] | train_loss: 0.2031 | train_accuracy: 0.9297 | test_accuracy: 0.9419
[  197]/[  468] | train_loss: 0.1621 | train_accuracy: 0.9688 | test_accuracy: 0.9435
[  198]/[  468] | train_loss: 0.2155 | train_accuracy: 0.9375 | test_accuracy: 0.9443
[  199]/[  468] | train_loss: 0.3267 | train_accuracy: 0.8906 | test_accuracy: 0.9452
[  200]/[  468] | train_loss: 0.1910 | train_accuracy: 0.9297 | test_accuracy: 0.9467
[  201]/[  468] | train_loss: 0.2309 | train_accuracy: 0.9453 | test_accuracy: 0.9461
[  202]/[  468] | train_loss: 0.2110 | train_accuracy: 0.9375 | test_accuracy: 0.9449
[  203]/[  468] | train_loss: 0.1075 | train_accuracy:

[  288]/[  468] | train_loss: 0.1913 | train_accuracy: 0.9219 | test_accuracy: 0.9537
[  289]/[  468] | train_loss: 0.1781 | train_accuracy: 0.9609 | test_accuracy: 0.9551
[  290]/[  468] | train_loss: 0.1791 | train_accuracy: 0.9297 | test_accuracy: 0.9524
[  291]/[  468] | train_loss: 0.0866 | train_accuracy: 0.9531 | test_accuracy: 0.9471
[  292]/[  468] | train_loss: 0.2196 | train_accuracy: 0.9297 | test_accuracy: 0.9429
[  293]/[  468] | train_loss: 0.0751 | train_accuracy: 0.9766 | test_accuracy: 0.9403
[  294]/[  468] | train_loss: 0.1792 | train_accuracy: 0.9453 | test_accuracy: 0.9398
[  295]/[  468] | train_loss: 0.2022 | train_accuracy: 0.9297 | test_accuracy: 0.9440
[  296]/[  468] | train_loss: 0.1500 | train_accuracy: 0.9453 | test_accuracy: 0.9472
[  297]/[  468] | train_loss: 0.1718 | train_accuracy: 0.9453 | test_accuracy: 0.9505
[  298]/[  468] | train_loss: 0.1283 | train_accuracy: 0.9766 | test_accuracy: 0.9515
[  299]/[  468] | train_loss: 0.0874 | train_accuracy:

[  384]/[  468] | train_loss: 0.0853 | train_accuracy: 0.9688 | test_accuracy: 0.9598
[  385]/[  468] | train_loss: 0.0542 | train_accuracy: 0.9922 | test_accuracy: 0.9589
[  386]/[  468] | train_loss: 0.1344 | train_accuracy: 0.9766 | test_accuracy: 0.9595
[  387]/[  468] | train_loss: 0.0553 | train_accuracy: 0.9922 | test_accuracy: 0.9579
[  388]/[  468] | train_loss: 0.1348 | train_accuracy: 0.9609 | test_accuracy: 0.9554
[  389]/[  468] | train_loss: 0.2329 | train_accuracy: 0.9297 | test_accuracy: 0.9541
[  390]/[  468] | train_loss: 0.1598 | train_accuracy: 0.9844 | test_accuracy: 0.9540
[  391]/[  468] | train_loss: 0.1533 | train_accuracy: 0.9766 | test_accuracy: 0.9548
[  392]/[  468] | train_loss: 0.1032 | train_accuracy: 0.9688 | test_accuracy: 0.9558
[  393]/[  468] | train_loss: 0.1559 | train_accuracy: 0.9609 | test_accuracy: 0.9571
[  394]/[  468] | train_loss: 0.1486 | train_accuracy: 0.9766 | test_accuracy: 0.9578
[  395]/[  468] | train_loss: 0.1938 | train_accuracy: