In [36]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

In [37]:
RESCALE_FACTOR = 1./255
BATCH_SIZE = 32
EPOCHS = 3
IMAGE_RES = 224

In [38]:
(ds_train, ds_test), ds_info = tfds.load(
                                        'cats_vs_dogs',
                                        split = ('train[:80%]', 'train[80%:]'),
                                        with_info=True,
                                        as_supervised=True)

In [39]:
def format_image(image, label):
    image = tf.image.resize(image, (IMAGE_RES, IMAGE_RES))/RESCALE_FACTOR
    return image, label


In [40]:
n_examples = ds_info.splits['train'].num_examples
n_examples

23262

In [41]:
ds_train = ds_train.shuffle(n_examples//4).map(format_image).batch(BATCH_SIZE).prefetch(1)
ds_test = ds_test.map(format_image).batch(BATCH_SIZE).prefetch(1)

In [42]:
print(ds_train)
print(ds_test)

<PrefetchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>
<PrefetchDataset element_spec=(TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int64, name=None))>


In [43]:
URL = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
feature_extractor = hub.KerasLayer(URL, input_shape=(IMAGE_RES, IMAGE_RES, 3))
feature_extractor.trainable = False

In [44]:
model = tf.keras.Sequential([
                             feature_extractor,
                             tf.keras.layers.Dense(2)                           
])
model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 keras_layer_2 (KerasLayer)  (None, 1280)              2257984   
                                                                 
 dense_1 (Dense)             (None, 2)                 2562      
                                                                 
Total params: 2,260,546
Trainable params: 2,562
Non-trainable params: 2,257,984
_________________________________________________________________


In [45]:
model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

In [46]:
history = model.fit(ds_train,
                    epochs=EPOCHS,
                    validation_data=ds_test)

Epoch 1/3
 12/582 [..............................] - ETA: 9:20 - loss: 0.9568 - accuracy: 0.4974

Corrupt JPEG data: 99 extraneous bytes before marker 0xd9


 44/582 [=>............................] - ETA: 11:06 - loss: 0.7689 - accuracy: 0.5568



 54/582 [=>............................] - ETA: 11:43 - loss: 0.7485 - accuracy: 0.5723

Corrupt JPEG data: 396 extraneous bytes before marker 0xd9


126/582 [=====>........................] - ETA: 11:00 - loss: 0.6767 - accuracy: 0.6198

Corrupt JPEG data: 65 extraneous bytes before marker 0xd9




Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9




Corrupt JPEG data: 128 extraneous bytes before marker 0xd9




Corrupt JPEG data: 239 extraneous bytes before marker 0xd9




Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9




Corrupt JPEG data: 228 extraneous bytes before marker 0xd9




Corrupt JPEG data: 162 extraneous bytes before marker 0xd9
Corrupt JPEG data: 252 extraneous bytes before marker 0xd9
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
Corrupt JPEG data: 1403 extraneous bytes before marker 0xd9


Epoch 2/3
 12/582 [..............................] - ETA: 10:12 - loss: 0.5700 - accuracy: 0.6823

Corrupt JPEG data: 99 extraneous bytes before marker 0xd9


 44/582 [=>............................] - ETA: 9:36 - loss: 0.5226 - accuracy: 0.7379



 54/582 [=>............................] - ETA: 9:24 - loss: 0.5245 - accuracy: 0.7355

Corrupt JPEG data: 396 extraneous bytes before marker 0xd9


126/582 [=====>........................] - ETA: 8:02 - loss: 0.5323 - accuracy: 0.7282

Corrupt JPEG data: 65 extraneous bytes before marker 0xd9




Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9




Corrupt JPEG data: 128 extraneous bytes before marker 0xd9




Corrupt JPEG data: 239 extraneous bytes before marker 0xd9




Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9




Corrupt JPEG data: 228 extraneous bytes before marker 0xd9




Corrupt JPEG data: 162 extraneous bytes before marker 0xd9
Corrupt JPEG data: 252 extraneous bytes before marker 0xd9
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
Corrupt JPEG data: 1403 extraneous bytes before marker 0xd9


Epoch 3/3
 12/582 [..............................] - ETA: 12:09 - loss: 0.5193 - accuracy: 0.7552

Corrupt JPEG data: 99 extraneous bytes before marker 0xd9


 44/582 [=>............................] - ETA: 11:12 - loss: 0.5017 - accuracy: 0.7649



 54/582 [=>............................] - ETA: 11:04 - loss: 0.4985 - accuracy: 0.7610

Corrupt JPEG data: 396 extraneous bytes before marker 0xd9


126/582 [=====>........................] - ETA: 9:19 - loss: 0.5149 - accuracy: 0.7488

Corrupt JPEG data: 65 extraneous bytes before marker 0xd9




Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9




Corrupt JPEG data: 128 extraneous bytes before marker 0xd9




Corrupt JPEG data: 239 extraneous bytes before marker 0xd9




Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9




Corrupt JPEG data: 228 extraneous bytes before marker 0xd9




Corrupt JPEG data: 162 extraneous bytes before marker 0xd9
Corrupt JPEG data: 252 extraneous bytes before marker 0xd9
Corrupt JPEG data: 214 extraneous bytes before marker 0xd9
Corrupt JPEG data: 1403 extraneous bytes before marker 0xd9


