diff --git a/iSeefood/SeefoodAI.py b/iSeefood/SeefoodAI.py index 9575573..7c4cbbf 100644 --- a/iSeefood/SeefoodAI.py +++ b/iSeefood/SeefoodAI.py @@ -20,15 +20,16 @@ import numpy as np import tensorflow as tf from PIL import Image +import os + class SeefoodAI(object): - # Single private instance + # Single private instance __instance = None - - # Initilize global variables. - global sess, class_scores, x_input, keep_prob - - # Initilize an AI object to be running + # Initilize global variables. + global sess, class_scores, x_input, keep_prob, scores + + # Initilize an AI object to be running def __init__(self): ''' Virtually private constructor ''' if SeefoodAI.__instance != None: @@ -37,25 +38,24 @@ def __init__(self): SeefoodAI.__instance = self SeefoodAI.__instance.__setup() print "+ Seefood AI instance has been created!" - - + @staticmethod def getInstance(): ''' Static access method ''' if SeefoodAI.__instance is None: SeefoodAI() return SeefoodAI.__instance - def __setup(self): ''' Setting-up the SeefoodAI instance''' - # try initializing the AI instance attrs, catch possible errors. - # TODO: Make it pretty :) - global sess, class_scores, x_input, keep_prob + # try initializing the AI instance attrs, catch possible errors. + # TODO: Make it pretty :) + global sess, class_scores, x_input, keep_prob try: - sess = tf.Session() - saver = tf.train.import_meta_graph('saved_model/model_epoch5.ckpt.meta') + sess = tf.Session() + saver = tf.train.import_meta_graph( + 'saved_model/model_epoch5.ckpt.meta') saver.restore(sess, tf.train.latest_checkpoint('saved_model/')) graph = tf.get_default_graph() x_input = graph.get_tensor_by_name('Input_xn/Placeholder:0') @@ -67,11 +67,15 @@ def __setup(self): print '++++++ [No errors occured during initialization +++++' print("+ Setting up instance ....") - + def process(self, image_path): '''TODO: Accept file path ''' - global sess, class_scores, x_input, keep_prob - # Open passed image, then convert it to RGB + + if not self.pathValidation(image_path): # Validate given path. + return -1 + + global sess, class_scores, x_input, keep_prob + # Open passed image, then convert it to RGB image = Image.open(image_path).convert('RGB') # Resize image to 227x227 image = image.resize((227, 227), Image.BILINEAR) @@ -79,26 +83,50 @@ def process(self, image_path): img_tensor = [np.asarray(image, dtype=np.float32)] print '+ Looking for food in ' + image_path + ' ...... ' - #Run the image in the model. + # Run the image in the model. scores = sess.run(class_scores, {x_input: img_tensor, keep_prob: 1.}) - - print '+ Statistics: ', scores + self.__setScores(scores) # Set score variable + print '+ Statistics: ', self.getScores() # Calculate score and display result - self.__scoresCalculation(scores) + + self.__setScores(scores) print("_________ Analyzing image.... Done! __________") def pathValidation(self, filePath): ''' TODO: Validate given file path ''' - if isinstance(filePath, basestring): + if isinstance(filePath, basestring): # Verify that instance is a string type & !empty. + if self.checkFileExtension(filePath) and self.directoryExist(filePath): # Verify path existance + print '+ Path validation ... ' + # check if the path end with .png || .jpg return True else: return False - def __scoresCalculation(self, scores): + def directoryExist(self, filePath): + ''' Verify the existance of the given path ''' + if os.path.exists(filePath): + return True + return False + + def checkFileExtension(self, filePath): + ''' Verify that the given path points to an image file (png, jpg) ''' + if filePath.endswith('.png') or filePath.endswith('.jpg'): + return True + else: + return False + + def __setScores(self, stat): ''' TODO: Optimaze and generate a final score''' + global scores + scores = stat # if np.argmax = 0; then the first class_score was higher, e.g., the model sees food. # if np.argmax = 1; then the second class_score was higher, e.g., the model does not see food. if np.argmax(scores) == 1: print "+ Result: Oops! No food here... :( " else: print "+ Results: YAY! I see food! :)" + + def getScores(self): + ''' Return last analyzed image stat. ''' + global scores + return scores