This neural network classifies images of tomato pests based on [Huang and Chuang's dataset](https://data.mendeley.com/datasets/s62zm6djd2/1). One experimental setup uses the dataset with basic augmentation and another setup uses a dataset with GAN augmentation. The model to be trained uses transfer learning on ImageNet based on a finetuned MobileNet model. MobileNet was used in order to deploy the model with greater efficiency into the field for edge computing.

Please setup datasets prior to running notebook:

* [Original Dataset](https://www.kaggle.com/datasets/aprilryan/original)
* [Dataset with Basic Augmentation](https://www.kaggle.com/datasets/aprilryan/pyaugm)
* [Dataset with Generative Augmentation](https://www.kaggle.com/datasets/aprilryan/tomatogan)

@author Jose Enrique R. Lopez<br />
@date-created 8 November 2020

# Import Libraries

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Dense, Flatten, BatchNormalization, Conv2D, MaxPool2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications import imagenet_utils
from keras.applications.vgg16 import VGG16
from keras.models import Model
from keras.layers import Dense
from keras.layers import Flatten
from sklearn.metrics import confusion_matrix
import itertools
import os
import shutil
import random
import glob
import matplotlib.pyplot as plt
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
%matplotlib inline

# Initialize experiment components

In [2]:
#EXPERIMENTAL TREATMENTS/MODES
ORIGINAL = 0
AUGMENTED = 1
GAN = 2

In [3]:
# vgg16_model = tf.keras.applications.vgg16.VGG16()
mobile = tf.keras.applications.mobilenet.MobileNet(input_shape = (128,128, 3))


# Utility Functions

In [4]:
def plot_confusion_matrix(cm, classes,
                          normalize=True,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    #plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], '.2f'),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [5]:
def plotImages(images_arr):
    fig, axes = plt.subplots(1, 10, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

In [6]:
classes = ['SE', 'SL', 'TU', 'BA', 'MP', 'HA']

# Retrieve Data

In [7]:
def preprocess(image):
    rotated = np.rot90(image, np.random.choice([-1, 0, 1]))
    preprocessed = preprocessing_function=tf.keras.applications.mobilenet.preprocess_input(image)
    return preprocessed

In [18]:
mode = ORIGINAL

if mode == ORIGINAL:
    train_path='../input/tomatopests/train'
    valid_path='../input/tomatopests/valid'
    test_path='../input/tomatopests/test'
elif mode == AUGMENTED:
    train_path='../input/pyaugm/train'
    valid_path='../input/pyaugm/valid'
    test_path='../input/pyaugm/test'
elif mode == GAN:
    train_path='../input/tomatogan/train'
    valid_path='../input/tomatogan/valid'
    test_path='../input/tomatogan/test'
    

train_batches = ImageDataGenerator(preprocessing_function=preprocess, brightness_range=[0.7, 1.3], zoom_range=[0.8, 1], channel_shift_range=10.,horizontal_flip=True, vertical_flip = True) \
.flow_from_directory(directory=train_path, target_size=(128, 128), classes=classes, batch_size=16)
valid_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input) \
.flow_from_directory(directory=valid_path, target_size=(128,128), classes=classes, batch_size=16)
test_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input) \
.flow_from_directory(directory=test_path, target_size=(128,128), classes=classes, batch_size=16, shuffle=False)

In [9]:
imgs, labels = next(train_batches)
plotImages(imgs)
print(labels)

# Build MobileNet Model

In [12]:
x = mobile.layers[-6].output
output = Dense(units=6, activation='softmax')(x)
model = Model(inputs=mobile.input, outputs=output)

for layer in model.layers[:-23]:
    layer.trainable = False

for layer in model.layers[:10]:         
    layer.trainable = False

model.summary()
model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy', keras.metrics.Precision(), keras.metrics.Recall()])
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode ='min', patience = 2, verbose=0)

# Train MobileNet model

In [19]:
model.fit(x=train_batches,
epochs=20,
steps_per_epoch=len(train_batches),
validation_data=valid_batches,
validation_steps=len(valid_batches),
verbose=2,
callbacks = [callback]
)

# Evaluate model

In [24]:
modes = {0: "CONTROL", 1: "AUGMENTED", 2: "GAN"}
print(f"**************RESULTS FOR {modes[mode]} SETUP**************")
predictions = model.predict(x=test_batches, steps=len(test_batches), verbose=0)
np.round(predictions)
print("Evaluate on test data")
result = model.evaluate(test_batches, batch_size=32)
print("test loss, test acc:", result)
    
cm = confusion_matrix(y_true=test_batches.classes, y_pred=np.argmax(predictions, axis=-1))
plot_confusion_matrix(cm=cm, classes=classes, title='Confusion Matrix')