<a href="https://colab.research.google.com/github/SMatusik/fruit_detection_recognition/blob/main/fruit_detection_recognition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
'''
README
1st: put this script inside TensorFlow Object Detection API library directory
2nd: place folder inference_graph in this directory
3rd: install all below listed libraries
4th: convert models from h5 to TensorFlow format using export_keras_model.py
5th: put converted models in location written in Server class
6th: choose function: webcam_detection or image_detection and write down location of files
in these functions
'''

#importing needed libraries
import os
import sys
import numpy as np
import tarfile
import zipfile
import six.moves.urllib as urllib
from io import StringIO
from matplotlib import pyplot as plt
from collections import defaultdict

#importing deep learning libraries
import tensorflow as tf
from keras import models

#importing TFServing
import requests
import argparse
import json
import signal
import subprocess
import time

#importing image analysis libraries
import cv2
from PIL import Image
from keras.preprocessing import image
from IPython.display import clear_output

#import TensorFlow Object Detection API Libraries
from utils import label_map_util
from utils import visualization_utils as vis_util

# check if video stream is already opened
# '0' for webcam detection and file path for video file
file = 'file_name.avi'
if 'cap' in globals():
      cap.release() 
cap = cv2.VideoCapture(0)

# change system path
sys.path.append("..")

# function which is used for limitting values in range
def limit(value, max_val, min_val):
    if(value > max_val):
        value = max_val
    elif(value < min_val):
        value = min_val      
    return value

# name of a model, path to checkpoint file, path to labels, number of classes
MODEL_NAME = 'inference_graph'
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
PATH_TO_LABELS = 'training/labelmap.pbtxt'
NUM_CLASSES = 6


# building a graph for inference
detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')
    
# loading label maps, converting them to categories and indexes
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

class Server():
    '''
    class which represents tensorflow server which handles the model
    __init__ - creating an object with port and name of a model
    startServer - starting a server. Loading a model to memory and starting serving a model
    shutdownServer - closing a server to be sure that process is killed
    change directory for your directory with models
    '''
    def __init__(self, name, port):
        self.tf_server = 0
        self.name = name
        self.port = port
    def startServer(self):
        try:
            self.tf_server = subprocess.Popen(["tensorflow_model_server "
                                     "--model_base_path=/home/sebastian/Servers/" + self.name + " "
                                     "--rest_api_port=" + str(self.port)+ " --model_name=" + self.name],
                                    stdout=subprocess.DEVNULL,
                                    shell=True,
                                    preexec_fn=os.setsid)
            print("Started TensorFlow Serving " + self.name + " server")
        except KeyboardInterrupt:
            print("Exception! Shutting down" + self.name + " servers")
            os.killpg(os.getpgid(self.tf_server.pid), signal.SIGTERM)
            print("Server " + self.name + " successfully shutdown")
        
    def shutdownServer(self):
        os.killpg(os.getpgid(self.tf_server.pid), signal.SIGTERM)
        print("Server " + self.name + " successfully shutdown")
    


class boxPrediction:
    '''
    class which stands for box prediction - we want to enlarge the bounding box to be sure
    that whole fruit is passed to neural network
    __init__ - building an object with values of box
    preprareBox - actually enlarging the bounding box and keeping coordinates
                    in object properties
    '''
    def __init__(self, xmin, xmax, ymin, ymax):
        self.xmin = xmin
        self.xmax = xmax
        self.ymin = ymin
        self.ymax = ymax
        self.prediction = 0
    def prepareBox(self, bbox, compenser, im_width, im_height):
        ymin, xmin, ymax, xmax = bbox
    
        ymin = int(int(im_height * ymin) - compenser)
        ymax = int(int(im_height * ymax) + compenser)
        xmin = int(int(im_width * xmin) - compenser)
        xmax = int(int(im_width * xmax) + compenser)
                    
        self.ymin = limit(ymin, im_height, 0)
        self.ymax = limit(ymax, im_height, 0)
        self.xmax = limit(xmax, im_width, 0)
        self.xmin = limit(xmin, im_width, 0)
        self.prediction = 0

def prepareImage(image1, reshape):
    # function which converts image to right format and then to array
    img = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
    img = Image.fromarray(img)
    img = image.img_to_array(img)
    if reshape:
        img = img.reshape((1,) + img.shape)
        img = img/255
    else:
        img = img/255
    
    return img

def predict_class(prediction):
    #function which decides what class is intended for fruit
    if prediction > 0.5:
        return "rotten"
    else:
        return "fresh"
    
def visualize_results(img, apples, bananas, oranges, pears, peppers, tomatoes):
    '''
    function which is used to visualise results of predictions on an input image
    '''
    if len(apples)>0:
        for obiekt in apples:
            cv2.putText(img, predict_class(obiekt.prediction) + " " + str(obiekt.prediction), 
                                (obiekt.xmin+20,obiekt.ymin+30), 
                                cv2.FONT_HERSHEY_DUPLEX, 0.6, (255, 255, 0))
    if len(bananas)>0:  
        for obiekt in bananas:
            cv2.putText(img, predict_class(obiekt.prediction) + " " + str(obiekt.prediction), 
                                (obiekt.xmin+20,obiekt.ymin+30), 
                                cv2.FONT_HERSHEY_DUPLEX, 0.6, (255, 0, 0))
    if len(oranges)>0:  
        for obiekt in oranges:
            cv2.putText(img, predict_class(obiekt.prediction) + " " + str(obiekt.prediction), 
                                (obiekt.xmin+20,obiekt.ymin+30), 
                                cv2.FONT_HERSHEY_DUPLEX, 0.6, (19, 69, 139))
    if len(pears)>0:  
        for obiekt in pears:
            cv2.putText(img, predict_class(obiekt.prediction) + " " + str(obiekt.prediction), 
                                (obiekt.xmin+20,obiekt.ymin+30), 
                                cv2.FONT_HERSHEY_DUPLEX, 0.6, (0, 0, 255))   
    if len(peppers)>0:  
        for obiekt in peppers:
            cv2.putText(img, predict_class(obiekt.prediction) + " " + str(obiekt.prediction), 
                                (obiekt.xmin+20,obiekt.ymin+30), 
                                cv2.FONT_HERSHEY_DUPLEX, 0.6, (0, 255, 255))
    if len(tomatoes)>0:  
        for obiekt in tomatoes:
            cv2.putText(img, predict_class(obiekt.prediction) + " " + str(obiekt.prediction), 
                                (obiekt.xmin+20,obiekt.ymin+30), 
                                cv2.FONT_HERSHEY_DUPLEX, 0.6, (0, 255, 255))
                                
    cv2.imshow('detection', img)
                    
                    
def freshness_recognition(boxes, classes, scores, image_np_copy, compenser):
    # 1 - apple, 2 - banana, 3 - orange, 4 - pear, 5 - pepper, 6 - tomato
    # funtion which is used to make a server request for classification
    min_score_thresh = 0.65
    bboxes = boxes[scores > min_score_thresh]
    bclasses = classes[scores > min_score_thresh]
    bscores = scores[scores > min_score_thresh]

    image_np_new = cv2.resize(image_np_copy, (800,600))
    im_width, im_height = (800, 600)
    apples = []
    bananas = []
    oranges = []
    pears = []
    peppers = []
    tomatoes = []
    if bclasses.size > 0:
        for bclass in enumerate(bclasses):
            if(bclass[1] == 1.0): #if any of detected classes stands for apple
                    obiekt = boxPrediction(0, 0, 0, 0)
                    obiekt.prepareBox(bboxes[bclass[0]], 10, im_width, im_height)

                    image_cropped = image_np_new[obiekt.ymin:obiekt.ymax, 
                                                 obiekt.xmin:obiekt.xmax]
                    image_cropped = cv2.resize(image_cropped, (200, 200))  
                    img = prepareImage(image_cropped, 0)
                    
                    payload = {
                        "instances": [{'input_image': img.tolist()}]
                        }
                    req = requests.post('http://localhost:9000/v1/models/apple:predict', json = payload)
                    pred = json.loads(req.content.decode('utf-8'))
                    pred = round(pred['predictions'][0][0], 3)

                    obiekt.prediction = pred
                    apples.append(obiekt)
            
            elif(bclass[1] == 2.0): #if any of detected classes stands for 'banana'
                    obiekt = boxPrediction(0, 0, 0, 0)
                    obiekt.prepareBox(bboxes[bclass[0]], 10, im_width, im_height)
        
                    image_cropped = image_np_new[obiekt.ymin:obiekt.ymax, obiekt.xmin:obiekt.xmax]
                    height, width, _ = image_cropped.shape
                    if height > width:
                        image_cropped = cv2.resize(image_cropped, (150, 200))
                        image_cropped = cv2.rotate(image_cropped, cv2.ROTATE_90_CLOCKWISE)
                    else:
                        image_cropped = cv2.resize(image_cropped, (200, 150))
                    img = prepareImage(image_cropped, 0)
                    # cv2.imshow("eyy its a banana!", image_cropped)

                    payload = {
                        "instances": [{'input_image': img.tolist()}]
                        }
                    req = requests.post('http://localhost:9001/v1/models/banana:predict', json = payload)
                    pred = json.loads(req.content.decode('utf-8'))
                    pred = round(pred['predictions'][0][0], 3)

                    obiekt.prediction = pred
                    bananas.append(obiekt)
            if(bclass[1] == 3.0): #if any of detected classes stands for orange
                    obiekt = boxPrediction(0, 0, 0, 0)
                    obiekt.prepareBox(bboxes[bclass[0]], 10, im_width, im_height)

                    image_cropped = image_np_new[obiekt.ymin:obiekt.ymax, 
                                                 obiekt.xmin:obiekt.xmax]
                    image_cropped = cv2.resize(image_cropped, (200, 200))  
                    img = prepareImage(image_cropped, 0)
                    
                    payload = {
                        "instances": [{'input_image': img.tolist()}]
                        }
                    req = requests.post('http://localhost:9002/v1/models/orange:predict', json = payload)
                    pred = json.loads(req.content.decode('utf-8'))
                    pred = round(pred['predictions'][0][0], 3)

                    obiekt.prediction = pred
                    oranges.append(obiekt)        
            elif(bclass[1] == 4.0): #if any of detected classes stands for 'pear'
                    obiekt = boxPrediction(0, 0, 0, 0)
                    obiekt.prepareBox(bboxes[bclass[0]], 10, im_width, im_height)
        
                    image_cropped = image_np_new[obiekt.ymin:obiekt.ymax, obiekt.xmin:obiekt.xmax]
                    height, width, _ = image_cropped.shape
                    if height > width:
                        image_cropped = cv2.resize(image_cropped, (150, 200))
                        image_cropped = cv2.rotate(image_cropped, cv2.ROTATE_90_CLOCKWISE)
                    else:
                        image_cropped = cv2.resize(image_cropped, (200, 150))
                    img = prepareImage(image_cropped, 0)
                    # cv2.imshow("eyy its a banana!", image_cropped)

                    payload = {
                        "instances": [{'input_image': img.tolist()}]
                        }
                    req = requests.post('http://localhost:9003/v1/models/pear:predict', json = payload)
                    pred = json.loads(req.content.decode('utf-8'))
                    pred = round(pred['predictions'][0][0], 3)

                    obiekt.prediction = pred
                    pears.append(obiekt)
            elif(bclass[1] == 5.0): #if any of detected classes stands for 'pepper'
                    obiekt = boxPrediction(0, 0, 0, 0)
                    obiekt.prepareBox(bboxes[bclass[0]], 10, im_width, im_height)
        
                    image_cropped = image_np_new[obiekt.ymin:obiekt.ymax, obiekt.xmin:obiekt.xmax]
                    height, width, _ = image_cropped.shape
                    if height > width:
                        image_cropped = cv2.resize(image_cropped, (150, 200))
                        image_cropped = cv2.rotate(image_cropped, cv2.ROTATE_90_CLOCKWISE)
                    else:
                        image_cropped = cv2.resize(image_cropped, (200, 150))
                    img = prepareImage(image_cropped, 0)
                    
                    payload = {
                        "instances": [{'input_image': img.tolist()}]
                        }
                    req = requests.post('http://localhost:9004/v1/models/pepper:predict', json = payload)
                    pred = json.loads(req.content.decode('utf-8'))
                    pred = round(pred['predictions'][0][0], 3)

                    obiekt.prediction = pred
                    peppers.append(obiekt)
            elif(bclass[1] == 6.0): #if any of detected classes stands for apple
                    obiekt = boxPrediction(0, 0, 0, 0)
                    obiekt.prepareBox(bboxes[bclass[0]], 10, im_width, im_height)

                    image_cropped = image_np_new[obiekt.ymin:obiekt.ymax, 
                                                 obiekt.xmin:obiekt.xmax]
                    image_cropped = cv2.resize(image_cropped, (200, 200))  
                    img = prepareImage(image_cropped, 0)

                    payload = {
                        "instances": [{'input_image': img.tolist()}]
                        }
                    req = requests.post('http://localhost:9005/v1/models/tomato:predict', json = payload)
                    pred = json.loads(req.content.decode('utf-8'))
                    pred = round(pred['predictions'][0][0], 3)

                    obiekt.prediction = pred
                    tomatoes.append(obiekt)
    return apples, bananas, oranges, pears, peppers, tomatoes



def webcam_detection():
    # function which is used for webcam/video analysis

    with detection_graph.as_default():
        with tf.Session(graph=detection_graph) as sess:
            while (cap.isOpened()):
                ret, image_np = cap.read()
                if ret == 0:
                    break
                image_np = cv2.resize(image_np, (800, 600), interpolation = cv2.INTER_AREA)
                image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
                image_np_expanded = np.expand_dims(image_rgb, axis=0)
                
                image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
                boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
                scores = detection_graph.get_tensor_by_name('detection_scores:0')
                classes = detection_graph.get_tensor_by_name('detection_classes:0')
                num_detections = detection_graph.get_tensor_by_name('num_detections:0')

                (boxes, scores, classes, num_detections) = sess.run(
                  [boxes, scores, classes, num_detections],
                  feed_dict={image_tensor: image_np_expanded})

                image_np_copy = image_np.copy()

                vis_util.visualize_boxes_and_labels_on_image_array(
                    image_np,
                    np.squeeze(boxes),
                    np.squeeze(classes).astype(np.int32),
                    np.squeeze(scores),
                    category_index,
                    use_normalized_coordinates=True,
                    line_thickness=8,
                    min_score_thresh=0.6)     
    
                apples, bananas, oranges, pears, peppers, tomatoes = freshness_recognition(boxes, 
                                                                           classes, scores, 
                                                                           image_np_copy, 30)
                
                visualize_results(image_np, apples, bananas, oranges, pears, peppers, tomatoes)
                
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
                
            cv2.destroyAllWindows()
            cap.release()  
       

def image_detection():
    #function which is used for image analysis
    #obligatory use another PATH_TO_IMAGE
     with detection_graph.as_default():
        with tf.Session(graph=detection_graph) as sess:
                # Input tensor is the image
                PATH_TO_IMAGE = "evaluate_final/img_multiple001.jpg"
                start_time = time.time()
                image_np = cv2.imread(PATH_TO_IMAGE)
                image_np = cv2.resize(image_np, (800, 600), interpolation = cv2.INTER_AREA)
                image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
                image_np_expanded = np.expand_dims(image_rgb, axis=0)


                image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
                boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
                scores = detection_graph.get_tensor_by_name('detection_scores:0')
                classes = detection_graph.get_tensor_by_name('detection_classes:0')
                num_detections = detection_graph.get_tensor_by_name('num_detections:0')

                (boxes, scores, classes, num_detections) = sess.run(
                      [boxes, scores, classes, num_detections],
                      feed_dict={image_tensor: image_np_expanded})

                image_np_copy = image_np.copy()

                vis_util.visualize_boxes_and_labels_on_image_array(
                        image_np,
                        np.squeeze(boxes),
                        np.squeeze(classes).astype(np.int32),
                        np.squeeze(scores),
                        category_index,
                        use_normalized_coordinates=True,
                        line_thickness=8,
                        min_score_thresh=0.65)

                apples, bananas, oranges, pears, peppers, tomatoes = freshness_recognition(boxes, 
                                                                  classes, scores, 
                                                                  image_np_copy, 30)
                
                
                while True:
                    visualize_results(image_np, apples, bananas, oranges, pears, peppers, tomatoes)
                    
                    if cv2.waitKey(25) & 0xFF == ord('q'):

                        cv2.destroyAllWindows()
                        break
                


server_apple = Server("apple", 9000)
server_banana = Server("banana", 9001)
server_orange = Server("orange", 9002)
server_pear = Server("pear", 9003)
server_pepper = Server("pepper", 9004)
server_tomato = Server("tomato", 9005)
    
server_apple.startServer()
server_banana.startServer()
server_orange.startServer()
server_pear.startServer()
server_pepper.startServer()
server_tomato.startServer()
   
# choose what function you need 
#image_detection()
webcam_detection()



server_apple.shutdownServer()
server_banana.shutdownServer()
server_orange.shutdownServer()
server_pear.shutdownServer()
server_pepper.shutdownServer()
server_tomato.shutdownServer()



Started TensorFlow Serving apple server
Started TensorFlow Serving banana server
Started TensorFlow Serving orange server
Started TensorFlow Serving pear server
Started TensorFlow Serving pepper server
Started TensorFlow Serving tomato server
Server apple successfully shutdown
Server banana successfully shutdown
Server orange successfully shutdown
Server pear successfully shutdown
Server pepper successfully shutdown
Server tomato successfully shutdown
