In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import cv2

from keras.models import Sequential
from keras.layers.core import Dense, Flatten, Dropout, Lambda
from keras.layers.convolutional import Conv2D
from keras.layers import MaxPooling2D
from keras.optimizers import Adam
from keras.utils import to_categorical
from sklearn.utils import shuffle

from sklearn.model_selection import train_test_split

%matplotlib inline

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


# Helper functions 

In [2]:
def get_detected_traffic_lights(boxes, scores, classes):
    """
    filter traffic lights from detected objects
    
    Args:
        boxes (np): coordinates of objects in image
        scores (np): confidence factor for detected objects
        classes (np): class for each object
    Returns:
        filtered_boxes (np): coordinates of traffic lights in image
        filtered_scores (np): confidence factors for detected traffic lights
    """
    
    #get all detected traffic lights with a score above 0.1
    inds = ((classes == 10) & (scores >= 0.1)).nonzero()
    
    filtered_boxes = boxes[inds[0], ...]
    filtered_scores = scores[inds[0], ...]
    return filtered_boxes, filtered_scores
        
def crop_detection(image, boxes, path, img):
    """
    cut traffic lights for image and save them to folder given by path
    
    Args:
        image : image containing detected objects
        boxes (np): coordinates of objects in image
        path (str): path for saving traffic lights
        img (str): path + name of image
    """
    height = image.shape[1]
    width = image.shape[2]
        
    # Box coordinates are scaled and need to be converted
    # into image coordinates
    boxes[:, 0] *= height
    boxes[:, 1] *= width
    boxes[:, 2] *= height
    boxes[:, 3] *= width
        
    # cropping needs integers 
    boxes = boxes.astype(int)
        
    name = img.split('\\')[1]
    name = name.split('.')[0]
        
    # loop over every detected traffic light
    for i in range(len(boxes)):
        bot, left, top, right = boxes[i,...]

        cropped = image[0,bot:top,left:right]
            
        # only needed for testing
        save_img = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR)
        cv2.imwrite(path+'/'+name+'_'+str(i)+'_cropped.jpg',save_img)
        #plt.imshow(cropped)

# Graph functions

Frozen graph of [SSD MobileNet](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md) pretrained on COCO provided by [Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection)

In [3]:
def load_SSD_MobileNet(path_frozen_model):
    """
    load frozen graph
    
    Args:
        path_frozen_model (str): path to frozen graph
    Returns:
        graph
        image_tensor: input tensor for graph
        detection_boxes: output tensor for coordinates of detected objects
        detection_scores: output tensor for confidence factor of detected objects
        detection_classes: output tensor for detected classes
    """
    #loads a frozen SSD MobileNet
    graph = tf.Graph()
    with graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(path_frozen_model, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
            
        # get Tensors for input image + boxes for detected objects + classification for each box
        # + scores for how certain the classifier is
        image_tensor = graph.get_tensor_by_name('image_tensor:0')
        detection_boxes = graph.get_tensor_by_name('detection_boxes:0')
        detection_classes = graph.get_tensor_by_name('detection_classes:0')
        detection_scores = graph.get_tensor_by_name('detection_scores:0')       
    return graph, image_tensor, detection_boxes,detection_scores,detection_classes

In [4]:
def detect_objects_in_single_img(detection_graph, image_tensor, detection_boxes,\
                                detection_scores, detection_classes, image):
    """
    use loaded and pretrained graph for object detection
    
    Args:
        detection_graph
        image_tensor: input tensor for graph
        detection_boxes: output tensor for coordinates of detected objects
        detection_scores: output tensor for confidence factor of detected objects
        detection_classes: output tensor for detected classes
    Returns:
        boxes (np): coordinates of objects in image
        scores (np): confidence factor for detected objects
        classes (np): class for each object
    """
  
    with tf.Session(graph=detection_graph) as sess:                  
        (boxes, scores, classes) = sess.run([detection_boxes, detection_scores, detection_classes], 
                                            feed_dict={image_tensor: image})

        # Remove unnecessary dimensions
        boxes = np.squeeze(boxes)
        scores = np.squeeze(scores)
        classes = np.squeeze(classes)

        return boxes, scores, classes


# Run graph and detect objects in images

In [5]:
frozen_model='frozen_inference_graph.pb' #path to frozen model
test_images= glob.glob('images_for_detection/*.jpg') #get path to images
path_cropped_images = 'images_for_detection/cropped' #path for saving the images

In [6]:
#load frozen model
detection_graph,image_tensor,detection_boxes,detection_scores,detection_classes = load_SSD_MobileNet(frozen_model)
for img in test_images: 
    image = cv2.imread(img) #read image
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) #convert to RGB
    image = np.expand_dims(image, 0) #extra dim needed by graph
    boxes,scores,classes = detect_objects_in_single_img(detection_graph, image_tensor, detection_boxes,\
                                                        detection_scores, detection_classes, image)
    #get images of traffic lights
    boxes, scores = get_detected_traffic_lights(boxes, scores, classes)
    crop_detection(image, boxes,path_cropped_images, img) #save images of traffic lights