In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from PIL import Image
import numpy

# Create a dataset iterator

In [None]:
image_generator = ImageDataGenerator(
    validation_split=0.15,
    horizontal_flip=True,
    zoom_range=0.1,
    width_shift_range=0.1,
    height_shift_range=0.1,
    rotation_range=5
)
train_generator = image_generator.flow_from_directory("dataset", subset="training", target_size=(224,224), batch_size=8)
validation_generator = image_generator.flow_from_directory("dataset", subset="validation", target_size=(224,224), batch_size=8)

In [None]:
Image.fromarray(next(train_generator)[0][0].astype(numpy.uint8))

In [None]:
Image.fromarray(next(validation_generator)[0][0].astype(numpy.uint8))

# Create model
We're using a MobileNetV2 as we want to deploy on mobile.

In [None]:
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
from tensorflow.keras.applications.mobilenet import MobileNet
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import *


In [None]:
mobile = MobileNet(
    input_shape=(224,224,3),
    include_top=False,
    weights='imagenet', 
    pooling='avg',
    alpha=0.5
)
output = Dropout(0.4)(mobile.output)
output = Dense(8, activation="relu")(output)
output = Dense(2, activation='softmax')(output)

model = Model(inputs=mobile.input, outputs=output)
model.summary()

In [None]:
model.compile(optimizer=Adam(amsgrad=True), loss="categorical_crossentropy", metrics=["accuracy"])

In [None]:
callbacks = [
    ReduceLROnPlateau(
        patience=3,
        factor=0.2,
        verbose=1,
        min_lr=1e-5
    ),
    ModelCheckpoint(
        filepath="croissant.hdf5",
        verbose=1,
        save_best_only=True
    )
]

In [None]:
model.fit_generator(
    generator=train_generator,
    steps_per_epoch=256,
    epochs=50,
    verbose=1,
    validation_data=validation_generator,
    validation_steps=40,
    callbacks=callbacks
)