In [None]:
# Import libraries
import tensorflow as tf
import numpy as np
import os
import random
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.metrics import confusion_matrix
from PIL import Image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.vgg19 import preprocess_input

tfk = tf.keras
tfkl = tf.keras.layers
print(tf.__version__)

In [None]:
# Random seed for reproducibility
seed = 42

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
tf.compat.v1.set_random_seed(seed)

In [None]:
# Setting path directories
dataset_dir = '../input/ann-challenge/training'
labels = ['Apple','Blueberry','Cherry','Corn','Grape','Orange','Peach','Pepper','Potato','Raspberry','Soybean','Squash','Strawberry','Tomato']

In [None]:
# Training metadata
batch_size = 32
img_height = 256
img_width = 256
input_shape = (256,256,3)
epochs = 200

In [None]:
# Loader with data augmentation
idg = ImageDataGenerator(
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    width_shift_range=5,
    height_shift_range=5,
    rotation_range=30,
    fill_mode='constant',
    cval=0,
    validation_split=0.2) 

In [None]:
# Splitting in training and validation set
train_data_generator = idg.flow_from_directory(
                  directory=dataset_dir, target_size=(img_width, img_height),
                  batch_size=batch_size, seed=seed, subset="training")
  
valid_data_generator = idg.flow_from_directory(
                  directory=dataset_dir, target_size=(img_width, img_height), 
                  batch_size=batch_size, seed=seed, subset="validation")

In [None]:
# Download VGG19 model
supernet = tfk.applications.VGG19(
    include_top=False,
    weights="imagenet",
    input_shape=(256,256,3)
)
supernet.summary()

In [None]:
# Use the supernet as feature extractor
supernet.trainable = False

inputs = tfk.Input(shape=(img_height,img_width,3))
x = preprocess_input(inputs)
x = tfkl.Resizing(256,256, interpolation='bicubic')(x)
x = supernet(x)
x = tfkl.Flatten(name='Flattening')(x)
x = tfkl.Dropout(0.3, seed=seed)(x)
x = tfkl.Dense(
    256, 
    activation='relu',
    kernel_initializer = tfk.initializers.GlorotUniform(seed))(x)
x = tfkl.Dropout(0.3, seed=seed)(x)
x = tfkl.Dense(
    512, 
    activation='relu',
    kernel_initializer = tfk.initializers.GlorotUniform(seed))(x)
x = tfkl.Dropout(0.3, seed=seed)(x)
outputs = tfkl.Dense(
    14, 
    activation='softmax',
    kernel_initializer = tfk.initializers.GlorotUniform(seed))(x)


# Connect input and output through the Model class
VGG19_model = tfk.Model(inputs=inputs, outputs=outputs, name='VGG19_model')

# Compile the model
VGG19_model.compile(loss=tfk.losses.CategoricalCrossentropy(), optimizer=tfk.optimizers.Adam(), metrics='accuracy')
VGG19_model.summary()

In [None]:
# Checking GPU
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

In [None]:
# Training the model
history = VGG19_model.fit(
    train_data_generator,
    batch_size = batch_size,
    epochs = epochs,
    validation_data = valid_data_generator,
    callbacks = [tfk.callbacks.EarlyStopping(monitor='val_accuracy', mode='max', patience=10, restore_best_weights=True)]
).history

In [None]:
VGG19_model.save('VGG19_Model')

In [None]:
# Plotting the training
plt.figure(figsize=(15,5))
plt.plot(history['loss'], label='Training', alpha=.8, color='#ff7f0e')
plt.plot(history['val_loss'], label='Validation', alpha=.8, color='#4D61E2')
plt.legend(loc='upper left')
plt.title('Categorical Crossentropy')
plt.grid(alpha=.3)

plt.figure(figsize=(15,5))
plt.plot(history['accuracy'], label='Training', alpha=.8, color='#ff7f0e')
plt.plot(history['val_accuracy'], label='Validation', alpha=.8, color='#4D61E2')
plt.legend(loc='upper left')
plt.title('Accuracy')
plt.grid(alpha=.3)

plt.show()

In [None]:
# Showing validation accuracy
max(history['val_accuracy'])