In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Dense,Flatten,Conv2D
from tensorflow.keras import Model

In [2]:
tf.__version__

'2.2.0-rc3'

In [32]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [33]:
x_test = x_test[..., tf.newaxis]
x_train = x_train[...,tf.newaxis]

In [34]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [40]:
class Mymodel(Model):
    def __init__(self):
        super(Mymodel,self).__init__()
        self.conv1 = Conv2D(32,3,activation = 'relu')
        self.flatten = Flatten()
        self.d1 = Dense(128,activation='relu')
        self.d2 = Dense(10)
    def call(self,x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

In [41]:
model = Mymodel()

In [22]:
loss_object  =tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

In [23]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [45]:
@tf.function
def train_step(images,labels):
    with tf.GradientTape() as tape:
        predictions = model(images,training=True)
        loss = loss_object(labels,predictions)
    gradients = tape.gradient(loss,model.trainable_variables)
    optimizer.apply_gradients(zip(gradients,model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(labels,predictions)

In [26]:
@tf.function
def test_step(images,labels):
    predictions = model(images,training = False)
    loss = loss_object(labels,predictions)
    
    test_loss(loss)
    test_accuracy(labels,predictions)

In [46]:
EPOCHS = 5

for epoch in range(EPOCHS):
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()
    
    for images,labels in train_ds:
        train_step(images,labels)
    for images,labels in test_ds:
        test_step(images,labels)
        
    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch + 1,
                        train_loss.result(),
                        train_accuracy.result() * 100,
                        test_loss.result(),
                        test_accuracy.result() * 100))

Epoch 1, Loss: 0.04459976404905319, Accuracy: 98.58999633789062, Test Loss: 0.054880641400814056, Test Accuracy: 98.0999984741211
Epoch 2, Loss: 0.024253780022263527, Accuracy: 99.25333404541016, Test Loss: 0.0494452603161335, Test Accuracy: 98.50999450683594
Epoch 3, Loss: 0.013512328267097473, Accuracy: 99.57167053222656, Test Loss: 0.05269540473818779, Test Accuracy: 98.48999786376953
Epoch 4, Loss: 0.011303810402750969, Accuracy: 99.63500213623047, Test Loss: 0.05632008612155914, Test Accuracy: 98.5199966430664
Epoch 5, Loss: 0.007076573558151722, Accuracy: 99.77166748046875, Test Loss: 0.07598726451396942, Test Accuracy: 98.11000061035156
