In [3]:
import tensorflow_datasets as tfds

(ds_train, ds_val) = tfds.load(
    "cats_vs_dogs",
    split=["train[:80%]", "train[80%:]"],
    as_supervised=True
)


In [4]:
import tensorflow as tf
from tensorflow import keras

IMG_SIZE = (160, 160)
BATCH_SIZE = 32

def preprocess(image, label):
    image = tf.image.resize(image, IMG_SIZE)
    image = keras.applications.mobilenet_v2.preprocess_input(image)
    return image, label

train_ds = ds_train.map(preprocess).batch(BATCH_SIZE)
val_ds = ds_val.map(preprocess).batch(BATCH_SIZE)


In [5]:
base_model = keras.applications.MobileNetV2(
    input_shape=(160, 160, 3),
    include_top=False,
    weights="imagenet"
)

base_model.trainable = False


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [6]:
from keras.models import Sequential
from keras.layers import Dense, Dropout, GlobalAveragePooling2D

model = Sequential([
    base_model,                     # smart brain
    GlobalAveragePooling2D(),        # summarizer
    Dense(128, activation="relu"),  # thinking layer
    Dropout(0.5),                   # anti-cheating
    Dense(1, activation="sigmoid")  # cat or dog?
])


In [8]:
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    loss="binary_crossentropy",
    metrics=["accuracy"]
)


In [None]:
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=3
)


Epoch 1/3
[1m582/582[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m535s[0m 899ms/step - accuracy: 0.9049 - loss: 0.2196 - val_accuracy: 0.9826 - val_loss: 0.0506
Epoch 2/3
[1m582/582[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 657ms/step - accuracy: 0.9789 - loss: 0.0582

In [None]:
plt.plot(history.history["accuracy"], label="Train Accuracy")
plt.plot(history.history["val_accuracy"], label="Validation Accuracy")
plt.legend()
plt.show()

plt.plot(history.history["loss"], label="Train Loss")
plt.plot(history.history["val_loss"], label="Validation Loss")
plt.legend()
plt.show()
