In [1]:
import os
from PIL import Image
import numpy as np
import numpy as np
import tensorflow.keras
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
from tensorflow.keras.applications import imagenet_utils
from sklearn.metrics import confusion_matrix
import itertools
import tensorflow as tf
# import matplotlib.pyplot as plt

In [2]:
IMG_SIZE = 224
batch_size=32

In [3]:
def collect():
	train_datagen = ImageDataGenerator(
			rescale=1./255,
			shear_range=0.2,
			horizontal_flip=True, 
		)

	val_datagen = ImageDataGenerator(
			rescale=1./255,
			shear_range=0.2,
			horizontal_flip=True,		)

	train_generator = train_datagen.flow_from_directory(
	    directory="Dataset/dataset/train",
	    target_size=(IMG_SIZE, IMG_SIZE),
	    batch_size=32,
	    class_mode="categorical",
	    shuffle=True,
	    seed=42
	)

	val_generator = val_datagen.flow_from_directory(
	    directory="Dataset/dataset/val",
	    target_size=(IMG_SIZE, IMG_SIZE),
	    batch_size=32,
	    class_mode="categorical",
	    shuffle=True,
	    seed=42
	)
	return train_generator, val_generator

In [4]:
def save_model(model):
  model.save('eye_status_classifier.h5')

In [5]:
from tensorflow.keras.models import load_model

def load_pretrained_model():
    model = load_model('eye_status_classifier.h5')
    model.summary()
    return model

In [6]:
train_generator, test_generator = collect()

Found 3779 images belonging to 2 classes.
Found 1067 images belonging to 2 classes.


In [None]:
from keras import backend as K

def recall_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def precision_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def f1_m(y_true, y_pred):
    precision = precision_m(y_true, y_pred)
    recall = recall_m(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

In [7]:
base_model = tf.keras.applications.MobileNet(input_shape=(224, 224, 3), include_top=False,
                          weights='imagenet')
model = tf.keras.models.Sequential()
model.add(base_model)
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(Dense(64, activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.2))
model.add(Dense(2, activation='sigmoid'))
model.summary()

model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy', f1_m,precision_m, recall_m])

history=model.fit(train_generator,epochs=50,batch_size=batch_size,validation_data=test_generator)
# model.evaluate(test_generator)
model.save("model/eye_status_classifier.h5")

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
mobilenet_1.00_224 (Function (None, 7, 7, 1024)        3228864   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 64)                65600     
_________________________________________________________________
batch_normalization (BatchNo (None, 64)                256       
_________________________________________________________________
dropout (Dropout)            (None, 64)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 130       
Total params: 3,294,850
Trainable params: 3,272,834
Non-trainable params: 22,016
_________________________________________

In [None]:
import matplotlib.pyplot as plt
plt.style.use("ggplot")
plt.figure()
plt.plot(history.history['accuracy'],'r',label='training accuracy',color='green')
plt.plot(history.history['val_accuracy'],label='validation accuracy')
plt.xlabel('# epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig("model/mobilenet.png")
plt.show()