Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
237 lines (185 sloc) 7.55 KB
"""Predict captions on test images using trained model, with greedy sample method"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from datetime import datetime
import configuration
from ShowAndTellModel import build_model
from coco_utils import load_coco_data, sample_coco_minibatch, decode_captions
from image_utils import image_from_url, write_text_on_image
import numpy as np
import scipy.misc
from scipy.misc import imread
import pandas as pd
import os
from six.moves import urllib
import sys
import tarfile
import json
import argparse
model_config = configuration.ModelConfig()
training_config = configuration.TrainingConfig()
FLAGS = None
verbose = True
mode = 'inference'
pretrain_model_name = 'classify_image_graph_def.pb'
layer_to_extract = 'pool_3:0'
MODEL_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
def maybe_download_and_extract():
"""Download and extract model tar file."""
dest_directory = FLAGS.pretrain_dir
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = MODEL_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(MODEL_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def create_graph():
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(os.path.join(
FLAGS.pretrain_dir, pretrain_model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
def extract_features(image_dir):
if not os.path.exists(image_dir):
print("image_dir does not exit!")
return None
maybe_download_and_extract()
create_graph()
with tf.Session() as sess:
# Some useful tensors:
# 'softmax:0': A tensor containing the normalized prediction across
# 1000 labels.
# 'pool_3:0': A tensor containing the next-to-last layer containing 2048
# float description of the image.
# 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
# encoding of the image.
# Runs the softmax tensor by feeding the image_data as input to the graph.
final_array = []
extract_tensor = sess.graph.get_tensor_by_name(layer_to_extract)
counter = 0
print("There are total " + str(len(os.listdir(image_dir))) + " images to process.")
all_image_names = os.listdir(image_dir)
all_image_names = pd.DataFrame({'file_name':all_image_names})
for img in all_image_names['file_name'].values:
temp_path = os.path.join(image_dir, img)
image_data = tf.gfile.FastGFile(temp_path, 'rb').read()
predictions = sess.run(extract_tensor, {'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
final_array.append(predictions)
final_array = np.array(final_array)
return final_array, all_image_names
def step_inference(sess, features, model, keep_prob):
batch_size = features.shape[0]
captions_in = np.ones((batch_size, 1)) # <START> token index is one
state = None
final_preds = []
current_pred = captions_in
mask = np.zeros((batch_size, model_config.padded_length))
mask[:, 0] = 1
# get initial state using image feature
feed_dict = {model['image_feature']: features,
model['keep_prob']: keep_prob}
state = sess.run(model['initial_state'], feed_dict=feed_dict)
# start to generate sentences
for t in range(model_config.padded_length):
feed_dict={model['input_seqs']: current_pred,
model['initial_state']: state,
model['input_mask']: mask,
model['keep_prob']: keep_prob}
current_pred, state = sess.run([model['preds'], model['final_state']], feed_dict=feed_dict)
current_pred = current_pred.reshape(-1, 1)
final_preds.append(current_pred)
return final_preds
def main(_):
# load dictionary
data = {}
with open(FLAGS.dict_file, 'r') as f:
dict_data = json.load(f)
for k, v in dict_data.items():
data[k] = v
data['idx_to_word'] = {int(k):v for k, v in data['idx_to_word'].items()}
# extract all features
features, all_image_names = extract_features(FLAGS.test_dir)
# Build the TensorFlow graph and train it
g = tf.Graph()
with g.as_default():
num_of_images = len(os.listdir(FLAGS.test_dir))
print("Inferencing on {} images".format(num_of_images))
# Build the model.
model = build_model(model_config, mode, inference_batch = num_of_images)
# run training
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
model['saver'].restore(sess, FLAGS.saved_sess)
print("Model restored! Last step run: ", sess.run(model['global_step']))
# predictions
final_preds = step_inference(sess, features, model, 1.0)
captions_pred = [unpack.reshape(-1, 1) for unpack in final_preds]
captions_pred = np.concatenate(captions_pred, 1)
captions_deco = decode_captions(captions_pred, data['idx_to_word'])
# saved the images with captions written on them
if not os.path.exists(FLAGS.results_dir):
os.makedirs(FLAGS.results_dir)
for j in range(len(captions_deco)):
this_image_name = all_image_names['file_name'].values[j]
img_name = os.path.join(FLAGS.results_dir, this_image_name)
img = imread(os.path.join(FLAGS.test_dir, this_image_name))
write_text_on_image(img, img_name, captions_deco[j])
print("\ndone.")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--pretrain_dir',
type=str,
default= '/tmp/imagenet/',
help="""\
Path to pretrained model (if not found, will download from web)\
"""
)
parser.add_argument(
'--test_dir',
type=str,
default= '/home/ubuntu/COCO/testImages/',
help="""\
Path to dir of test images to be predicted\
"""
)
parser.add_argument(
'--results_dir',
type=str,
default= '/home/ubuntu/COCO/savedTestImages/',
help="""\
Path to dir of predicted test images\
"""
)
parser.add_argument(
'--saved_sess',
type=str,
default= "/home/ubuntu/COCO/savedSession/model0.ckpt",
help="""\
Path to saved session\
"""
)
parser.add_argument(
'--dict_file',
type=str,
default= '/home/ubuntu/COCO/dataset/COCO_captioning/coco2014_vocab.json',
help="""\
Path to dictionary file\
"""
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)