# Strawberry Classification

<Initial explanation>

# 1. Imports
<Explain imports here>

In [None]:
import numpy as np
from keras import backend as K
from keras.preprocessing import image
import os
from keras.applications.vgg16 import VGG16
from keras.applications.resnet import ResNet50
from keras.application.mobilenet import MobileNet
from keras import layers
from keras import Model
from keras import Sequential
import matplotlib.pyplot as plt
import json

# 2. Data Normalization
<Explain normalization here>

In [None]:
Data normalization code goes here.

# 3. Test / Training / Validation Set Formation
<Explain the splitting of the data into the three sets>

In [None]:
Data split code goes here.

# 4. Convolutional Layers
<Information about each of the architectures goes here>
VGG16
ResNet50
AlexNet
etc

## VGG16
Further explanation of the model. Explain the attachment of the SDN at the bottom and why freezing the weights is important / what it achieves

In [None]:
# Build VGG16 model
def vgg_build_model():
    set_trainable = False
    m = Sequential()
    conv_base = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    for layer in conv_base.layers:
        if layer.name == 'block_conv1':
            set_trainable = True
        if set_trainable:
            layer.trainable = True
        else:
            layer.trainable = False
    m.add(conv_base)

    # Custom Shallow Dense Network
    m.add(layers.Flatten())
    m.add(layers.Dense(256, activation='relu'))
    m.add(layers.Dropout(0.5))
    m.add(layers.Dense(2, activation='heaviside'))
    return m

## ResNet
Further explanation of the model. Explain the attachment of the SDN at the bottom and why freezing the weights is important / what it achieves

In [None]:
# Build ResNet50 model
def res_build_model():
    m = Sequential()
    conv_base = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    for layer in conv_base.layers:
        layer.trainable = False
    m.add(conv_base)

    # Custom Shallow Dense Network
    m.add(layers.Flatten())
    m.add(layers.Dense(256, activation='relu'))
    m.add(layers.Dropout(0.5))
    m.add(layers.Dense(5, activation='heaviside'))
    return m

## AlexNet
Further explanation of the model. Explain the attachment of the SDN at the bottom and why freezing the weights is important / what it achieves

In [None]:
def mobile_build_model():
    m = Sequential()
    conv_base = MobileNet(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    for layer in conv_base.layers:
        layer.trainable = False
    m.add(conv_base)

    # Custom Shallow Dense Network
    m.add(layers.Flatten())
    m.add(layers.Dense(256, activation='relu'))
    m.add(layers.Dropout(0.5))
    m.add(layers.Dense(5, activation='heaviside'))
    return m

# 5. Data Augmentation
<Explain the augmentation phase, augmentation params, generators, what each param does, etc>

In [None]:
# Set up generators train_gen and val_gen
train_datagen = image.ImageDataGenerator(
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest')

val_datagen = image.ImageDataGenerator()

dir_path = os.path.dirname(os.path.realpath(__file__))
train_dir = os.path.join(dir_path, 'train')
val_dir = os.path.join(dir_path, 'validation')

train_gen = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='binary')

val_gen = val_datagen.flow_from_directory(
    val_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='binary')

# 6. Callback
<Explain what it does, how it works, why it's important for training>

In [None]:
# Establish a callback for overfitting detection
es = EarlyStopping(monitor='val_acc', mode='max', verbose=0)
callback_list = [es]

# 7. Training
<Explain how training works via generator>

## VGG16: Initializing the Model
<Explanation text about initializing and compiling the VGG16 model. Explain what the loss function and metrics are>

In [None]:
# VGG16 Model
# Build the model and train using train_gen and val_gen
model = vgg_build_model()
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])

## VGG16: Training the Model
<This is where VGG16 is trained. Explain what the various hyperparams are and what they do (steps_per_epoch, epochs, etc) and why we chose them this way>

In [None]:
# Train the model using the generator
history = model.fit_generator(train_gen,
                              steps_per_epoch=100,
                              epochs=10,
                              validation_data=val_gen,
                              callbacks=callback_list,
                              validation_steps=32)
model.save('vgg1_strawberry.h5')

## VGG16: Plotting the Results
<This is where the training results from the VGG16 run are plotted with PyPlot. Explain the significance of the results and what the code does>

In [None]:
# Plot the results
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('VGG16 Training and Validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('VGG16 Training and Validation loss')
plt.legend()

plt.show()

## ResNet50: Initializing the Model
<Explanation text about initializing and compiling the ResNet50 model. Explain what the loss function and metrics are>

In [None]:
# ResNet50 Model
# Build the model and train using train_gen and val_gen
model = res_build_model()
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])

## ResNet50: Training the Model
<This is where ResNet50 is trained. Explain what the various hyperparams are and what they do (steps_per_epoch, epochs, etc) and why we chose them this way>

In [None]:
# Train the model using the generator
history = model.fit_generator(train_gen,
                              steps_per_epoch=100,
                              epochs=10,
                              validation_data=val_gen,
                              callbacks=callback_list,
                              validation_steps=32)
model.save('res_strawberry.h5')

## ResNet50: Plotting the Results
<This is where the training results from the ResNet50 run are plotted with PyPlot. Explain the significance of the results and what the code does>

In [None]:
# Plot the results
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('ResNet Training and Validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('ResNet Training and Validation loss')
plt.legend()

plt.show()

## AlexNet: Initializing the Model
<Explanation text about initializing and compiling the AlexNet model. Explain what the loss function and metrics are>

In [None]:
# Build the model and train using train_gen and val_gen
model = mobile_build_model()
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])

## AlexNet: Training the Model
<This is where AlexNet is trained. Explain what the various hyperparams are and what they do (steps_per_epoch, epochs, etc) and why we chose them this way>

In [None]:
history = model.fit_generator(train_gen,
                              steps_per_epoch=100,
                              epochs=10,
                              validation_data=val_gen,
                              callbacks=callback_list,
                              validation_steps=32)
model.save('mobile_strawberry.h5')

## AlexNet: Plotting the Results
<This is where the training results from the AlexNet run are plotted with PyPlot. Explain the significance of the results and what the code does>

In [None]:
# Plot the results
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('MobileNet Training and Validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('MobileNet Training and Validation loss')
plt.legend()

plt.show()

# Conclusion
<Closing remarks, observations, intuition, etc>