In [2]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

In [29]:
mnist = tf.keras.datasets.mnist

In [30]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [31]:
x_train.shape

(60000, 28, 28)

channel情報を含む次元を追加
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/layers/Conv2D

In [32]:
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

In [33]:
x_train.shape

(60000, 28, 28, 1)

In [34]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [19]:
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        
        #Flattenで1次元に平滑化される
        #ref) https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/layers/Flatten
        # now: model.output_shape == (None, 64, 32, 32)
        #model.add(Flatten())
        # now: model.output_shape == (None, 65536)
        
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')
    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

In [44]:
model = MyModel()

In [50]:
model.summary()

Model: "my_model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_2 (Conv2D)            multiple                  320       
_________________________________________________________________
flatten_2 (Flatten)          multiple                  0         
_________________________________________________________________
dense_4 (Dense)              multiple                  2769024   
_________________________________________________________________
dense_5 (Dense)              multiple                  1290      
Total params: 2,770,634
Trainable params: 2,770,634
Non-trainable params: 0
_________________________________________________________________


In [45]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy() 
#ラベルクラスが2つ以上ある場合は、このクロスエントロピー損失関数を使用してください。らしい。one-hot表現で欲しいときはCategoricalCrossentropy
optimizer = tf.keras.optimizers.Adam()

In [46]:
train_loss = tf.keras.metrics.Mean(name='train_loss') #名前の通り加重平均を計算するためのfunction
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [47]:
@tf.function #Define and Run
def train_step(image, label):
    with tf.GradientTape() as tape:
        predictions = model(image)
        loss = loss_object(label, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_accuracy(label, predictions)
# tf.GradientTape()で勾配を計算するための情報を記録しているっぽい
# ここで実行した式が記憶され、後続のtape.gradientで勾配が計算できる

# apply_gradientsで勾配の更新が走る ref) https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/optimizers/Adam#apply_gradients
# 引数に grads_and_vars: List of (gradient, variable) pairs.をとる

In [48]:
@tf.function
def test_step(image, label):
    predictions = model(image)
    t_loss = loss_object(label, predictions)
    test_loss(t_loss)
    test_accuracy(label, predictions)

In [49]:
EPOCHS = 5
for epoch in range(EPOCHS):
    for image, label in train_ds:
        train_step(image, label)
    for test_image, test_label in test_ds:
        test_step(test_image, test_label)
    
    template = 'Epoch {}, Loss {}, Accuracy {}, Test Loss {}, Test Accuracy {}'
    print(template.format(
        epoch + 1,
        train_loss.result(),
        train_accuracy.result() * 100,
        test_loss.result(),
        test_accuracy.result() * 100
    ))

Epoch 1, Loss 0.13764896988868713, Accuracy 95.8949966430664, Test Loss 0.06724967807531357, Test Accuracy 97.88999938964844
Epoch 2, Loss 0.08862052112817764, Accuracy 97.35083770751953, Test Loss 0.06288466602563858, Test Accuracy 97.98500061035156
Epoch 3, Loss 0.06527643650770187, Accuracy 98.04777526855469, Test Loss 0.06169414147734642, Test Accuracy 98.1199951171875
Epoch 4, Loss 0.05190098285675049, Accuracy 98.44708251953125, Test Loss 0.060695841908454895, Test Accuracy 98.13999938964844
Epoch 5, Loss 0.04331650584936142, Accuracy 98.6989974975586, Test Loss 0.06159958615899086, Test Accuracy 98.1719970703125
