In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import (
    Dense,
    Conv2D,
    Flatten,
    Dropout,
    MaxPooling2D,
)

import os
import cv2
import matplotlib.pyplot as plt

PATH = "C:\\ImageSets\\SmallSetAugmentedNormalized8x8Old"
PREPROCESSED_IMAGES_PATH = "C:\\new_set_of_augmented_images"
BEFORE_AUGMENTATION_IMAGES_PATH = "C:\\images"
AUGMENTED_IMAGES_PATH = "C:\\ImageSets\\SmallSetAugmented\\test\\aug_other"

TEST_DIR = os.path.join(PATH, "test")
TRAIN_DIR = os.path.join(PATH, "train")

IMG_WIDTH = 64
IMG_HEIGHT = 64
EPOCHS = 1
BATCH_SIZE = 16
NO_OF_AUGMENTED_IMAGES = 600

In [None]:
def data_info():
    total_number_of_training_images = 0
    for root, dirs, files in os.walk(TRAIN_DIR):
        for file in files:
            total_number_of_training_images += 1

    total_number_of_testing_images = 0
    for root, dirs, files in os.walk(TEST_DIR):
        for file in files:
            total_number_of_testing_images += 1

    print("Total training images:", total_number_of_training_images)
    print("Total testing images:", total_number_of_testing_images, "\n")

    return total_number_of_training_images, total_number_of_testing_images

In [None]:
def augment_data(
    BEFORE_AUGMENTATION_IMAGES_PATH, AUGMENTED_IMAGES_PATH, NO_OF_AUGMENTED_IMAGES
):
    image_aug = ImageDataGenerator(
        rescale=1.0 / 255,
        rotation_range=5,
        brightness_range=[0.6, 1.4],
        channel_shift_range=50.0,
        zoom_range=[0.9, 1.1],
    )
    no_of_images = 0
    for batch in image_aug.flow_from_directory(
        batch_size=1,
        directory=BEFORE_AUGMENTATION_IMAGES_PATH,
        save_to_dir=AUGMENTED_IMAGES_PATH,
        save_prefix="aug",
        color_mode="rgb",
        target_size=(IMG_HEIGHT, IMG_WIDTH),
        class_mode="binary",
    ):
        no_of_images += 1
        if no_of_images >= NO_OF_AUGMENTED_IMAGES:
            break


# augment_data(BEFORE_AUGMENTATION_IMAGES_PATH, AUGMENTED_IMAGES_PATH, NO_OF_AUGMENTED_IMAGES)

In [None]:
def preprocess_images(path):
    destination_path = PATH
    for root, dirs, files in os.walk(path):
        for file in files:
            if "png" in file:
                index = root.rfind("\\t")
                image = cv2.imread(os.path.join(root, file))
                image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
                clahe = cv2.createCLAHE(clipLimit=0.1, tileGridSize=(8, 8))
                image = clahe.apply(image)
                print(destination_path + root[index:] + "\\" + file)
                cv2.imwrite(destination_path + root[index:] + "\\" + file, image)


# preprocess_images(AUGMENTED_IMAGES_PATH)

In [None]:
def generate_data():
    image_gen_train = ImageDataGenerator(rescale=1.0 / 255)
    train_data_gen = image_gen_train.flow_from_directory(
        batch_size=BATCH_SIZE,
        directory=TRAIN_DIR,
        shuffle=True,
        color_mode="grayscale",
        target_size=(IMG_HEIGHT, IMG_WIDTH),
        class_mode="categorical",
    )

    image_gen_val = ImageDataGenerator(rescale=1.0 / 255)
    val_data_gen = image_gen_val.flow_from_directory(
        batch_size=BATCH_SIZE,
        directory=TEST_DIR,
        shuffle=True,
        color_mode="grayscale",
        target_size=(IMG_HEIGHT, IMG_WIDTH),
        class_mode="categorical",
    )

    return train_data_gen, val_data_gen

In [None]:
def define_architecture():
    model_new = Sequential(
        [
            Conv2D(
                16,
                3,
                padding="same",
                activation="relu",
                input_shape=(IMG_HEIGHT, IMG_WIDTH, 1),
            ),
            MaxPooling2D(),
            Dropout(0.5),
            Conv2D(32, 3, padding="same", activation="relu"),
            MaxPooling2D(),
            Conv2D(64, 3, padding="same", activation="relu"),
            MaxPooling2D(),
            Dropout(0.5),
            Flatten(),
            Dense(512, activation="relu"),
            Dense(5, activation="sigmoid"),
        ]
    )
    model_new.summary()

    return model_new

In [None]:
def compile_and_train_cnn(
    model_new, train_data_gen, total_train, val_data_gen, total_val
):
    model_new.compile(
        optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
    )

    information = model_new.fit_generator(
        train_data_gen,
        steps_per_epoch=total_train,
        epochs=EPOCHS,
        validation_data=val_data_gen,
        validation_steps=total_val,
    )

    return information

In [None]:
def sub_plot(epochs_range, first_val, second_val, label, plt):
    if label == "Accuracy":
        plt.subplot(1, 2, 1)
    else:
        plt.subplot(1, 2, 2)
    plt.plot(epochs_range, first_val, label="Training " + label)
    plt.plot(epochs_range, second_val, label="Validation " + label)
    plt.legend(loc="lower right")
    plt.title(label)

In [None]:
def plot_graphs(information):
    acc = information.history["accuracy"]
    val_acc = information.history["val_accuracy"]
    loss = information.history["loss"]
    val_loss = information.history["val_loss"]

    epochs_range = range(EPOCHS)
    plt.figure(figsize=(8, 8))

    sub_plot(epochs_range, acc, val_acc, "Accuracy", plt)
    sub_plot(epochs_range, loss, val_loss, "Loss", plt)

    plt.show()

In [None]:
def convert_model(model_new):
    model_new.save("model")
    converter = tf.lite.TFLiteConverter.from_saved_model("model")
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_quantized_model = converter.convert()
    open("model.tflite", "wb").write(tflite_quantized_model)

In [None]:
def main():
    total_train, total_val = data_info()
    print(total_train, total_val)
    train_data_gen, val_data_gen = generate_data()
    model_new = define_architecture()
    information = compile_and_train_cnn(
        model_new, train_data_gen, total_train, val_data_gen, total_val
    )
    plot_graphs(information)
    convert_model(model_new)


main()