In [1]:
import tensorflow as tf
import numpy as np

tf.__version__

'2.1.0'

In [2]:
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist

In [3]:
import os

def load(model, checkpoint_dir):
    print(" [*] Reading checkpoints... ")
    
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        checkpoint = tf.train.Checkpoint(dnn=model)
        checkpoint.restore(save_path=checkpoint_dir+ckpt_name)
        counter = int(ckpt_name.split('-')[1])
        print(" [*] Success to read {} ".format(ckpt_name))
        return True, counter
    else:
        print(" [*] Failed to find a checkpoint ")
        return False, 0
    
def check_folder(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)
    return dir

In [4]:
def load_mnist():
    (train_data, train_labels), (test_data, test_labels) = mnist.load_data()
    train_data = np.expand_dims(train_data, axis=-1)
    # 하나의 배열로 생각하고 가장 끝에 원소를 추가
    test_data = np.expand_dims(test_data, axis=-1)
    
    train_labels = to_categorical(train_labels, 10)
    # one-hot cording
    test_labels = to_categorical(test_labels, 10)
    
    return train_data, train_labels, test_data, test_labels

def normalize(train_data, test_data):
    train_data = train_data.astype(np.float32) / 255.0
    test_data = test_data.astype(np.float32)/ 255.0
    return train_data, test_data

In [5]:
def loss_fn(model, images, labels):
    logits = model(images, training=False)
    return tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_pred=logits, y_true=labels, from_logits=True))

def accuracy_fn(model, images, labels):
    logits = model(images, training=False)
    prediction = tf.equal(tf.argmax(logits, -1), tf.argmax(labels, -1))
    accuracy = tf.reduce_mean(tf.cast(prediction, dtype=tf.float32))
    return accuracy

def grad(model, images, labels):
    with tf.GradientTape() as tape:
        loss = loss_fn(model, images, labels)
    return tape.gradient(loss, model.variables)

In [6]:
# layer를 1차원으로 만들어준다 생성
def flatten():
    return tf.keras.layers.Flatten()

def dense(label_dim, weight_init):
    # units : output space의 크기
    # activation : activation 함수, 디폴트는 linear
    # use_bias : bias의 사용
    # kernel_initializer : weight matrix 초기값
    return tf.keras.layers.Dense(units=label_dim, use_bias=True, kernel_initializer=weight_init)

def sigmoid():
    return tf.keras.layers.Activation(tf.keras.activations.sigmoid)

In [7]:
class create_model_class(tf.keras.Model):
    def __init__(self, label_dim):
        super(create_model_class, self).__init__()
        weight_init = tf.keras.initializers.RandomNormal()
        
        self.model = tf.keras.Sequential()
        self.model.add(flatten())
        # 배열처럼 쌓인다.
        # block 선도 처럼 직렬로 연결됨
        
        for i in range(2):
            self.model.add(dense(256, weight_init))
            self.model.add(sigmoid())
            
        self.model.add(dense(label_dim, weight_init))
    
    def call(self, x, training=None, mask=None):
        x = self.model(x)
        
        return x

In [8]:
train_x, train_y, test_x, test_y = load_mnist()

learning_rate = 0.001
batch_size = 128

training_epochs = 1
training_iterations = len(train_x) // batch_size

label_dim = 10

train_flag = True

train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y)).\
    shuffle(buffer_size=100000).\
    prefetch(buffer_size=batch_size).\
    batch(batch_size, drop_remainder=True)
# shuffle은 섞는거, buffer_size는 섞는 구간 정해줌, 원래 크기보다 크면 전체를 섞어줌
# prefetch 현재 element들로 학습하는 동안, 이 후의 batch를 준비해 놓는 것
# 잊지말자, 1 epoch = batch_size * 1 iteration
# batch drop_reminder batch 크기보다 작은 마지막 batch를 무시할 것인지 말것인지 결정

test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y)).\
    shuffle(buffer_size=10000).\
    prefetch(buffer_size=batch_size).\
    batch(batch_size, drop_remainder=True)

In [9]:
# model
network = create_model_class(label_dim)

# training
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

checkpoint_dir = 'checkpoints'
logs_dir = 'logs'

model_dir = 'nn_softmax'

checkpoint_dir = 'd:\\Machinelearning\\' + checkpoint_dir
check_folder(checkpoint_dir)
checkpoint_prefix = checkpoint_dir + '\\' + model_dir
logs_dir = 'd:\\Machinelearning\\' + logs_dir + '\\' + model_dir

In [11]:
from time import time

if train_flag:
    
    checkpoint = tf.train.Checkpoint(dnn=network)
    
    summary_writer = tf.summary.create_file_writer(logdir=logs_dir)
    start_time = time()
    
    could_load, checkpoint_counter = load(network, checkpoint_dir)
    
    if could_load:
        start_epoch = (int)(checkpoint_counter / training_interations)
        counter = checkpoint_counter
        print(" [*] Load SUCCESS ")
    else:
        start_epoch = 0
        start_interation = 0
        counter = 0
        print(" [!] Load failed... ")
        
    with summary_writer.as_default():
        for epoch in range(start_epoch, training_epochs):
            for idx, (train_input, train_label) in enumerate(train_dataset):
                grads = grad(network, train_input, train_label)
                optimizer.apply_gradients(grads_and_vars=zip(grads, network.variables))
                
                train_loss = loss_fn(network, train_input, train_label)
                train_accuracy = accuracy_fn(network, train_input, train_label)
                
                for test_input, test_label in test_dataset:
                    test_accuracy = accuracy_fn(network, test_input, test_label)
                    
                tf.summary.scalar(name='tarin_loss', data=train_loss, step=counter)
                tf.summary.scalar(name='train_accuracy', data=train_accuracy, step=counter)
                tf.summary.scalar(name='test_accuracy', data=test_accuracy, step=counter)
                
                print("Epoch: [{:2d}] [{:5d}/{:5d}] | time: {:4.4f} | train_loss: {:.8f} | train_accuracy: {:.4f} | test_accuracy: {:.4f}"\
                     .format(epoch, idx, training_iterations, time()-start_time, train_loss, train_accuracy, test_accuracy))
                counter += 1
        checkpoint.save(file_prefix=checkpoint_prefix + '-{}'.format(counter))
        
else:
    _, _ = load(network, checkpoint_dir)
    for test_input, test_label in test_dataset:
        test_accuracy = accuracy_fn(network, test_input, test_label)
        
    print("test_Accuracy: {:.4f}".format(test_accuracy))

 [*] Reading checkpoints... 
 [*] Failed to find a checkpoint 
 [!] Load failed... 
Epoch: [ 0] [    0/  468] | time: 1.6810 | train_loss: 2.24112129 | train_accuracy: 0.1328 | test_accuracy: 0.1406
Epoch: [ 0] [    1/  468] | time: 2.7520 | train_loss: 2.23041153 | train_accuracy: 0.2266 | test_accuracy: 0.1797
Epoch: [ 0] [    2/  468] | time: 3.8420 | train_loss: 2.19766283 | train_accuracy: 0.2422 | test_accuracy: 0.2500
Epoch: [ 0] [    3/  468] | time: 5.3420 | train_loss: 2.16500783 | train_accuracy: 0.3906 | test_accuracy: 0.2734
Epoch: [ 0] [    4/  468] | time: 6.4210 | train_loss: 2.14562702 | train_accuracy: 0.3906 | test_accuracy: 0.3203
Epoch: [ 0] [    5/  468] | time: 7.5250 | train_loss: 2.14068818 | train_accuracy: 0.3203 | test_accuracy: 0.2500
Epoch: [ 0] [    6/  468] | time: 8.6110 | train_loss: 2.11958408 | train_accuracy: 0.2969 | test_accuracy: 0.2031
Epoch: [ 0] [    7/  468] | time: 10.5150 | train_loss: 2.02165341 | train_accuracy: 0.3047 | test_accuracy: 0.

Epoch: [ 0] [   70/  468] | time: 88.6132 | train_loss: 0.71597028 | train_accuracy: 0.8438 | test_accuracy: 0.7891
Epoch: [ 0] [   71/  468] | time: 89.7162 | train_loss: 0.76806736 | train_accuracy: 0.8281 | test_accuracy: 0.8594
Epoch: [ 0] [   72/  468] | time: 90.8233 | train_loss: 0.75777328 | train_accuracy: 0.7891 | test_accuracy: 0.8359
Epoch: [ 0] [   73/  468] | time: 91.9763 | train_loss: 0.72300124 | train_accuracy: 0.8438 | test_accuracy: 0.8125
Epoch: [ 0] [   74/  468] | time: 93.0883 | train_loss: 0.67805886 | train_accuracy: 0.8828 | test_accuracy: 0.8281
Epoch: [ 0] [   75/  468] | time: 94.1893 | train_loss: 0.61455917 | train_accuracy: 0.8672 | test_accuracy: 0.8750
Epoch: [ 0] [   76/  468] | time: 95.2976 | train_loss: 0.69657749 | train_accuracy: 0.8672 | test_accuracy: 0.8047
Epoch: [ 0] [   77/  468] | time: 96.3837 | train_loss: 0.63949329 | train_accuracy: 0.8672 | test_accuracy: 0.8125
Epoch: [ 0] [   78/  468] | time: 97.5106 | train_loss: 0.74750990 | tra

Epoch: [ 0] [  141/  468] | time: 170.1518 | train_loss: 0.42441791 | train_accuracy: 0.9297 | test_accuracy: 0.9141
Epoch: [ 0] [  142/  468] | time: 171.4398 | train_loss: 0.37559736 | train_accuracy: 0.9453 | test_accuracy: 0.8672
Epoch: [ 0] [  143/  468] | time: 172.5238 | train_loss: 0.40396613 | train_accuracy: 0.8828 | test_accuracy: 0.9297
Epoch: [ 0] [  144/  468] | time: 173.7408 | train_loss: 0.38410112 | train_accuracy: 0.9141 | test_accuracy: 0.8984
Epoch: [ 0] [  145/  468] | time: 174.9128 | train_loss: 0.54498327 | train_accuracy: 0.8359 | test_accuracy: 0.8828
Epoch: [ 0] [  146/  468] | time: 176.1088 | train_loss: 0.56278294 | train_accuracy: 0.8516 | test_accuracy: 0.8984
Epoch: [ 0] [  147/  468] | time: 177.2398 | train_loss: 0.42673126 | train_accuracy: 0.8906 | test_accuracy: 0.8828
Epoch: [ 0] [  148/  468] | time: 178.4368 | train_loss: 0.38396591 | train_accuracy: 0.9219 | test_accuracy: 0.8906
Epoch: [ 0] [  149/  468] | time: 179.6099 | train_loss: 0.48686

Epoch: [ 0] [  212/  468] | time: 251.6185 | train_loss: 0.48762575 | train_accuracy: 0.8516 | test_accuracy: 0.8906
Epoch: [ 0] [  213/  468] | time: 252.7225 | train_loss: 0.29188910 | train_accuracy: 0.9297 | test_accuracy: 0.8594
Epoch: [ 0] [  214/  468] | time: 253.8205 | train_loss: 0.38125497 | train_accuracy: 0.8750 | test_accuracy: 0.8906
Epoch: [ 0] [  215/  468] | time: 254.8635 | train_loss: 0.34542200 | train_accuracy: 0.9297 | test_accuracy: 0.9531
Epoch: [ 0] [  216/  468] | time: 256.0297 | train_loss: 0.39146650 | train_accuracy: 0.8750 | test_accuracy: 0.8906
Epoch: [ 0] [  217/  468] | time: 257.1057 | train_loss: 0.32603866 | train_accuracy: 0.8906 | test_accuracy: 0.8828
Epoch: [ 0] [  218/  468] | time: 258.1517 | train_loss: 0.28759891 | train_accuracy: 0.9141 | test_accuracy: 0.9297
Epoch: [ 0] [  219/  468] | time: 259.3517 | train_loss: 0.37282801 | train_accuracy: 0.8828 | test_accuracy: 0.8906
Epoch: [ 0] [  220/  468] | time: 260.4077 | train_loss: 0.36330

Epoch: [ 0] [  283/  468] | time: 329.8429 | train_loss: 0.44554579 | train_accuracy: 0.8906 | test_accuracy: 0.9531
Epoch: [ 0] [  284/  468] | time: 330.8799 | train_loss: 0.49981791 | train_accuracy: 0.8672 | test_accuracy: 0.8672
Epoch: [ 0] [  285/  468] | time: 332.0399 | train_loss: 0.29656613 | train_accuracy: 0.9062 | test_accuracy: 0.8906
Epoch: [ 0] [  286/  468] | time: 333.0869 | train_loss: 0.38525945 | train_accuracy: 0.8828 | test_accuracy: 0.8828
Epoch: [ 0] [  287/  468] | time: 334.1679 | train_loss: 0.37957951 | train_accuracy: 0.8906 | test_accuracy: 0.8906
Epoch: [ 0] [  288/  468] | time: 335.2929 | train_loss: 0.35056940 | train_accuracy: 0.9062 | test_accuracy: 0.8906
Epoch: [ 0] [  289/  468] | time: 336.4199 | train_loss: 0.34607807 | train_accuracy: 0.8828 | test_accuracy: 0.8750
Epoch: [ 0] [  290/  468] | time: 337.4839 | train_loss: 0.40383524 | train_accuracy: 0.9141 | test_accuracy: 0.8984
Epoch: [ 0] [  291/  468] | time: 338.5849 | train_loss: 0.36170

Epoch: [ 0] [  354/  468] | time: 407.7395 | train_loss: 0.35866719 | train_accuracy: 0.8672 | test_accuracy: 0.9062
Epoch: [ 0] [  355/  468] | time: 408.8405 | train_loss: 0.41286635 | train_accuracy: 0.8906 | test_accuracy: 0.8906
Epoch: [ 0] [  356/  468] | time: 409.9485 | train_loss: 0.21213350 | train_accuracy: 0.9609 | test_accuracy: 0.9375
Epoch: [ 0] [  357/  468] | time: 411.0825 | train_loss: 0.36545098 | train_accuracy: 0.8984 | test_accuracy: 0.9375
Epoch: [ 0] [  358/  468] | time: 412.1305 | train_loss: 0.23092356 | train_accuracy: 0.9531 | test_accuracy: 0.9141
Epoch: [ 0] [  359/  468] | time: 413.2295 | train_loss: 0.28185731 | train_accuracy: 0.9219 | test_accuracy: 0.8906
Epoch: [ 0] [  360/  468] | time: 414.3935 | train_loss: 0.30731401 | train_accuracy: 0.9141 | test_accuracy: 0.8750
Epoch: [ 0] [  361/  468] | time: 415.4375 | train_loss: 0.27368671 | train_accuracy: 0.9141 | test_accuracy: 0.8750
Epoch: [ 0] [  362/  468] | time: 416.5295 | train_loss: 0.28899

Epoch: [ 0] [  425/  468] | time: 485.0643 | train_loss: 0.39019483 | train_accuracy: 0.8750 | test_accuracy: 0.9375
Epoch: [ 0] [  426/  468] | time: 486.1453 | train_loss: 0.33535972 | train_accuracy: 0.9141 | test_accuracy: 0.9141
Epoch: [ 0] [  427/  468] | time: 487.1983 | train_loss: 0.29937276 | train_accuracy: 0.9297 | test_accuracy: 0.9141
Epoch: [ 0] [  428/  468] | time: 488.3173 | train_loss: 0.20499852 | train_accuracy: 0.9688 | test_accuracy: 0.8906
Epoch: [ 0] [  429/  468] | time: 489.4983 | train_loss: 0.30430305 | train_accuracy: 0.9219 | test_accuracy: 0.9219
Epoch: [ 0] [  430/  468] | time: 490.5453 | train_loss: 0.20672737 | train_accuracy: 0.9531 | test_accuracy: 0.9531
Epoch: [ 0] [  431/  468] | time: 491.7823 | train_loss: 0.23801279 | train_accuracy: 0.9453 | test_accuracy: 0.8906
Epoch: [ 0] [  432/  468] | time: 493.0403 | train_loss: 0.29824728 | train_accuracy: 0.9141 | test_accuracy: 0.8984
Epoch: [ 0] [  433/  468] | time: 494.1093 | train_loss: 0.31839