In [None]:
"""This module implements data feeding and training loop to create model
to classify X-Ray chest images as a lab example for BSU students.
"""

__author__ = 'Alexander Soroka, soroka.a.m@gmail.com'
__copyright__ = """Copyright 2020 Alexander Soroka"""


import argparse
import glob
import numpy as np
import tensorflow as tf
import time
from tensorflow.python import keras as keras
from tensorflow.python.keras.callbacks import LearningRateScheduler


LOG_DIR = 'logs'
SHUFFLE_BUFFER = 4
BATCH_SIZE = 64
NUM_CLASSES = 6
PARALLEL_CALLS=4
RESIZE_TO = 224
TRAINSET_SIZE = 14034
VALSET_SIZE = 3000


def parse_proto_example(proto):
    keys_to_features = {
        'image/encoded': tf.io.FixedLenFeature((), tf.string, default_value=''),
        'image/class/label': tf.io.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64))
    }
    example = tf.io.parse_single_example(proto, keys_to_features)
    example['image'] = tf.image.decode_jpeg(example['image/encoded'], channels=3)
    example['image'] = tf.image.convert_image_dtype(example['image'], dtype=tf.float32)
    example['image'] = tf.image.resize(example['image'], tf.constant([RESIZE_TO, RESIZE_TO]))
    return example['image'], tf.one_hot(example['image/class/label'], depth=NUM_CLASSES)


def normalize(image, label):
    return tf.image.per_image_standardization(image), label

def resize(image, label):
    return tf.image.resize(image, tf.constant([RESIZE_TO, RESIZE_TO])), label

def create_dataset(filenames, batch_size):
    """Create dataset from tfrecords file
    :tfrecords_files: Mask to collect tfrecords file of dataset
    :returns: tf.data.Dataset
    """
    return tf.data.TFRecordDataset(filenames)\
        .map(parse_proto_example)\
        .map(resize)\
        .map(normalize)\
        .batch(batch_size)\
        .prefetch(batch_size)

def create_aug_dataset(filenames, batch_size):
    return tf.data.TFRecordDataset(filenames)\
        .map(parse_proto_example)\
        .map(resize)\
        .map(normalize)\
        .map(augment)\
        .shuffle(buffer_size=5 * batch_size)\
        .batch(batch_size)\
        .prefetch(2 * batch_size)

def augment(image,label):
    with tf.name_scope('Add_gaussian_noise'):
        noise_img = image + tf.random.normal(shape=tf.shape(image), mean=0.0, stddev=0.1, dtype=tf.float32)
        noise_img = tf.clip_by_value(noise_img, -1.0, 1.0)
    return noise_img,label

def build_model():
    base_model = tf.keras.applications.MobileNetV2(
                                  include_top=False,
                                  weights='imagenet')
    global_average_layer = tf.keras.layers.GlobalAveragePooling2D()     
    prediction_layer = tf.keras.layers.Dense(6, activation=tf.keras.activations.softmax)

    inputs = tf.keras.Input(shape=(224, 224, 3))
    x = base_model(inputs, training=True)
    x = global_average_layer(x)
    x = tf.keras.layers.Dropout(0.4)(x)
    x = tf.keras.layers.Flatten()(x)
    outputs = prediction_layer(x)
    model = tf.keras.Model(inputs, outputs)

    return model

def main():
    train_path = '/content/drive/My Drive/SMOMI/dataset/train*'
    test_path = '/content/drive/My Drive/SMOMI/dataset/val*'

    train_dataset = create_aug_dataset(glob.glob(train_path), BATCH_SIZE)
    validation_dataset = create_aug_dataset(glob.glob(test_path), BATCH_SIZE)

    model = build_model()

    model.compile(
        optimizer=tf.optimizers.Adam(lr=3e-6),
        loss=tf.keras.losses.categorical_crossentropy,
        metrics=[tf.keras.metrics.categorical_accuracy],
    )

    weights_file = "/content/drive/My Drive/SMOMI/w-lab4_4-3.hdf5"
    log_dir='/content/drive/My Drive/SMOMI/{}/lab4_4-3-gug/ilcd-{}'.format(LOG_DIR, time.time())
    model.fit(
        train_dataset,
        epochs=100,
        validation_data=validation_dataset,
        callbacks=[
            tf.keras.callbacks.TensorBoard(log_dir),
            tf.keras.callbacks.ModelCheckpoint(filepath=weights_file, monitor='val_categorical_accuracy', mode='max', save_best_only=True, save_weights_only=True, verbose=1)
        ]
    )
if __name__ == '__main__':
    main()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
Epoch 1/100
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
    220/Unknown - 99s 448ms/step - loss: 1.7727 - categorical_accuracy: 0.1759
Epoch 00001: val_categorical_accuracy improved from -inf to 0.24633, saving model to /content/drive/My Drive/SMOMI/w-lab4_4-3.hdf5
Epoch 2/100
Epoch 00002: val_categorical_accuracy improved from 0.24633 to 0.32400, saving model to /content/drive/My Drive/SMOMI/w-lab4_4-3.hdf5
Epoch 3/100
Epoch 00003: val_categorical_accuracy improved from 0.32400 to 0.38933, saving model to /content/drive/My Drive/SMOMI/w-lab4_4-3.hdf5
Epoch 4/100
Epoch 00004: val_categorical_accuracy improved from 0.38933 to 0.43600, saving model to /content/drive/My Drive/SMOMI/w-lab4_4-3.hdf5
Epoch 5/100
Epoch 00005: val_categorical_accuracy improved from 0.43600 to 0.44867, saving model to /conte