In [None]:
%matplotlib inline

In [None]:
from matplotlib import pyplot as plt
import numpy as np
import os
from PIL import Image
import tensorflow as tf
import math
from skimage.transform import resize

In [None]:
!bazel build -c opt //im2txt/...

In [None]:
import sys
import collections

In [None]:
sys.path.insert(0, 'bazel-bin/im2txt/run_inference.runfiles/im2txt/')

In [None]:
from __future__ import absolute_import
from im2txt import configuration
from im2txt.inference_utils import caption_generator
from im2txt.inference_utils import vocabulary
from im2txt import show_and_tell_model
from im2txt.inference_utils import inference_wrapper_base

In [None]:
tf.logging.set_verbosity(tf.logging.INFO)

In [None]:
checkpoint_path = '/Users/aghasy/tmp/im2sem/model/3/model.ckpt-1000000'

In [None]:
class InferenceWrapper(inference_wrapper_base.InferenceWrapperBase):
    """Model wrapper class for performing inference with a ShowAndTellModel."""

    def __init__(self):
        super(InferenceWrapper, self).__init__()
        self.fetches = {}
        self.post_feed_image = lambda: result 
    
    def build_model(self, model_config):
        model = show_and_tell_model.ShowAndTellModel(model_config, mode="inference")
        model.build()
        return model
    
    
    
    def feed_image(self, sess, encoded_image):
        self.fetches['initial_state'] = tf.get_default_graph().get_tensor_by_name('lstm/initial_state:0')
        result = sess.run(fetches=self.fetches,
                                 feed_dict={"image_feed:0": encoded_image})
        self.post_feed_image(result)
        return result['initial_state']

    def inference_step(self, sess, input_feed, state_feed):
        softmax_output, state_output = sess.run(
            fetches=["softmax:0", "lstm/state:0"],
            feed_dict={
                "input_feed:0": input_feed,
                "lstm/state_feed:0": state_feed,
            })
        return softmax_output, state_output, None


In [None]:
model = InferenceWrapper()
config = configuration.ModelConfig()
restore_fn = model.build_graph_from_config(config, checkpoint_path)

In [None]:
def test_uninitialized(sess):
    global_vars          = tf.global_variables()
    is_not_initialized   = sess.run([tf.is_variable_initialized(var) for var in global_vars])
    not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
    if not_initialized_vars:
        print 'Not initalized varibles are: ', [str(i.name) for i in not_initialized_vars]
        # from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
        # print_tensors_in_checkpoint_file(checkpoint_path, all_tensors=True, tensor_name='')
    assert not not_initialized_vars

In [None]:
with tf.Session() as sess:
    restore_fn(sess)
    test_uninitialized(sess)

In [None]:
sys.path.insert(0, '..')
sys.path.insert(0, '../slim/')

In [None]:
from object_detection.utils import ops as utils_ops

In [None]:
!cd ..; protoc object_detection/protos/*.proto --python_out=.

In [None]:
sys.path.insert(0, '../object_detection')
from utils import label_map_util
from utils import visualization_utils as vis_util

In [None]:
PATH_TO_LABELS = os.path.abspath('../object_detection/data/mscoco_label_map.pbtxt')
NUM_CLASSES = 90

In [None]:
label_map  = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

In [None]:
TEST_IMAGES_DIR = '/Users/aghasy/Desktop/tmp/test_images/'
TEST_IMAGE_PATHS = [os.path.join(TEST_IMAGES_DIR, x) for x in  os.listdir(TEST_IMAGES_DIR)]
IMAGE_SIZE = (12, 8)

In [None]:
image_tensor = tf.get_default_graph().get_tensor_by_name('image_feed:0')

In [None]:
def get_detection_tensors():
    ops = tf.get_default_graph().get_operations()
    all_tensor_names = {output.name for op in ops for output in op.outputs}
    tensor_dict = {}
    for key in [
            'num_detections', 'detection_boxes', 'detection_scores',
            'detection_classes', 'detector_image', 'decode/DecodeJpeg'
    ]:
        tensor_name = key + ':0'
        if tensor_name in all_tensor_names:
            tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
                    tensor_name)
    return tensor_dict

In [None]:
def draw_detection_result(output_dict):
    # all outputs are float32 numpy arrays, so convert types as appropriate
    image_np = output_dict['decode/DecodeJpeg']
    # Visualization of the results of a detection.
    image_np = np.array(image_np).astype(np.uint8)[:,:,:3]
#     image_np = resize(image_np, (640, int(640.0 / image_np.shape[0] *image_np.shape[1]), 3))
    vis_util.visualize_boxes_and_labels_on_image_array(
         image_np,
         output_dict['detection_boxes'][0],
         output_dict['detection_classes'][0].astype(np.uint8),
         output_dict['detection_scores'][0],
         category_index,
         instance_masks=None,
         use_normalized_coordinates=True,
         line_thickness=min(image_np.shape[:-1]) / 100)
    plt.figure(figsize=IMAGE_SIZE)
    plt.imshow(image_np)

In [None]:
model.fetches = get_detection_tensors()
model.post_feed_image = draw_detection_result

In [None]:
vocab_file = '/Users/aghasy/tmp/im2sem/model/word_counts.txt'
vocab = vocabulary.Vocabulary(vocab_file)
generator = caption_generator.CaptionGenerator(model, vocab)

In [None]:
with tf.Session() as sess:
    restore_fn(sess)
    for image_path in TEST_IMAGE_PATHS[1:]:
        with open(image_path, 'r') as content_file:
            content = content_file.read()
        captions = generator.beam_search(sess, content)
        title = "Captions for image %s:" % os.path.basename(image_path)
        for i, caption in enumerate(captions):
            # Ignore begin and end words.
            sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]]
            sentence = " ".join(sentence)
            title += "\n  %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob))
        plt.title(title)
        plt.draw()
        plt.pause(0.001)