# Hello World: Subclassing and GradientTape edition

An example showing how to use Keras [Subclassing](https://www.tensorflow.org/guide/keras) in TensorFlow 2.0. We'll use a [GradientTape](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/GradientTape) to write our training loop. You can find more details about this style, and how it compares to the previous one, in this [article](https://medium.com/tensorflow/what-are-symbolic-and-imperative-apis-in-tensorflow-2-0-dfccecb01021).

### Install the nightly build


In [1]:
!pip install tf-nightly-2.0-preview



In [0]:
import tensorflow as tf

In [3]:
print("You have version", tf.__version__)
assert tf.__version__ >= "2.0" # TensorFlow ≥ 2.0 required

You have version 2.0.0-dev20190203


In [0]:
import numpy as np

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Flatten

### Load the dataset

In [0]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

x_train = x_train / 255
x_test = x_test / 255

### Batch and shuffle the data

Next, we'll use `tf.data` to batch up and shuffle our dataset. Notice the `buffer_size` parameter. Why is it necessary? Datasets are (potentially infinite) streams. Since we can't shuffle a stream, we maintain a buffer, and shuffle that instead.

In [0]:
BATCH_SIZE = 128
BUFFER_SIZE = len(x_train)

mnist_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
mnist_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

### Define a model

Using this style feels like Object-Oriented Python + NumPy development. Initialize your layers in the constructor, then write your forward pass in the call method.

In [0]:
class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.flatten = Flatten()
    self.d1 = Dense(128, activation='relu')
    self.d2 = Dense(10, activation='softmax')

  def call(self, x):
    x = self.flatten(x)
    x = self.d1(x)
    return self.d2(x)
  
model = MyModel()

### Choose an optimizer and loss function

In [0]:
loss_function = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam()

### Choose metrics to measure loss and accuracy
These are helper functions that accumulate values over time.

In [0]:
train_loss_metric = tf.keras.metrics.Mean(name='train_loss')
test_loss_metric = tf.keras.metrics.Mean(name='test_loss')

train_accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
test_accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

### Train the model using GradientTape

We'll use a [GradientTape](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/GradientTape), rather than the built-in `model.fit`.

Also note, we could "compile" the next two methods by adding a [tf.function](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/function)  annotation at the top. Not necessary for our purposes here, just FYI. 

In [0]:
def train_step(images, labels):
  with tf.GradientTape() as tape:    
    # Forward pass
    predictions = model(images)
    train_loss = loss_function(y_true=labels, y_pred=predictions)
  
  # Backward pass
  gradients = tape.gradient(train_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  # Record results
  train_loss_metric(train_loss)
  train_accuracy_metric(labels, predictions)

In [0]:
def test_step(images, labels):
  predictions = model(images)
  test_loss = loss_function(y_true=labels, y_pred=predictions)
  
  # Record results
  test_loss_metric(test_loss)
  test_accuracy_metric(labels, predictions)

In [12]:
EPOCHS = 5

for epoch in range(EPOCHS):
  for images, labels in mnist_train:
    train_step(images, labels)
  
  for test_images, test_labels in mnist_test:
    test_step(test_images, test_labels)
  
  template = 'Epoch {}, Loss: {:.4f}, Accuracy: {:.2f}, Test loss: {:.4f}, Test accuracy: {:.2f}'
  print (template.format(epoch +1, 
                         train_loss_metric.result(), 
                         train_accuracy_metric.result() * 100, 
                         test_loss_metric.result(), 
                         test_accuracy_metric.result() * 100))

Epoch 1, Loss: 0.3517, Accuracy: 90.32, Test loss: 0.1958, Test accuracy: 94.26
Epoch 2, Loss: 0.2572, Accuracy: 92.80, Test loss: 0.1678, Test accuracy: 94.99
Epoch 3, Loss: 0.2104, Accuracy: 94.08, Test loss: 0.1499, Test accuracy: 95.52
Epoch 4, Loss: 0.1804, Accuracy: 94.92, Test loss: 0.1373, Test accuracy: 95.89
Epoch 5, Loss: 0.1588, Accuracy: 95.52, Test loss: 0.1279, Test accuracy: 96.17


Okay! As a next step, you can play with the model to see if you can increase the accuracy.