# Fashion MNIST with `tf.keras` from Scratch

This example demonstrates the workflow to create, train, and validate a 
TensorFlow `tf.keras` model, save it to HDF5 `.h5` model and convert it 
to Core ML `.mlmodel` format using the `tfcoreml` converter. For more
examples, refer `test_tf_2x.py` file.
 
Note: 

- This notebook was tested with following dependencies:

```
tensorflow==2.0.0
coremltools==3.1
tfcoreml==1.1
```

- Models from TensorFlow 2.0+ is supported only for `minimum_ios_deployment_target>='13'`.
You can also use `coremltools.converters.tensorflow.convert()` 
instead of `tfcoreml.convert()` to convert your model.

In [1]:
import tensorflow as tf
import numpy as np
import tfcoreml

print(tf.__version__)

W1101 14:00:52.328081 4735601984 __init__.py:74] TensorFlow version 2.0.0 detected. Last version known to be fully compatible is 1.14.0 .


2.0.0


In [2]:
# prepare fashion_mnist dataset
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

In [3]:
# create a simple model using tf.keras
keras_model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

In [4]:
# training and evaludate keras model
keras_model.compile(optimizer='adam',
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy'])

keras_model.fit(train_images, train_labels, epochs=10)
test_loss, test_acc = keras_model.evaluate(test_images, test_labels, verbose=2)

print('\nTest accuracy:', test_acc)

Train on 60000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
10000/1 - 0s - loss: 0.2264 - accuracy: 0.8833

Test accuracy: 0.8833


In [5]:
# save the tf.keras model as .h5 model file
model_file = './mnist_fashion_model.h5'
keras_model.save(model_file)

!ls mnist_fashion_model.h5

mnist_fashion_model.h5


In [6]:
# convert this model to Core ML format
model = tfcoreml.convert(tf_model_path=model_file,
                         input_name_shape_dict={'flatten_input': (1, 28, 28)},
                         output_feature_names=['Identity'],
                         minimum_ios_deployment_target='13')
model.save('./mnist_fashion_model.mlmodel')

!ls mnist_fashion_model.mlmodel

0 assert nodes deleted
['sequential/dense_1/BiasAdd/ReadVariableOp/resource:0', 'sequential/dense/MatMul/ReadVariableOp:0', 'sequential/dense/BiasAdd/ReadVariableOp:0', 'sequential/flatten/Reshape/shape:0', 'sequential/dense/BiasAdd/ReadVariableOp/resource:0', 'sequential/dense/MatMul/ReadVariableOp/resource:0', 'sequential/dense_1/MatMul/ReadVariableOp/resource:0', 'sequential/dense_1/BiasAdd/ReadVariableOp:0', 'sequential/dense_1/MatMul/ReadVariableOp:0']
4 nodes deleted
0 nodes deleted
0 nodes deleted
[Op Fusion] fuse_bias_add() deleted 4 nodes.
2 identity nodes deleted
2 disconnected nodes deleted
[SSAConverter] Converting function main ...
[SSAConverter] [1/7] Converting op type: 'Placeholder', name: 'flatten_input', output_shape: (1, 28, 28).
[SSAConverter] [2/7] Converting op type: 'Const', name: 'sequential/flatten/Reshape/shape', output_shape: (2,).
[SSAConverter] [3/7] Converting op type: 'Reshape', name: 'sequential/flatten/Reshape', output_shape: (1, 784).
[SSAConverter] [4

In [7]:
# run predictions with fake image as an input
fake_image = np.random.rand(1, 28, 28)

keras_predictions = keras_model.predict(fake_image)
print(keras_predictions[:10])

coreml_predictions = model.predict({'flatten_input': fake_image})['Identity']
print(coreml_predictions[:10])

assert(np.allclose(keras_predictions, coreml_predictions))

[[1.5719648e-09 1.7905072e-09 5.9817944e-07 8.1820750e-10 9.6943937e-09
  5.0254831e-20 1.5249961e-07 6.2053448e-17 9.9999928e-01 1.0400648e-15]]
[[1.57196778e-09 1.79050730e-09 5.98181146e-07 8.18209001e-10
  9.69441238e-09 5.02548314e-20 1.52499751e-07 6.20534484e-17
  9.99999285e-01 1.04006487e-15]]
