In [5]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import numpy as np
import time
import matplotlib.pyplot as plt

(train_data, val_data), meta = tfds.load('cats_vs_dogs', split=['train[:80%]', 'train[80%:]'], 
                                                        with_info=True, as_supervised=True)

num_train=meta.splits['train'].num_examples
BATCH=32
EPOCHS=6

def define_model():
    #MobileNet
    #RES=224
    #URL = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"

    #Inception
    RES=299
    URL = "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4"
    model=tf.keras.Sequential([
        hub.KerasLayer(URL, input_shape=(RES, RES, 3)),
        tf.keras.layers.Dense(2, activation='sigmoid')
    ])
    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy'], optimizer=tf.keras.optimizers.Adam())
    return model, RES
model,RES=define_model()

def preprocess_batches():
    normalize=lambda data, labels: (tf.image.resize(data, (RES, RES))/255.0, labels)
    train_batch = train_data.shuffle(num_train//4).map(normalize).batch(BATCH).prefetch(tf.data.AUTOTUNE)
    val_batch = val_data.map(normalize).batch(BATCH).prefetch(tf.data.AUTOTUNE)
    return train_batch, val_batch

train_batch, val_batch = preprocess_batches()
model.fit(
    train_batch,
    verbose=2,
    validation_data=val_batch,
    epochs=EPOCHS
)

def summarize_model_diagnostics(history, history_mob):
    loss, acc = history.history['loss'], history.history['accuracy']
    val_loss, val_acc = history.history['val_loss'], history.history['val_accuracy']
    loss_mob, acc_mob = history_mob.history['loss'], history_mob.history['accuracy']
    val_loss_mob, val_acc_mob = history_mob.history['val_loss'], history_mob.history['val_accuracy']
    fig, ax=plt.subplots(1,2, figsize=(20,10))
    ax[1].plot(loss, label='loss Inception',color='blue')
    ax[1].plot(val_loss, label='val loss Inception', color='r')
    ax[1].plot(loss_mob, label='loss MobileNet',color='orange')
    ax[1].plot(val_loss_mob, label='val loss MobileNet', color='green')
    ax[1].legend(loc='upper right')
    ax[1].set_title('Loss')

    ax[0].plot(acc, label='accuracy Inception',color='blue')
    ax[0].plot(val_acc, label='val accuracy Inception', color='r')
    ax[0].plot(acc_mob, label='accuracy MobileNet',color='orange')
    ax[0].plot(val_acc_mob, label='val accuracy MobileNet', color='green')
    ax[0].legend(loc='upper right')
    ax[0].set_title('Accuracy')
    plt.title('Training & Validation Performance on')
    plt.tight_layout()
    plt.show()

from collections import Counter
test_data, test_labels = next(iter(train_data.take(1)))
test_data, test_labels = test_data.numpy(), test_labels.numpy()

#For MobileNet
predicts = model.predict(test_data, verbose=2)
predicted_label = np.argmax(tf.squeeze(predicts).numpy(), axis=-1)




KeyboardInterrupt: 