In [1]:
import tensorflow as tf

  from ._conv import register_converters as _register_converters


In [2]:
import os
import numpy as np

In [3]:
import config
import utils

In [4]:
detection_graph = tf.Graph()

In [5]:
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(config.mask_model_infer_path,mode='rb') as graph_file:
        serialized_graph = graph_file.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def)

In [6]:
class_map = utils.get_class_map(config.class_map_file)

In [7]:
img_names = utils.get_dir_images(config.test_imgs_dir)

In [8]:
def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,
                                     image_width):
    
    """Transforms the box masks back to full image masks.
    Embeds masks in bounding boxes of larger masks whose shapes correspond to
    image shape.
    Args:
    box_masks: A tf.float32 tensor of size [num_masks, mask_height, mask_width].
    boxes: A tf.float32 tensor of size [num_masks, 4] containing the box
           corners. Row i contains [ymin, xmin, ymax, xmax] of the box
           corresponding to mask i. Note that the box corners are in
           normalized coordinates.
    image_height: Image height. The output mask will have the same height as
                  the image height.
    image_width: Image width. The output mask will have the same width as the
                 image width.
    Returns:
    A tf.float32 tensor of size [num_masks, image_height, image_width].
    """
  
    # TODO(rathodv): Make this a public function.
    def reframe_box_masks_to_image_masks_default():
        
        """The default function when there are more than 0 box masks."""
        
        def transform_boxes_relative_to_boxes(boxes, reference_boxes):
            boxes = tf.reshape(boxes, [-1, 2, 2])
            min_corner = tf.expand_dims(reference_boxes[:, 0:2], 1)
            max_corner = tf.expand_dims(reference_boxes[:, 2:4], 1)
            transformed_boxes = (boxes - min_corner) / (max_corner - min_corner)
            return tf.reshape(transformed_boxes, [-1, 4])

        box_masks_expanded = tf.expand_dims(box_masks, axis=3)
        num_boxes = tf.shape(box_masks_expanded)[0]
        unit_boxes = tf.concat(
            [tf.zeros([num_boxes, 2]), tf.ones([num_boxes, 2])], axis=1)
        reverse_boxes = transform_boxes_relative_to_boxes(unit_boxes, boxes)
    
        return tf.image.crop_and_resize(
            image=box_masks_expanded,
            boxes=reverse_boxes,
            box_ind=tf.range(num_boxes),
            crop_size=[image_height, image_width],
            extrapolation_value=0.0)
    
    image_masks = tf.cond(
      tf.shape(box_masks)[0] > 0,
      reframe_box_masks_to_image_masks_default,
      lambda: tf.zeros([0, image_height, image_width, 1], dtype=tf.float32))
    
    return tf.squeeze(image_masks, axis=3)

In [9]:
with tf.Session(graph = detection_graph) as sess:
    
    for img_name in img_names:
       
        img_path = os.path.join(config.test_imgs_dir,img_name)
        test_img = utils.load_image(img_path)
        
        input_ = sess.graph.get_tensor_by_name("import/image_tensor:0")
        boxes = sess.graph.get_tensor_by_name("import/detection_boxes:0")
        scores = sess.graph.get_tensor_by_name("import/detection_scores:0")
        classes = sess.graph.get_tensor_by_name("import/detection_classes:0")
        masks = sess.graph.get_tensor_by_name("import/detection_masks:0")
        num_detections = sess.graph.get_tensor_by_name("import/num_detections:0")
    
        #masks = reframe_box_masks_to_image_masks(masks[0],boxes[0],test_img.shape[0],test_img.shape[1])
        
        (boxes, scores, classes, masks, num_detections) = sess.run([boxes,scores,
                                                                  classes,masks,
                                                                  num_detections],
                                                                  feed_dict = {
                                                                      input_:test_img
                                                                  }
                                                                ) 
        boxes = np.squeeze(boxes,axis=0)
        scores = np.squeeze(scores,axis=0)
        classes = np.squeeze(classes,axis=0)
        masks = np.squeeze(masks,axis=0)
        test_img = np.squeeze(test_img,axis=0)

        detections = utils.get_detections(scores,config.threshold_score)

        utils.draw_bounding_box(test_img,detections,boxes,classes,class_map,masks)
         
        save_path = os.path.join(config.result_imgs_dir,img_name)
        
        utils.save_image(save_path,test_img)
        
        #print(img_name,test_img.shape,masks.shape)