In [3]:
import tensorflow_datasets as tfds
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_hub as hub

def define_dataset():
    (train_data, val_data), info = tfds.load('cats_vs_dogs', split=['train[:80%]', 'train[80%:]'],
                                                                 with_info=True, as_supervised=True)
    num_examples = info.splits['train'].num_examples
    return train_data, val_data, info, num_examples

EPOCHS=5
BATCH=40
RES=224
#_URL="https://tfhub.dev/google/imagenet/resnet_v2_152/classification/5"
_URL="https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/5"

def define_model():
    model = tf.keras.Sequential([
        hub.KerasLayer(_URL, input_shape=(224,224,3)),
        tf.keras.layers.Dense(2, activation='softmax')
    ])
    model.compile(metrics=['accuracy'], loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
                optimizer=tf.keras.optimizers.Adam())
    model.summary()
    return model

def preprocess_data():
    normalize = lambda data, labels: (tf.image.resize(data, (RES, RES))/255.0, labels)
    train_batch = train_data.shuffle(num_examples//4).map(normalize).batch(BATCH).prefetch(tf.data.AUTOTUNE)
    val_batch = val_data.map(normalize).batch(BATCH).prefetch(tf.data.AUTOTUNE)
    return train_batch, val_batch

train_data, val_data, info, num_examples = define_dataset()
m=define_model()
train_batch, val_batch = preprocess_data()
history = m.fit(train_batch, 
                validation_data=val_batch, callbacks=tf.keras.callbacks.EarlyStopping(patience=1),
                verbose=2, epochs=EPOCHS)

def summarize_diagnostics(hist):
    plt.figure(figsize=(10,12))
    plt.subplot(1,2,1)
    acc = hist.history['accuracy']
    val_acc = hist.history['val_accuracy']
    loss = hist.history['loss']
    val_loss = hist.history['val_loss']
    plt.plot(acc, label='Training Accuracy')
    plt.plot(val_acc, label='Val Accuracy')
    plt.legend(loc='upper right')
    plt.subplot(1,2,2)
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.title('Summary ResNet')
    plt.legend(loc='upper right')
    plt.show()
summarize_diagnostics(history)


ValueError: Input 0 of layer "conv3d" is incompatible with the layer: expected min_ndim=5, found ndim=4. Full shape received: (None, 224, 224, 3)