<a href="https://colab.research.google.com/github/akinoriosamura/tensorflow2.0-sample/blob/master/transfer_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals

!pip install -q tensorflow-gpu==2.0.0-beta1
import tensorflow as tf

In [0]:
!pip install -q tensorflow_hub
import tensorflow_hub as hub

from tensorflow.keras import layers

In [0]:
classifier_url ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2" #@param {type:"string"}
IMAGE_SHAPE = (224, 224)

classifier = tf.keras.Sequential([
    hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE+(3,))
])

In [4]:
data_root = tf.keras.utils.get_file(
  'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
   untar=True
)

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz


In [6]:
image_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
IMAGE_SHAPE = (224, 224)
image_data = image_gen.flow_from_directory(str(data_root), target_size=IMAGE_SHAPE)

Found 3670 images belonging to 5 classes.


In [7]:
image_data

<keras_preprocessing.image.directory_iterator.DirectoryIterator at 0x7faf1a0f1860>

In [9]:
for image_batch, label_batch in image_data:
  print("Image batch shape: ", image_batch.shape)
  print("Label batch shape: ", label_batch.shape)
  break  

Image batch shape:  (32, 224, 224, 3)
Label batch shape:  (32, 5)


In [11]:
result_batch = classifier.predict(image_batch)
result_batch.shape

(32, 1001)

In [14]:
import numpy as np
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
predicted_class_names = imagenet_labels[np.argmax(result_batch, axis=-1)]

predicted_class_names


array(['pot', 'daisy', 'broom', 'bee', 'ant', 'picket fence', 'bee',
       'mushroom', 'cardoon', 'spider web', 'cardoon', 'hip',
       'picket fence', 'stone wall', 'broom', 'white stork', 'daisy',
       'orange', 'pot', 'picket fence', 'sea urchin', 'daisy',
       'tarantula', 'pot', 'spindle', 'daisy', 'ant', 'zucchini', 'vase',
       'bee', 'bell pepper', 'bee'], dtype='<U30')

In [15]:

plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predicted_class_names[n])
  plt.axis('off')
_ = plt.suptitle("ImageNet predictions")

NameError: ignored

In [0]:
feature_extractor_url = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/2" #@param {type:"string"}

feature_extractor_layer = hub.KerasLayer(feature_extractor_url,
                                         input_shape=(224,224,3))

In [17]:
feature_batch = feature_extractor_layer(image_batch)
print(feature_batch.shape)

(32, 1280)


In [0]:
feature_extractor_layer.trainable = False


In [19]:
model = tf.keras.Sequential([
  feature_extractor_layer,
  layers.Dense(image_data.num_classes, activation='softmax')
])

model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
keras_layer_1 (KerasLayer)   (None, 1280)              2257984   
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________


In [20]:
predictions = model(image_batch)
predictions.shape


TensorShape([32, 5])

In [0]:
model.compile(
  optimizer=tf.keras.optimizers.Adam(),
  loss='categorical_crossentropy',
  metrics=['acc'])

In [22]:
steps_per_epoch = np.ceil(image_data.samples/image_data.batch_size)
history = model.fit(image_data, epochs=2,
                    steps_per_epoch=steps_per_epoch,
                    )

Epoch 1/2


W0619 13:47:55.156842 140390981650304 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Epoch 2/2


In [26]:
class_names = sorted(image_data.class_indices.items(), key=lambda pair:pair[1])
class_names = np.array([key.title() for key, value in class_names])
class_names

predicted_batch = model.predict(image_batch)
predicted_id = np.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]

predicted_label_batch

array(['Roses', 'Daisy', 'Dandelion', 'Tulips', 'Dandelion', 'Dandelion',
       'Tulips', 'Dandelion', 'Dandelion', 'Dandelion', 'Dandelion',
       'Tulips', 'Tulips', 'Daisy', 'Dandelion', 'Dandelion', 'Daisy',
       'Tulips', 'Sunflowers', 'Tulips', 'Sunflowers', 'Tulips',
       'Dandelion', 'Tulips', 'Dandelion', 'Sunflowers', 'Daisy',
       'Tulips', 'Tulips', 'Sunflowers', 'Tulips', 'Daisy'], dtype='<U10')