##### Copyright 2019 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Convolutional Neural Network (CNN)

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/images/cnn">
    <img src="https://www.tensorflow.org/images/tf_logo_32px.png" />
    View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/images/cnn.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />
    Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/images/cnn.ipynb">
    <img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />
    View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/images/cnn.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

This tutorial demonstrates training a simple [Convolutional Neural Network](https://developers.google.com/machine-learning/glossary/#convolutional_neural_network) (CNN) to classify [CIFAR images](https://www.cs.toronto.edu/~kriz/cifar.html). Because this tutorial uses the [Keras Sequential API](https://www.tensorflow.org/guide/keras/overview), creating and training your model will take just a few lines of code.


### Import TensorFlow

In [1]:
import tensorflow as tf

from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
from ipywidgets import widgets
import numpy as np
import imageio
from skimage.transform import resize




In [2]:
model = tf.keras.models.load_model('default-cifar10-cnn')





In [3]:
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

## accept and predict on user image

In [4]:
import imageio
from skimage.transform import resize

output = widgets.Output()

@output.capture()
def predict_image(sender):
    filename = input_text.value + '.jpg'
    input_text.value = ''
    img = imageio.imread(filename)
    img = resize(img, (32, 32))
    plt.imshow(img)

    test = np.expand_dims(img, axis=0)
    pred = model.predict(test)[0]

    print(f'Given image {filename}')
    idx = pred.argmax()
    print(f'Prediction name {class_names[idx]}')

In [5]:
input_text = widgets.Text()
submit = widgets.Button(description='submit')
submit.on_click(predict_image)
display(widgets.VBox([input_text, submit]))
display(output)

VBox(children=(Text(value=''), Button(description='submit', style=ButtonStyle())))

Output()

In [6]:
input_text = widgets.Text()
submit = widgets.Button(description='submit')
submit.on_click(predict_image)
display(widgets.VBox([input_text, submit]))
display(output)

VBox(children=(Text(value=''), Button(description='submit', style=ButtonStyle())))



In [7]:
input_text = widgets.Text()
submit = widgets.Button(description='submit')
submit.on_click(predict_image)
display(widgets.VBox([input_text, submit]))
display(output)

VBox(children=(Text(value=''), Button(description='submit', style=ButtonStyle())))



In [8]:
input_text = widgets.Text()
submit = widgets.Button(description='submit')
submit.on_click(predict_image)
display(widgets.VBox([input_text, submit]))
display(output)

VBox(children=(Text(value=''), Button(description='submit', style=ButtonStyle())))



In [9]:
input_text = widgets.Text()
submit = widgets.Button(description='submit')
submit.on_click(predict_image)
display(widgets.VBox([input_text, submit]))
display(output)

VBox(children=(Text(value=''), Button(description='submit', style=ButtonStyle())))

