In [1]:
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow import keras

In [24]:
import numpy as np

In [15]:
import tensorflow as tf

In [25]:
DATA_URL = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz'

path = tf.keras.utils.get_file('mnist.npz', DATA_URL)
with np.load(path) as data:
    x_train = data['x_train']
    y_train = data['y_train']
    x_examples = data['x_test']
    x_labels = data['y_test']

There are various ways to construct ANNs in Tensorflow. 

Here we compare SEQUENTIAL formatting with FUNCTIONAL API formatting

In [28]:
# SEQUENTIAL
seq_model = keras.Sequential([
        Flatten(input_shape=(28,28)),
        Dense(128, activation='relu'),
        Dense(10, activation='softmax')
    ])


In [29]:
# FUNCTIONAL API
input = Input(shape=(28,28))
x = Flatten()(input)
x = Dense(128, activation="relu")(x)
predictions = Dense(10, activation="softmax")(x)
func_model = Model(inputs=input, outputs=predictions)

Note that with a functional API:
- you explicitly define an input layer. The MNIST dataset is 28 by 28 image.
- define the layers of the model, but don't put into a list. 
- The return value of each function is passed into the next function using parentheses. 

In [30]:
func_model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

In [32]:
history = func_model.fit(x_train, y_train, epochs=1)




In [34]:
 history.history['accuracy'][-1]

0.862666666507721

# Extra
We can make model construction a single function to save time creating models

In [35]:
def build_functional_model():
    input_layer = Input(shape=(28,28))
    flatten_layer = Flatten()(input_layer)
    first_dense = Dense(128, activation="relu")(flatten_layer)
    output_layer = Dense(10, activation="softmax")(first_dense)
    func_model = Model(inputs=input_layer, outputs=output_layer)
    return func_model
    

In [36]:
func_model2 = build_functional_model()