In [None]:
import tensorflow as tf
import tensorflow_hub as hub

import matplotlib.pyplot as plt
import tempfile
from six.moves.urllib.request import urlopen
from six import BytesIO

import numpy as np
from PIL import Image, ImageColor, ImageDraw, ImageFont, ImageOps
import time

In [None]:
module_handle = "https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1"

model = hub.load(module_handle)

In [None]:
model.signatures.keys()

In [None]:
detector = model.signatures['default']

In [None]:
def display_image(image):
    fig = plt.figure(figsize=(20,15))
    plt.grid(False)
    plt.imshow(image)

def download_and_resize_image(url, new_width=256, new_height=256, display=False):
    _, filename = tempfile.mkstemp(suffix=".jpg")
    response = urlopen(url)
    image_data = response.read()
    image_data = BytesIO(image_data)
    pil_image = Image.open(image_data)
    pil_image = ImageOps.fit(pil_image, (new_width, new_height), Image.LANCZOS)
    pil_image_rgb = pil_image.convert("RGB")
    pil_image_rgb.save(filename, format="JPEG", quality=90)
    
    print("Image downloaded to %s." % filename)
    
    if display:
        display_image(pil_image)
    
    return filename

In [None]:
image_url = "https://m.media-amazon.com/images/M/MV5BZjk1ZjliYTgtZWI0Yi00NmJiLTg5NmEtMDQ0ODM1OGE2ZDIwXkEyXkFqcGdeQXRyYW5zY29kZS13b3JrZmxvdw@@._V1_.jpg"
#https://upload.wikimedia.org/wikipedia/commons/6/60/Naxos_Taverna.jpg
#https://www.verywellhealth.com/thmb/yV6RJjwnJs8lqaOUSx-dWnQvFfU=/1500x0/filters:no_upscale():max_bytes(150000):strip_icc()/fruits-to-avoid-if-you-have-diabetes-1087587-primary-recirc-3a95a09a48cb46b49d5632326f9405d3.jpg

downloaded_image_path = download_and_resize_image(image_url, 1280, 856, True)

In [None]:
def draw_bounding_box_on_image(image, y_min, x_min, y_max, x_max,
                               color, font, thickness=4, display_str_list=()):
    draw = ImageDraw.Draw(image)
    im_width, im_height = image.size
    
    #scaleing the bounding boxes
    (left, right, top, bottom) = (x_min*im_width, x_max*im_width,
                                  y_min*im_height, y_max*im_height)
    
    #defineing the edges of detection boxes 
    draw.line([(left, top), (left, bottom), 
               (right, bottom), (right, top), (left, top)],
              width=thickness,
              fill=color)
    
    """If the total height of the display strings added to the top of the bounding box exceeds the top of the image, stack the strings below the bounding box instead of above.""" 
    #[font.getsize(ds)[1] for ds in display_str_list]

    display_str_heights = [font.font.getsize(ds)[1] for ds in display_str_list]
    total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
    
    if top > total_display_str_height:
        text_bottom = top
    else:
        text_bottom = top + total_display_str_height
        
    # Reverse list and print from bottom to top.
    for display_str in display_str_list[::-1]:
        text_width, text_height = font.font.getsize(display_str)
        margin = np.ceil(0.05 * text_height)
        draw.rectangle([(left, text_bottom - text_height - 2 * margin),
                        (left + text_width, text_bottom)],
                       fill=color)
        draw.text((left + margin, text_bottom - text_height - margin),
                  display_str,
                  fill="black",
                  font=font)
        text_bottom -= text_height - 2 * margin

In [None]:
def draw_boxes(image, boxes, class_names, scores, max_boxes=10, min_score=0.1):
    colors = list(ImageColor.colormap.values())

    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf",
                              25)
    except IOError:
        print("Font not found, using default font.")
        font = ImageFont.load_default()

    for i in range(min(boxes.shape[0], max_boxes)):
        
        # only display detection boxes that have the minimum score or higher
        if scores[i] >= min_score:
            ymin, xmin, ymax, xmax = tuple(boxes[i])
            display_str = "{}: {}%".format(class_names[i].decode("ascii"),
                                         int(100 * scores[i]))
            color = colors[hash(class_names[i]) % len(colors)]
            image_pil = Image.fromarray(np.uint8(image)).convert("RGB")

            # draw one bounding box and overlay the class labels onto the image
            draw_bounding_box_on_image(image_pil,
                                       ymin,
                                       xmin,
                                       ymax,
                                       xmax,
                                       color,
                                       font,
                                       display_str_list=[display_str])
            np.copyto(image, np.array(image_pil))
        
    return image

In [None]:
def load_img(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    
    return img

In [None]:
def run_detector(detector, path):
    img = load_img(path)
    
    converted_img = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...]
    
    start_time = time.time()
    result = detector(converted_img)
    end_time = time.time()
    
    result = {key:value.numpy() for key,value in result.items()}
    
    print("Found %d objects." % len(result["detection_scores"]))
    print("Inference time: ", end_time-start_time)
    
    image_with_boxes = draw_boxes(img.numpy(), 
                                  result["detection_boxes"],
                                  result["detection_class_entities"], 
                                  result["detection_scores"])
    
    display_image(image_with_boxes)

In [None]:
run_detector(detector, downloaded_image_path)