##### Copyright 2021 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/MendezJesus/python_to_web/blob/main/python_to_web.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />
    Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/MendezJesus/python_to_web/blob/main/python_to_web.ipynb">
    <img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />
    View source on GitHub</a>
  </td>
</table>


# Run a neural network in the browser

This notebook contains an example of training a neural network using TensorFlow and Python, then running it in the browser using TensorFlow.js. If you run this notebook in Colab using *Runtime -> Run all*, it will create a webpage that you can use interactively to classify images you upload through a user interface.

This notebook provides example code to do the following:

1. Train a neural network using Python to classify flowers
1. Save the model to disk
1. Convert the model to TensorFlow.js format
1. Create a webpage using a minimum amount of HTML and JavaScript to run the model
1. Serve the webpage from this notebook (instructions are also provided to serve the page locally)
1. Upload images through the UI and classify them with the model

# Setup

## Install Tensorflow.js and other libraries

In [None]:
!pip install tensorflowjs

## Imports

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import pathlib
import json
import matplotlib

from tensorflow.keras import datasets, layers, models
from tensorflow.keras.models import Model
from google.colab import html
from IPython.display import HTML
from PIL import Image

# Train a model using Python

## Load the dataset


The first half of this tutorial trains a convolutional neural network to classify images of flowers, based on this [tutorial](https://www.tensorflow.org/tutorials/images/classification). The steps that are repeated here are only lightly commented, and you can refer to the original tutorial to learn more.

In [None]:
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url, 
                                   fname='flower_photos', 
                                   untar=True)
data_dir = pathlib.Path(data_dir)

There are 3670 total flower images.

In [None]:
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

Define parameters for the loader.


In [None]:
batch_size = 32
img_height = 180
img_width = 180

It's good practice to use a validation split when developing your model. We will use 80% of the images for training, and 20% for validation.

In [None]:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

In [None]:
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

The flower dataset contains 5 different type of flowers.

In [None]:
class_names = train_ds.class_names
print(class_names)

## Visualize the data

Here are the first 9 images from the training dataset.

In [None]:
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

Displays an image of a sunflower that can be classified later in the webpage.


In [None]:
plt.imshow(images[3].numpy().astype("uint8"))

## Standardize the data



Note: It is important that you preprocess images identically in Python and JavaScript. This notebook uses a Python preprocessing layer which is not yet available in TensorFlow.js. As a workaround, you will normalize the images in JavaScript manually, as shown later in this notebook.

In [None]:
normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)

train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))

## Train the model

In [None]:
num_classes = 5

model = tf.keras.Sequential([
  layers.Conv2D(32, 3, activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(num_classes, activation='softmax')
])

In [None]:
model.compile(
  optimizer='adam',
  loss=tf.losses.SparseCategoricalCrossentropy(from_logits=False),
  metrics=['accuracy'])

In [None]:
model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=5
)

In [None]:
model.summary()

# Save and convert the model

Save the TensorFlow model.



In [None]:
model.save("tfjs/myModel.h5")

Convert an existing TensorFlow model to the TensorFlow.js web format. 

The conversion will produce two types of files:


1. model.json (the dataflow graph and weight manifest)
1. group1-shard\*of\* (collection of binary weight files)

In [None]:
!tensorflowjs_converter \
    --input_format=keras \
    tfjs/myModel.h5 \
    tfjs

# Develop the Website

This notebook contains a small amount of HTML and JavaScript, stored as Python strings inside code cells below. This is close to the minimum amount of code necessary to run the model in the browser (along with some CSS for the UI). Running the code cells below will save this HTML and JavaScript to disk, then serve the page from inside this notebook.

## JavaScript

This code include helper functions to prepare the data, classify images, and display the model's prediction in HTML elements that will be defined in the next block.

In [None]:
js_template = """
const CLASSES = [
  'Daisy',
  'Dandelion',
  'Rose',
  'Sunflower',
  'Tulips',
];

const MODEL_PATH = '<model_url>';
const IMAGE_SIZE = 180;
const TOPK_PREDICTIONS = 5;

let myModel = undefined;
const DEMO = async () => {
  status('Loading model...');

  myModel = await tf.loadLayersModel(MODEL_PATH);

  // Warmup the model. This isn't necessary, but makes the first prediction
  // faster. Call `dispose` to release the WebGL memory allocated for the return
  // value of `predict`.
  myModel.predict(tf.zeros([1, IMAGE_SIZE, IMAGE_SIZE, 3])).dispose();

  status('');

  // Make a prediction through the locally hosted cat.jpg.
  const inputImage = document.getElementById('inputimg');
  if (inputImage.complete && inputImage.naturalHeight !== 0) {
    predict(inputImage);
    inputImage.style.display = '';
  } else {
    inputImage.onload = () => {
      predict(inputImage);
      inputImage.style.display = '';
    }
  }
};

/**
 * Uses an image to make a prediction using myModel.
 * @param imgElement is converted to a tensor and then normalized to make
 * prediction.
 */
async function predict(imgElement) {
  status('Predicting...');

  // The first start time includes the time it takes to extract the image
  // from the HTML and preprocess it, in additon to the predict() call.
  const startTime1 = performance.now();
  // The second start time excludes the extraction and preprocessing and
  // includes only the predict() call.
  let startTime2 = 0;
  const logits = tf.tidy(() => {
    // tf.browser.fromPixels() returns a Tensor from an image element.
    const img = tf.browser.fromPixels(imgElement).toFloat();

    // Normalize the image from [0, 255] to [-1, 1].
    const normalized = img.div(255.0);

    // Reshape to a single-element batch so we can pass it to predict.
    const batched = normalized.reshape([1, IMAGE_SIZE, IMAGE_SIZE, 3]);

    startTime2 = performance.now();
    // Make a prediction through myModel.
    return myModel.predict(batched);
  });

  // Convert logits to probabilities and class names.
  const {values, indices} = tf.topk(logits, TOPK_PREDICTIONS);
  const indicesArray = indices.arraySync();
  const valuesArray = values.arraySync();

  const classes = [];
  for (let i = 0; i < CLASSES.length; i++) {
    classes.push({
      className: CLASSES[indicesArray[0][i]],
      probability: valuesArray[0][i]
    })
  }

  const totalTime1 = performance.now() - startTime1;
  const totalTime2 = performance.now() - startTime2;
  status(`Done in ${Math.floor(totalTime1)} ms ` +
      `(not including preprocessing: ${Math.floor(totalTime2)} ms)`);

  // Show the classes in the DOM.
  showResults(imgElement, classes);
}

/**
 * UI
 * Maps flower class with corresponding probability.
 * @param classes contains flower class name and corresponding prediction.
 */
function showResults(imgElement, classes) {

  const predictionContainer = document.createElement('div');
  const imgContainer = document.createElement('div');
  imgContainer.appendChild(imgElement);
  predictionContainer.appendChild(imgContainer);

  const table = document.createElement("table");
  const tableBody = document.createElement("tbody");

  for (let i = 0; i < classes.length; i++) {
    var row = document.createElement("tr");
    var cell = document.createElement("td");
    var cell2 = document.createElement("td");
    var cellTextFlowerName = document.createTextNode(classes[i].className);
    var cellTextProbability = document.createTextNode(classes[i].probability.toFixed(3));
    cell.appendChild(cellTextFlowerName);
    cell2.appendChild(cellTextProbability);
    row.appendChild(cell);
    row.appendChild(cell2);
    tableBody.appendChild(row);
  }
  table.appendChild(tableBody);

  predictionContainer.appendChild(table);

  predictionsElement.insertBefore(
      predictionContainer, predictionsElement.firstChild);
}

const filesElement = document.getElementById('files');
filesElement.addEventListener('change', evt => {
  let files = evt.target.files;
  // Display thumbnails & issue call to predict each image.
  for (let i = 0, f; f = files[i]; i++) {
    // Only process image files (skip non image files)
    if (!f.type.match('image.*')) {
      continue;
    }
    let reader = new FileReader();
    const idx = i;
    // Closure to capture the file information.
    reader.onload = e => {
      // Fill the image & call predict.
      let img = document.createElement('img');
      img.src = e.target.result;
      img.width = IMAGE_SIZE;
      img.height = IMAGE_SIZE;
      img.onload = () => predict(img);
    };

    // Read in the image file as a data URL.
    reader.readAsDataURL(f);
  }
});

const DEMO_STATUS_ELEMENT = document.getElementById('status');
const status = msg => DEMO_STATUS_ELEMENT.innerText = msg;

const predictionsElement = document.getElementById('predictions');

DEMO();
"""

## HTML
The HTML below contains element IDs that will be populated by the JavaScript.

In [None]:
html_template = """
<!doctype html>

<head>
  <title>Your title</title>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1">
</head>

<body>
  <div class="tfjs-example-container">
      <h1>TensorFlow.js: Using a Keras model in the browser</h1>
      <p>Description</p>
      <p>
        This file is based on <a>https://github.com/tensorflow/tfjs-examples/tree/master/mobilenet</a>.
      </p>
      <p>Status</p>
      <div id="status"></div>
      <p>Model Output</p>
       <label for="file">Upload an image:</label>
       <input type="file" id="files" name="files[]" multiple />
      <p id="predictions"></p>
      <img style="display: none" id="inputimg" src="{sunflower_url}" crossorigin="anonymous" width=180 height=180 />
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script>
    <script>{js_content}</script>
  </div>
</body>
"""

# Serve the model from this notebook

In [None]:
sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/models/tfjs/python_to_web/sunflower.jpg"

In [None]:
# Create a resource and make it available for http request.

def CreateModelResources(some_path, file=""):
  # Create resource for model.json
  model_path = some_path + file
  ref = html.create_resource(filepath = model_path, route = model_path)

  weight_path_list = []
  with open(model_path, 'r') as f:
    model_json = json.load(f)
    weightsManifest = model_json['weightsManifest']
    for weightGroup in weightsManifest:
      weight_path_list.extend(weightGroup['paths'])

  # Create resource for each weight path.
  refs = []
  for path in weight_path_list:
    weight_path = some_path + '/' + path
    with open(weight_path, 'rb') as f:
      weights = f.read()
      weight_path = html.create_resource(content=weights, route=weight_path)
      refs.append(weight_path)

  return ref, refs

In [None]:
(ref, refs) = CreateModelResources("/content/tfjs", "/model.json")

In [None]:
js_content_final = js_template.replace("<model_url>", ref.url)
html_content_final = html_template.format(sunflower_url=sunflower_url, js_content=js_content_final)

In [None]:
HTML(html_content_final)

In [None]:
# Prepare the HTML to run locally
# by replacing the model path with a local directory
import re
html_content_final = re.sub("const MODEL_PATH =.+", 
                            "const MODEL_PATH = 'model.json';", 
                            html_content_final)
with open('/content/tfjs/index.html', 'w') as f:
    f.write(html_content_final)

# Run locally

This example has the files embedded within the Colab for convenience, but you can run serve the model locally on your computer by starting your own webserver. To do so:

1. Run this notebook (either in Colab, on on your computer) to create the saved model and index.html. 

1. Download the converted model (including the weights and json) and index.html file from ```/content/tfjs/``` in Colab to your local machine. If you don't see that directory in the file browser, click the refresh icon.

1. Start a web server on your local machine (if you simply open index.html in a browser, you may run into security protections that prevent it from loading scripts. To start a server, you can use one built-in to Python. First, `cd` into your `tfds` directory.  Using Python3, run this command in your terminal:

 `$ python3 -m http.server 8888`

1. Now, open a browser to the generated port(e.g. point the URL to `localhost:8888`). Your webpage should appear. To debug, in Chrome you can open the Javascript console via ```View - Developer -> Developer tools -> Console```. Check if there are any errors.

# Next Steps
Explore another example of how to convert a SavedModel to TensorFlow.js in this [codelab](https://codelabs.developers.google.com/codelabs/tensorflowjs-convert-python-savedmodel), and
to learn more about image classification please checkout this [tutorial](https://www.tensorflow.org/tutorials/images/classification). You can learn more about converting models [here](https://www.tensorflow.org/js/tutorials/conversion/import_keras), and you find an additional example of image classification in the browser [here](https://github.com/tensorflow/tfjs-examples/tree/master/mobilenet).
