In [None]:
# TODO: imports, move to CNN section, replace MobileNetV2 with something smaller and simpler. 

# Instead, of these images, lets try a dataset somewhat closer to the ImageNet dataset. 
# The dataset is too large to reasonably redistribute as part of this repository, so you will
# have to download it separately.

# The dataset can be found here: 
# https://www.kaggle.com/alxmamaev/flowers-recognition/
flower_dataset_directory = 'flowers_dataset/flowers/'

# Load the images by class
classes = ['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']

# Bigger takes longer, but can be mroe accurate.
# 96x96 is the smallest allowed shape for MobileNetV2
image_size = 96
batch_size = 8
num_epochs = 30

# Same as above, import the pretrained model
# But this time we've got bigger images, so we'll use MobileNetV2's default shape
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(image_size, image_size, 3))
old_top = base_model.output
old_top = GlobalAveragePooling2D()(old_top)
new_top = Dense(len(classes), activation='softmax')(old_top)
model = Model(inputs=base_model.input, outputs=new_top)

for layer in base_model.layers:
    layer.trainable = False

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

## We're doing a little bit of a hack to employ cross validation.
## Because we don't have an excellently curated dataset we want the validation
## data to change each training round.
historical_data = {
    'acc': [],
    'val_acc': [],
    'loss': [],
    'val_loss': []
}

for _ in range(num_epochs):
    # Prepare the data and apply some common augmentation
    # Plus, use 20% as validation data randomly each time.
    train_datagen = ImageDataGenerator(rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        preprocessing_function=preprocess_input,
        validation_split=0.2)
    
    train_generator = train_datagen.flow_from_directory(
        flower_dataset_directory,
        target_size=(image_size, image_size),
        batch_size=batch_size,
        class_mode='categorical',
        classes=classes,
        subset='training')

    validation_generator = train_datagen.flow_from_directory(
        flower_dataset_directory,
        target_size=(image_size, image_size),
        batch_size=batch_size,
        class_mode='categorical',
        classes=classes,
        subset='validation')
    
    history = model.fit_generator(
        train_generator,
        steps_per_epoch = train_generator.samples // batch_size,
        validation_data = validation_generator, 
        validation_steps = validation_generator.samples // batch_size,
        epochs = 1)
    
    historical_data['acc'].append(history.history['acc'])
    historical_data['val_acc'].append(history.history['val_acc'])
    historical_data['loss'].append(history.history['loss'])
    historical_data['val_loss'].append(history.history['val_loss'])

figure = plt.figure()

plt.subplot(1, 2, 1)
plt.plot(historical_data['acc'])
plt.plot(historical_data['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.tight_layout()

plt.subplot(1, 2, 2)
plt.plot(historical_data['loss'])
plt.plot(historical_data['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.tight_layout()

figure.tight_layout()
plt.show()