### Environment Setup

In [None]:
import datetime

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Flatten, Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.applications.mobilenet import MobileNet, preprocess_input

%load_ext tensorboard

### Ingest, Image Preprocessing and Augmentation

In [None]:
train_data_dir = "./dataset/mushie_image_data/"
num_classes = 2
img_width, img_height = 224, 224
classes = ['poisonous', 'edible']
batch_size = 40

# NOTE: our model will have a single output node
# This means that an output of '0' means a prediction of poisonous,
# And an output of '1' means a prediction of edible
# To flip this, change the order of the classes above

In [None]:
# Define image augmentation methods here
# As well as the train/validation split (thanks Keras for adding that feature!)
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=20,
    shear_range=0.2,
    zoom_range=0.2,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    validation_split=0.2)

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    classes = classes,
    class_mode='binary',
    subset='training')

validation_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    classes = classes,
    class_mode='binary',
    subset='validation')

In [None]:
### Modeling Setup

In [None]:
# Build the model using various pre-trained models as a base
# And train the bottom few layers via transfer learning
def model_maker(base_model, img_width, img_height):
    #freeze the highest layers
    for layer in base_model.layers[:]:
        layer.trainable = False
        
    input = Input(shape=(img_width, img_height, 3))
    custom_model = base_model(input)
    custom_model = GlobalAveragePooling2D()(custom_model)
    custom_model = Dense(64, activation='relu')(custom_model)
    custom_model = Dropout(0.5)(custom_model)
    predictions = Dense(1, activation='softmax')(custom_model)
    
    return Model(inputs=input, outputs=predictions)

In [None]:
# Choose and instantiate the pre-trained model we want to try
mobile_net = MobileNet(include_top=False, input_shape=(img_width, img_height, 3))

In [None]:
model = model_maker(mobile_net, img_width, img_height)

In [None]:
#use this if you want to continue training a saved model
#model = load_model('./mushie_model.h5')

##### Optional Callbacks

In [None]:
# To enable tensorboard fun

# Clear any logs from previous runs
!rm -rf ./logs/ 
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

In [None]:
#enable early stopping
es = tf.keras.callbacks.EarlyStopping(
                                 monitor='val_loss',
                                 patience=4,
                                 mode='auto',
                                 baseline=None,
                                 restore_best_weights=True
                                )

### Training

In [None]:
#compile and train the model
model.compile(loss='binary_crossentropy',
              optimizer=tf.keras.optimizers.Adam(0.001),
              metrics=['acc'])

In [None]:
epochs = 5
model.fit(train_generator,
          epochs = epochs,
          steps_per_epoch = train_generator.samples // batch_size,
          validation_data=validation_generator,
          validation_steps = validation_generator.samples // batch_size,
          callbacks=[tensorboard_cb, es]
         )

### Tensorboard Evaluation

In [None]:
%tensorboard --logdir logs/fit