In [1]:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
tf.enable_eager_execution()
tfe = tf.contrib.eager

In [8]:
mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)


Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


In [16]:
# Hyperparameters
lr = 0.001
num_steps = 1000
bs = 128
display_step = 100

# Network parameters
n_h1 = 256
n_h2 = 256
n_out = 10
n_in = 784



In [17]:
# using TF Dataset to split data in to batches

dataset = tf.data.Dataset.from_tensor_slices((mnist.train.images, mnist.train.labels))
dataset = dataset.repeat().batch(bs).prefetch(bs)
dataset_iter = tfe.Iterator(dataset)

In [18]:
## Define the network, to use eager API and tf.layers API together, 
## we must instantiate a tfe.Network class as follows

class mlp_eager(tfe.Network):
    def __init__(self):
        # Define Each layer
        super(mlp_eager, self).__init__()
        
        self.layer1 = self.track_layer(tf.layers.Dense(n_h1, activation=tf.nn.relu))
        
        self.layer2 = self.track_layer(tf.layers.Dense(n_h2, activation=tf.nn.relu))
        
        self.out_layer = self.track_layer(tf.layers.Dense(n_out))
    
    def call(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return self.out_layer(x)

network = mlp_eager()
        


Please inherit from `tf.keras.Model`, and see its documentation for details. `tf.keras.Model` should be a drop-in replacement for `tfe.Network` in most cases, but note that `track_layer` is no longer necessary or supported. Instead, `Layer` instances are tracked on attribute assignment (see the section of `tf.keras.Model`'s documentation on subclassing). Since the output of `track_layer` is often assigned to an attribute anyway, most code can be ported by simply removing the `track_layer` calls.

`tf.keras.Model` works with all TensorFlow `Layer` instances, including those from `tf.layers`, but switching to the `tf.keras.layers` versions along with the migration to `tf.keras.Model` is recommended, since it will preserve variable names. Feel free to import it with an alias to avoid excess typing :).


In [19]:
def loss_fn(model_fn, inputs, labels):
    return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=model_fn(inputs), labels=labels))

def accuracy_fn(model_fn, inputs, labels):
    preds = tf.nn.softmax(model_fn(inputs))
    correct_preds = tf.equal(tf.argmax(preds, 1), labels)
    return tf.reduce_mean(tf.cast(correct_preds, tf.float32))

optimizer = tf.train.AdamOptimizer(learning_rate=lr)

grad = tfe.implicit_gradients(loss_fn)

In [21]:
## Training

avg_loss = 0.0
avg_acc = 0.0

for step in range(num_steps):
    
    ## Iterate through dataset
    d = dataset_iter.next()
    
    batch_x = d[0]; batch_y = tf.cast(d[1], dtype=tf.int64)
    
    batch_loss = loss_fn(network, batch_x, batch_y)
    avg_loss += batch_loss
    
    batch_accuracy = accuracy_fn(network, batch_x, batch_y)
    avg_acc += batch_accuracy
    
    if(step == 0):
        print("Initial Loss: ", avg_loss.numpy())
        
    optimizer.apply_gradients(grad(network, batch_x, batch_y))
    
    if(step+1) % display_step == 0:
        avg_loss /= display_step
        avg_acc /= display_step
        
        print("Step: ", step+1, "loss: ", avg_loss.numpy(), "accuracy: ", avg_acc.numpy())
        avg_loss = 0.0
        avg_acc = 0.0

Initial Loss:  tf.Tensor(0.08449602, shape=(), dtype=float32)
Step:  100 loss:  0.07309364 accuracy:  0.9778125
Step:  200 loss:  0.069960386 accuracy:  0.97867185
Step:  300 loss:  0.05830607 accuracy:  0.9815625
Step:  400 loss:  0.057883106 accuracy:  0.9825
Step:  500 loss:  0.05129766 accuracy:  0.9853906
Step:  600 loss:  0.047039997 accuracy:  0.9853125
Step:  700 loss:  0.04275153 accuracy:  0.98726565
Step:  800 loss:  0.040562183 accuracy:  0.98859376
Step:  900 loss:  0.03719801 accuracy:  0.9886719
Step:  1000 loss:  0.031995494 accuracy:  0.9907031


In [23]:
## Evaluation of model
testX, testY = mnist.test.images, mnist.test.labels

test_acc = accuracy_fn(network, testX, testY)
print("Test set Accuracy: ", test_acc.numpy())

Test set Accuracy:  0.9779
