In [1]:
import tensorflow as tf
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input, decode_predictions
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import os.path
from math import ceil

In [2]:
print(tf.__version__)
print(tf.test.is_gpu_available())
print(tf.config.list_physical_devices('GPU'))

2.1.0
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
True
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [3]:
def plot_history(history):
    plt.plot(history.history['categorical_accuracy'])
    plt.plot(history.history['val_categorical_accuracy'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()
    
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()

In [4]:
#current_dir = '/home/jakobkallestad'
current_dir = '/home/jupyter'
top_layers_checkpoint_path = current_dir + '/best_models/top_layers_best.hdf5'
fine_tuned_checkpoint_path = current_dir + '/best_models/fine_tuned_best.hdf5'

starting_epoch = 0
epochs_top_layers = 10
epochs_fine_tuning = 50
batch_size = 128

In [5]:
train_dir = current_dir  + '/planktondata/plankton/data-65/train/'
validation_dir = current_dir + '/planktondata/plankton/data-65/validate/'
test_dir = current_dir + '/planktondata/plankton/data-65/test/'

In [6]:
input_shape = (299, 299)

In [7]:
datagen_train = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=360,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='nearest')

datagen_validate = ImageDataGenerator(preprocessing_function=preprocess_input)
datagen_test = ImageDataGenerator(preprocessing_function=preprocess_input)

In [8]:
print(input_shape)

(299, 299)


In [10]:
generator_train = datagen_train.flow_from_directory(directory=train_dir,
                                                    target_size=input_shape,
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    color_mode='rgb')

generator_validate = datagen_validate.flow_from_directory(directory=validation_dir,
                                                    target_size=input_shape,
                                                    batch_size=batch_size,
                                                    color_mode='rgb',
                                                    shuffle=False)

generator_test = datagen_test.flow_from_directory(directory=test_dir,
                                                  target_size=input_shape,
                                                  batch_size=batch_size,
                                                  color_mode='rgb',
                                                  shuffle=False)

Found 699491 images belonging to 65 classes.
Found 6500 images belonging to 65 classes.
Found 6500 images belonging to 65 classes.


In [None]:
print("generators done")

In [None]:
steps_train = 1000 #generator_train.n / batch_size
steps_validate = ceil(generator_validate.n / batch_size)
steps_test = ceil(generator_test.n / batch_size)

In [None]:
cls_train = generator_train.classes
cls_validate = generator_validate.classes
cls_test = generator_test.classes

In [None]:
class_names = list(generator_train.class_indices.keys())
num_classes = generator_train.num_classes

In [None]:
from sklearn.utils.class_weight import compute_class_weight

class_weight = compute_class_weight(class_weight='balanced',
                                    classes=np.unique(cls_train),
                                    y=cls_train)
class_weight = dict(enumerate(class_weight))

In [None]:
# https://keras.io/applications/#inceptionv3
base_model = InceptionV3(weights='imagenet', include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)

#if os.path.exists(top_layers_checkpoint_path):
#    model.load_weights(top_layers_checkpoint_path)
#    print("loaded top layer checkpoint: {}".format(top_layers_checkpoint_path))

base_model.trainable = False

In [None]:
def top_5_accuracy(y_true, y_pred):
    return tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=5)

In [None]:
model.compile(optimizer=Adam(1e-3), loss='categorical_crossentropy', metrics=['categorical_accuracy', top_5_accuracy])

In [None]:
tensorboard_callback = TensorBoard(log_dir=current_dir + '/logs6', histogram_freq=1, write_graph=True, write_images=False, profile_batch=0)
checkpoint_callback = ModelCheckpoint(top_layers_checkpoint_path, monitor='val_categorical_accuracy', verbose=1, save_best_only=False, save_weights_only=True, mode='auto', save_freq='epoch')

In [None]:
# EVERYTHING CONFUSION MATRIX RELATED:
from sklearn.metrics import confusion_matrix
from datetime import datetime
import itertools
import io

def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
      returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image


def plot_confusion_matrix(cm, class_names):
    """
    Returns a matplotlib figure containing the plotted confusion matrix.
    Args:
        cm (array, shape = [n, n]): a confusion matrix of integer classes
        class_names (array, shape = [n]): String names of the integer classes
    """
    figure = plt.figure(figsize=(50, 50))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion matrix")
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
    # Normalize the confusion matrix.
    cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)
    # Use white text if squares are dark; otherwise black.
    threshold = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        color = "white" if cm[i, j] > threshold else "black"
        plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    return figure

logdir = current_dir + "/logs6/image/" + datetime.now().strftime("%Y%m%d-%H%M%S")
# Define the basic TensorBoard callback.
file_writer_cm = tf.summary.create_file_writer(logdir + '/cm')

def log_confusion_matrix(epoch, logs):
    # Use the model to predict the values from the validation dataset.
    test_pred_raw = model.predict(generator_validate, steps=steps_validate)
    test_pred = np.argmax(test_pred_raw, axis=1)
    # Calculate the confusion matrix.
    cm = confusion_matrix(cls_validate, test_pred)
    # Log the confusion matrix as an image summary.
    figure = plot_confusion_matrix(cm, class_names=class_names)
    cm_image = plot_to_image(figure)
    # Log the confusion matrix as an image summary.
    with file_writer_cm.as_default():
        tf.summary.image("Confusion Matrix", cm_image, step=epoch)

# Define the per-epoch callback.
cm_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)



In [None]:
history = model.fit(generator_train,
                                  initial_epoch = starting_epoch,
                                  epochs=epochs_top_layers+starting_epoch,
                                  steps_per_epoch=steps_train,
                                  class_weight=class_weight,
                                  validation_data=generator_validate,
                                  validation_steps=steps_validate,
                                  shuffle=True,
                                  callbacks = [tensorboard_callback, checkpoint_callback, cm_callback])

In [None]:
plot_history(history)

In [None]:
result = model.evaluate(generator_test, steps=steps_test)
print("Test-set classification accuracy: {0:.2%}".format(result[1]))

In [None]:
# New checkpoint callback is set
checkpoint_callback = ModelCheckpoint(fine_tuned_checkpoint_path, monitor='val_categorical_accuracy', verbose=1, save_best_only=True, save_weights_only=True, mode='auto', save_freq='epoch')

In [None]:
# Unfreeze all layers and train a bit more:
model.trainable = True

In [None]:
model.compile(optimizer=Adam(1e-4), loss='categorical_crossentropy', metrics=['categorical_accuracy', top_5_accuracy])

In [None]:
history = model.fit(generator_train,
                                  initial_epoch=starting_epoch+epochs_top_layers,
                                  epochs=starting_epoch+epochs_top_layers+epochs_fine_tuning,
                                  steps_per_epoch=steps_train,
                                  class_weight=class_weight,
                                  validation_data=generator_validate,
                                  validation_steps=steps_validate,
                                  shuffle=True,
                                  callbacks = [tensorboard_callback, checkpoint_callback, cm_callback])

In [None]:
plot_history(history)

In [None]:
result = model.evaluate(generator_test, steps=steps_test)
print("Test-set classification accuracy: {0:.2%}".format(result[1]))

In [None]:
print("finished")