In [1]:
import warnings
warnings.filterwarnings('ignore')

from tqdm import tqdm
import tensorflow as tf

In [2]:
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

In [3]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

x_train = x_train[..., tf.newaxis]
x_test  = x_test[..., tf.newaxis]

In [4]:
train_ds = tf.data.Dataset.from_tensor_slices(
            (x_train, y_train)).shuffle(10000).batch(32)
test_ds  = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [5]:
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1   = Conv2D(32, 10, activation='relu')
#         self.flatten = Flatten()
#         self.d1      = Dense(128,    activation='relu')
        self.gap     = tf.keras.layers.GlobalAveragePooling2D()
        self.d2      = Dense(10,     activation='softmax')
        
    def call(self, x):
        x = self.conv1(x)
        x = self.gap(x)
        return self.d2(x)
    
model = MyModel()

In [6]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer   = tf.keras.optimizers.Adam()

In [7]:
train_loss     = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss      = tf.keras.metrics.Mean(name='test_loss')
test_accuracy  = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [8]:
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)

In [9]:
@tf.function
def test_step(images, labels):
    predictions = model(images)
    t_loss      = loss_object(labels, predictions)
    
    test_loss(t_loss)
    test_accuracy(labels, predictions)

In [11]:
EPOCHS = 5

for epoch in range(EPOCHS):
    for images, labels in tqdm(train_ds) : 
        train_step(images, labels)
    
    for test_images, test_labels in tqdm(test_ds):
        test_step(test_images, test_labels)
        
    
    template = 'Epoch : {}, Loss : {}, Accuracy : {}, Test Loss : {}, Test Accuracy : {}'
    print(template.format(epoch+1,
                         train_loss.result(),
                         train_accuracy.result() * 100,
                         test_loss.result(),
                         test_accuracy.result()  * 100))

1875it [00:13, 142.25it/s]
313it [00:00, 599.16it/s]
0it [00:00, ?it/s]

Epoch : 1, Loss : 0.5705838203430176, Accuracy : 85.13833618164062, Test Loss : 0.42119812965393066, Test Accuracy : 89.3316650390625


611it [00:04, 138.15it/s]

KeyboardInterrupt: 