In [None]:
import numpy as np
import tensorflow as tf
import math, cv2, os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
%matplotlib inline
import matplotlib.image  as mpimg
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from random import randint
from sklearn.utils import shuffle

print(tf.__version__)

train_dataset_path = "./dataset/seg_train"
test_dataset_path = "./dataset/seg_test"
pred_dataset_path = "./dataset/seg_pred"
checkpoint_path = "./checkpoint/cp.ckpt"

train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        train_dataset_path,
        target_size=(150, 150),
        color_mode='rgb',
        batch_size=32,
        class_mode="categorical",
        shuffle=True,
        seed=None)

test_generator = test_datagen.flow_from_directory(
        test_dataset_path,
        target_size=(150, 150),
        color_mode='rgb',
        batch_size=32,
        class_mode="categorical",
        shuffle=True,
        seed=None)

In [None]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(18, (3, 3), input_shape=(150, 150, 3)),
    
    tf.keras.layers.Conv2D(kernel_size=3, filters=12, use_bias=False, padding='valid'),
    tf.keras.layers.BatchNormalization(center=True, scale=False),
    tf.keras.layers.Activation('relu'),

    tf.keras.layers.Conv2D(kernel_size=3, filters=18, use_bias=False, padding='valid', strides=2),
    tf.keras.layers.BatchNormalization(center=True, scale=False),
    tf.keras.layers.Activation('relu'),
    
    tf.keras.layers.Conv2D(kernel_size=6, filters=24, use_bias=False, padding='valid', strides=2),
    tf.keras.layers.BatchNormalization(center=True, scale=False),
    tf.keras.layers.Activation('relu'),

    tf.keras.layers.Conv2D(kernel_size=6, filters=32, use_bias=False, padding='valid', strides=2),
    tf.keras.layers.BatchNormalization(center=True, scale=False),
    tf.keras.layers.Activation('relu'),

    tf.keras.layers.Flatten(),

    tf.keras.layers.Dense(60, use_bias=False),
    tf.keras.layers.BatchNormalization(center=True, scale=False),
    tf.keras.layers.Activation('relu'),

    tf.keras.layers.Dropout(0.4),
    tf.keras.layers.Dense(6, activation='softmax')
])

model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.01),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.summary()

In [None]:
def lr_decay(epoch):
  return 0.01 * math.pow(0.666, epoch)
lr_decay_callback = tf.keras.callbacks.LearningRateScheduler(lr_decay, verbose=True)

class myCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):
    if(logs.get('acc')>0.95):
      print("\nReached 95% accuracy so cancelling training!")
      self.model.stop_training = True

callbacks_max_acc = myCallback()

def get_images(directory):
    Images = []
        
    for image_file in all_image_paths:
        image=cv2.imread(directory+'/'+image_file)
        image=cv2.resize(image,(150,150))
        Images.append(image)
    
    return shuffle(Images,random_state=817328462)

checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,save_weights_only=True,verbose=1)

# model.load_weights(checkpoint_path) # Load

In [None]:
history = model.fit_generator(
      train_generator,
      epochs=15,
      validation_data=test_generator,
      callbacks=[lr_decay_callback, callbacks_max_acc])

In [None]:
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

In [None]:
all_image_paths=os.listdir(pred_dataset_path)

print(all_image_paths[:10])

pred_images = get_images(pred_dataset_path)
pred_images = np.array(pred_images)
pred_images.shape

In [None]:
fig = plt.figure(figsize=(30, 30))
outer = gridspec.GridSpec(5, 5, wspace=0.2, hspace=0.2)

for i in range(25):
    inner = gridspec.GridSpecFromSubplotSpec(2, 1,subplot_spec=outer[i], wspace=0.1, hspace=0.1)
    rnd_number = randint(0,len(pred_images))
    pred_image = np.array([pred_images[rnd_number]])
    pred_prob = model.predict(pred_image).reshape(6)
    for j in range(2):
        if (j%2) == 0:
            ax = plt.Subplot(fig, inner[j])
            ax.imshow(pred_image[0])
            ax.set_xticks([])
            ax.set_yticks([])
            fig.add_subplot(ax)
        else:
            ax = plt.Subplot(fig, inner[j])
            ax.bar([0,1,2,3,4,5],pred_prob)
            fig.add_subplot(ax)


fig.show()