<a href="https://colab.research.google.com/github/Itsuki-Hamano123/ML_DEMO_UI/blob/master/gradio-app/image_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install tensorflow-hub
%pip install gradio

Collecting gradio
[?25l  Downloading https://files.pythonhosted.org/packages/7d/66/a9f7fcc853b64d5b432b7cca9bd0f6a09c56778c80203c7bbe7a97234838/gradio-1.1.6-py3-none-any.whl (973kB)
[K     |████████████████████████████████| 983kB 2.8MB/s 
Collecting paramiko
[?25l  Downloading https://files.pythonhosted.org/packages/06/1e/1e08baaaf6c3d3df1459fd85f0e7d2d6aa916f33958f151ee1ecc9800971/paramiko-2.7.1-py2.py3-none-any.whl (206kB)
[K     |████████████████████████████████| 215kB 17.4MB/s 
Collecting analytics-python
  Downloading https://files.pythonhosted.org/packages/d3/37/c49d052f88655cd96445c36979fb63f69ef859e167eaff5706ca7c8a8ee3/analytics_python-1.2.9-py2.py3-none-any.whl
Collecting cryptography>=2.5
[?25l  Downloading https://files.pythonhosted.org/packages/ba/91/84a29d6a27fd6dfc21f475704c4d2053d58ed7a4033c2b0ce1b4ca4d03d9/cryptography-3.0-cp35-abi3-manylinux2010_x86_64.whl (2.7MB)
[K     |████████████████████████████████| 2.7MB 18.7MB/s 
[?25hCollecting pynacl>=1.0.1
[?25l  Do

In [2]:
import numpy as np
import PIL.Image as Image
import requests

import gradio as gr
from scipy.special import softmax
import tensorflow as tf
import tensorflow_hub as hub

## 画像分類モデルのデモ用WebUI作成

### モデルの用意

In [3]:
%%time
# tensorflow hubから読み込むためのURL
classifier_fetch_url = 'https://tfhub.dev/google/imagenet/mobilenet_v2_035_224/classification/4' #@param{type:'string'}
input_image_shape=(224,224,3)
def load_hub_keras(url, input_image_shape):
  model = tf.keras.Sequential([
                               hub.KerasLayer(url)
                               ])
  input_shape = list(input_image_shape)
  # バッチinput用の次元も追加
  input_shape.insert(0, None)
  model.build(input_shape=input_shape)
  return model

# ImageNetのラベル読み込み
label_fetch_url = 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt' #@param{type:'string'}
def fetch_label(url):
  response = requests.get(url)
  labels = response.text.split("\n")
  labels.remove('')
  return labels


classifier = load_hub_keras(classifier_fetch_url, input_image_shape)
labels = fetch_label(label_fetch_url)

CPU times: user 2.37 s, sys: 253 ms, total: 2.62 s
Wall time: 3.06 s


### gr.Interfaceの引数を定義

In [4]:
# WebUIの入出力定義
input_def = gr.inputs.Image(shape=input_image_shape)
output_def = gr.outputs.Label(num_top_classes=5)

# 入力を受け取ってから行う処理を定義
def classify_image_fn(input_image):
  input_tensor = np.expand_dims(input_image, 0)/255
  predict = classifier.predict(input_tensor)[0]
  def _convert_prob(num):
    return softmax(x=num)
  predict_proba = _convert_prob(predict)
  return {labels[i]: float(predict_proba[i]) for i in range(len(labels)-1)}

### Webページ起動

In [5]:
# Webページの起動
gr.Interface(fn=classify_image_fn, inputs=input_def, outputs=output_def).launch(debug=False)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on External URL: https://57281.gradio.app
Interface loading below...


(<gradio.networking.serve_files_in_background.<locals>.HTTPServer at 0x7efdbd9c4ba8>,
 'http://127.0.0.1:7860/',
 'https://57281.gradio.app')

## 物体検出モデルのデモ用WebUI作成

今回はObject Detection API with TensorFlow 2を使用<br>
[https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2.md](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2.md)

In [7]:
!git clone https://github.com/tensorflow/models.git
%cd models/research
# Compile protos.
!protoc object_detection/protos/*.proto --python_out=.
# Install TensorFlow Object Detection API.
%cp object_detection/packages/tf2/setup.py .
!python -m pip install .
%cd /content

Cloning into 'models'...
remote: Enumerating objects: 29, done.[K
remote: Counting objects: 100% (29/29), done.[K
remote: Compressing objects: 100% (29/29), done.[K
remote: Total 41964 (delta 11), reused 18 (delta 0), pack-reused 41935[K
Receiving objects: 100% (41964/41964), 549.27 MiB | 33.38 MiB/s, done.
Resolving deltas: 100% (28509/28509), done.
/content/models/research
Processing /content/models/research
Collecting avro-python3
  Downloading https://files.pythonhosted.org/packages/b2/5a/819537be46d65a01f8b8c6046ed05603fb9ef88c663b8cca840263788d58/avro-python3-1.10.0.tar.gz
Collecting apache-beam
[?25l  Downloading https://files.pythonhosted.org/packages/56/f1/7fcfbff3d3eed7895f10b358844b6e8ed21b230666aabd09d842cd725363/apache_beam-2.23.0-cp36-cp36m-manylinux2010_x86_64.whl (8.3MB)
[K     |████████████████████████████████| 8.3MB 6.5MB/s 
Collecting tf-slim
[?25l  Downloading https://files.pythonhosted.org/packages/02/97/b0f4a64df018ca018cc035d44f2ef08f91e2e8aa67271f6f19633a

In [8]:
from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.builders import model_builder

### モデル&COCOラベルの用意
以下コードの参考元：[https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/inference_from_saved_model_tf2_colab.ipynb](https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/inference_from_saved_model_tf2_colab.ipynb)

In [9]:
%%time
# tensorflow hubから読み込むためのURL
# 領域検出モデル
detector_hub_url = 'https://tfhub.dev/tensorflow/ssd_mobilenet_v2/2' #@param{type:'string'}
detector = hub.load(detector_hub_url)

# COCOのラベルマップ(ベタ打ち、、、)
category_index = {
    1: {'id': 1, 'name': 'person'},
    2: {'id': 2, 'name': 'bicycle'},
    3: {'id': 3, 'name': 'car'},
    4: {'id': 4, 'name': 'motorcycle'},
    5: {'id': 5, 'name': 'airplane'},
    6: {'id': 6, 'name': 'bus'},
    7: {'id': 7, 'name': 'train'},
    8: {'id': 8, 'name': 'truck'},
    9: {'id': 9, 'name': 'boat'},
    10: {'id': 10, 'name': 'traffic light'},
    11: {'id': 11, 'name': 'fire hydrant'},
    13: {'id': 13, 'name': 'stop sign'},
    14: {'id': 14, 'name': 'parking meter'},
    15: {'id': 15, 'name': 'bench'},
    16: {'id': 16, 'name': 'bird'},
    17: {'id': 17, 'name': 'cat'},
    18: {'id': 18, 'name': 'dog'},
    19: {'id': 19, 'name': 'horse'},
    20: {'id': 20, 'name': 'sheep'},
    21: {'id': 21, 'name': 'cow'},
    22: {'id': 22, 'name': 'elephant'},
    23: {'id': 23, 'name': 'bear'},
    24: {'id': 24, 'name': 'zebra'},
    25: {'id': 25, 'name': 'giraffe'},
    27: {'id': 27, 'name': 'backpack'},
    28: {'id': 28, 'name': 'umbrella'},
    31: {'id': 31, 'name': 'handbag'},
    32: {'id': 32, 'name': 'tie'},
    33: {'id': 33, 'name': 'suitcase'},
    34: {'id': 34, 'name': 'frisbee'},
    35: {'id': 35, 'name': 'skis'},
    36: {'id': 36, 'name': 'snowboard'},
    37: {'id': 37, 'name': 'sports ball'},
    38: {'id': 38, 'name': 'kite'},
    39: {'id': 39, 'name': 'baseball bat'},
    40: {'id': 40, 'name': 'baseball glove'},
    41: {'id': 41, 'name': 'skateboard'},
    42: {'id': 42, 'name': 'surfboard'},
    43: {'id': 43, 'name': 'tennis racket'},
    44: {'id': 44, 'name': 'bottle'},
    46: {'id': 46, 'name': 'wine glass'},
    47: {'id': 47, 'name': 'cup'},
    48: {'id': 48, 'name': 'fork'},
    49: {'id': 49, 'name': 'knife'},
    50: {'id': 50, 'name': 'spoon'},
    51: {'id': 51, 'name': 'bowl'},
    52: {'id': 52, 'name': 'banana'},
    53: {'id': 53, 'name': 'apple'},
    54: {'id': 54, 'name': 'sandwich'},
    55: {'id': 55, 'name': 'orange'},
    56: {'id': 56, 'name': 'broccoli'},
    57: {'id': 57, 'name': 'carrot'},
    58: {'id': 58, 'name': 'hot dog'},
    59: {'id': 59, 'name': 'pizza'},
    60: {'id': 60, 'name': 'donut'},
    61: {'id': 61, 'name': 'cake'},
    62: {'id': 62, 'name': 'chair'},
    63: {'id': 63, 'name': 'couch'},
    64: {'id': 64, 'name': 'potted plant'},
    65: {'id': 65, 'name': 'bed'},
    67: {'id': 67, 'name': 'dining table'},
    70: {'id': 70, 'name': 'toilet'},
    72: {'id': 72, 'name': 'tv'},
    73: {'id': 73, 'name': 'laptop'},
    74: {'id': 74, 'name': 'mouse'},
    75: {'id': 75, 'name': 'remote'},
    76: {'id': 76, 'name': 'keyboard'},
    77: {'id': 77, 'name': 'cell phone'},
    78: {'id': 78, 'name': 'microwave'},
    79: {'id': 79, 'name': 'oven'},
    80: {'id': 80, 'name': 'toaster'},
    81: {'id': 81, 'name': 'sink'},
    82: {'id': 82, 'name': 'refrigerator'},
    84: {'id': 84, 'name': 'book'},
    85: {'id': 85, 'name': 'clock'},
    86: {'id': 86, 'name': 'vase'},
    87: {'id': 87, 'name': 'scissors'},
    88: {'id': 88, 'name': 'teddy bear'},
    89: {'id': 89, 'name': 'hair drier'},
    90: {'id': 90, 'name': 'toothbrush'},
}

CPU times: user 13.3 s, sys: 1.04 s, total: 14.3 s
Wall time: 14.5 s


### gr.Interfaceの引数を定義

In [11]:
# WebUIの入出力定義
input_def = gr.inputs.Image(shape=input_image_shape)
output_def = gr.outputs.Image()

# 入力を受け取ってから行う処理を定義
def output_predict_image(input_image):
  input_tensor = np.expand_dims(input_image, 0)
  
  def _detecter_fn():
    predict = detector(input_tensor)
    return predict
  predict_res = _detecter_fn()

  def _write_bound_box():
    proc_img = viz_utils.visualize_boxes_and_labels_on_image_array(
                          input_image.copy(),
                          predict_res['detection_boxes'][0].numpy(),
                          predict_res['detection_classes'][0].numpy().astype(np.int32),
                          predict_res['detection_scores'][0].numpy(),
                          category_index,
                          use_normalized_coordinates=True,
                          max_boxes_to_draw=1000,
                          min_score_thresh=.45, # 閾値
                          agnostic_mode=False)
    return proc_img
  output_img = _write_bound_box()

  return output_img

### Webページ起動

In [12]:
# Webページの起動
gr.Interface(fn=output_predict_image, inputs=input_def, outputs=output_def).launch(debug=False)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on External URL: https://36208.gradio.app
Interface loading below...


(<gradio.networking.serve_files_in_background.<locals>.HTTPServer at 0x7efdb29c1470>,
 'http://127.0.0.1:7861/',
 'https://36208.gradio.app')

In [15]:
import matplotlib.pyplot as plt

# ローカルテスト用に画像データを準備する為のコード
IMAGE_SHAPE = (320, 320)
grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)

grace_hopper = np.array(grace_hopper)
input_tensor = np.expand_dims(grace_hopper, 0)
input_tensor.shape

(1, 320, 320, 3)

In [16]:
detections = detector(input_tensor)
image_np_with_detections = grace_hopper.copy()

img_show = viz_utils.visualize_boxes_and_labels_on_image_array(
                          image_np_with_detections,
                          detections['detection_boxes'][0].numpy(),
                          detections['detection_classes'][0].numpy().astype(np.int32),
                          detections['detection_scores'][0].numpy(),
                          category_index,
                          use_normalized_coordinates=True,
                          max_boxes_to_draw=1000,
                          min_score_thresh=.30,
                          agnostic_mode=False)
plt.imshow(img_show)

<matplotlib.image.AxesImage at 0x7efdbc731a58>