In [None]:
import tensorflow as tf

In [None]:
IMG_SIZE = 224
BATCH_SIZE = 32

DATA_DIR_PATH = 'dataset/PetImages'
CATEGORIES = ['Dog', 'Cat']

SEED = 123  # set to an arbitrary value
VALIDATION_SPLIT = 0.2  # set to desired percentage of data to be used for validation

In [None]:
image_get = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1/255,
    validation_split=VALIDATION_SPLIT,
    rotation_range=10,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

In [None]:
train_gen = image_get.flow_from_directory(
    directory=DATA_DIR_PATH,
    target_size=(IMG_SIZE, IMG_SIZE),
    color_mode='rgb',
    class_mode='categorical',
    batch_size=BATCH_SIZE,
    seed=SEED,
    shuffle=True,
    subset='training'
)
valid_gen = image_get.flow_from_directory(
    directory=DATA_DIR_PATH,
    target_size=(IMG_SIZE, IMG_SIZE),
    color_mode='rgb',
    class_mode='categorical',
    batch_size=BATCH_SIZE,
    seed=SEED,
    shuffle=True,
    subset='validation'
)

In [None]:
def my_gen(gen):
    while True:
        try:
            x, y = gen.next()
            yield x, y
        except:
            pass

In [None]:
base_model = tf.keras.applications.MobileNetV2(
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    include_top=False,
    weights='imagenet'
)
base_model.trainable = False

In [None]:
model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(2, activation='softmax')
])

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
history = model.fit(
    my_gen(train_gen),
    steps_per_epoch=BATCH_SIZE,
    validation_data=my_gen(valid_gen),
    validation_steps=BATCH_SIZE,
    epochs=10
)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

In [None]:
model.save('model.h5')