Permalink
Cannot retrieve contributors at this time
"""Predict captions on test images using trained model, with greedy sample method""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
from datetime import datetime | |
import configuration | |
from ShowAndTellModel import build_model | |
from coco_utils import load_coco_data, sample_coco_minibatch, decode_captions | |
from image_utils import image_from_url, write_text_on_image | |
import numpy as np | |
import scipy.misc | |
from scipy.misc import imread | |
import pandas as pd | |
import os | |
from six.moves import urllib | |
import sys | |
import tarfile | |
import json | |
import argparse | |
model_config = configuration.ModelConfig() | |
training_config = configuration.TrainingConfig() | |
FLAGS = None | |
verbose = True | |
mode = 'inference' | |
pretrain_model_name = 'classify_image_graph_def.pb' | |
layer_to_extract = 'pool_3:0' | |
MODEL_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' | |
def maybe_download_and_extract(): | |
"""Download and extract model tar file.""" | |
dest_directory = FLAGS.pretrain_dir | |
if not os.path.exists(dest_directory): | |
os.makedirs(dest_directory) | |
filename = MODEL_URL.split('/')[-1] | |
filepath = os.path.join(dest_directory, filename) | |
if not os.path.exists(filepath): | |
def _progress(count, block_size, total_size): | |
sys.stdout.write('\r>> Downloading %s %.1f%%' % ( | |
filename, float(count * block_size) / float(total_size) * 100.0)) | |
sys.stdout.flush() | |
filepath, _ = urllib.request.urlretrieve(MODEL_URL, filepath, _progress) | |
print() | |
statinfo = os.stat(filepath) | |
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') | |
tarfile.open(filepath, 'r:gz').extractall(dest_directory) | |
def create_graph(): | |
"""Creates a graph from saved GraphDef file and returns a saver.""" | |
# Creates graph from saved graph_def.pb. | |
with tf.gfile.FastGFile(os.path.join( | |
FLAGS.pretrain_dir, pretrain_model_name), 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
_ = tf.import_graph_def(graph_def, name='') | |
def extract_features(image_dir): | |
if not os.path.exists(image_dir): | |
print("image_dir does not exit!") | |
return None | |
maybe_download_and_extract() | |
create_graph() | |
with tf.Session() as sess: | |
# Some useful tensors: | |
# 'softmax:0': A tensor containing the normalized prediction across | |
# 1000 labels. | |
# 'pool_3:0': A tensor containing the next-to-last layer containing 2048 | |
# float description of the image. | |
# 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG | |
# encoding of the image. | |
# Runs the softmax tensor by feeding the image_data as input to the graph. | |
final_array = [] | |
extract_tensor = sess.graph.get_tensor_by_name(layer_to_extract) | |
counter = 0 | |
print("There are total " + str(len(os.listdir(image_dir))) + " images to process.") | |
all_image_names = os.listdir(image_dir) | |
all_image_names = pd.DataFrame({'file_name':all_image_names}) | |
for img in all_image_names['file_name'].values: | |
temp_path = os.path.join(image_dir, img) | |
image_data = tf.gfile.FastGFile(temp_path, 'rb').read() | |
predictions = sess.run(extract_tensor, {'DecodeJpeg/contents:0': image_data}) | |
predictions = np.squeeze(predictions) | |
final_array.append(predictions) | |
final_array = np.array(final_array) | |
return final_array, all_image_names | |
def step_inference(sess, features, model, keep_prob): | |
batch_size = features.shape[0] | |
captions_in = np.ones((batch_size, 1)) # <START> token index is one | |
state = None | |
final_preds = [] | |
current_pred = captions_in | |
mask = np.zeros((batch_size, model_config.padded_length)) | |
mask[:, 0] = 1 | |
# get initial state using image feature | |
feed_dict = {model['image_feature']: features, | |
model['keep_prob']: keep_prob} | |
state = sess.run(model['initial_state'], feed_dict=feed_dict) | |
# start to generate sentences | |
for t in range(model_config.padded_length): | |
feed_dict={model['input_seqs']: current_pred, | |
model['initial_state']: state, | |
model['input_mask']: mask, | |
model['keep_prob']: keep_prob} | |
current_pred, state = sess.run([model['preds'], model['final_state']], feed_dict=feed_dict) | |
current_pred = current_pred.reshape(-1, 1) | |
final_preds.append(current_pred) | |
return final_preds | |
def main(_): | |
# load dictionary | |
data = {} | |
with open(FLAGS.dict_file, 'r') as f: | |
dict_data = json.load(f) | |
for k, v in dict_data.items(): | |
data[k] = v | |
data['idx_to_word'] = {int(k):v for k, v in data['idx_to_word'].items()} | |
# extract all features | |
features, all_image_names = extract_features(FLAGS.test_dir) | |
# Build the TensorFlow graph and train it | |
g = tf.Graph() | |
with g.as_default(): | |
num_of_images = len(os.listdir(FLAGS.test_dir)) | |
print("Inferencing on {} images".format(num_of_images)) | |
# Build the model. | |
model = build_model(model_config, mode, inference_batch = num_of_images) | |
# run training | |
init = tf.global_variables_initializer() | |
with tf.Session() as sess: | |
sess.run(init) | |
model['saver'].restore(sess, FLAGS.saved_sess) | |
print("Model restored! Last step run: ", sess.run(model['global_step'])) | |
# predictions | |
final_preds = step_inference(sess, features, model, 1.0) | |
captions_pred = [unpack.reshape(-1, 1) for unpack in final_preds] | |
captions_pred = np.concatenate(captions_pred, 1) | |
captions_deco = decode_captions(captions_pred, data['idx_to_word']) | |
# saved the images with captions written on them | |
if not os.path.exists(FLAGS.results_dir): | |
os.makedirs(FLAGS.results_dir) | |
for j in range(len(captions_deco)): | |
this_image_name = all_image_names['file_name'].values[j] | |
img_name = os.path.join(FLAGS.results_dir, this_image_name) | |
img = imread(os.path.join(FLAGS.test_dir, this_image_name)) | |
write_text_on_image(img, img_name, captions_deco[j]) | |
print("\ndone.") | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--pretrain_dir', | |
type=str, | |
default= '/tmp/imagenet/', | |
help="""\ | |
Path to pretrained model (if not found, will download from web)\ | |
""" | |
) | |
parser.add_argument( | |
'--test_dir', | |
type=str, | |
default= '/home/ubuntu/COCO/testImages/', | |
help="""\ | |
Path to dir of test images to be predicted\ | |
""" | |
) | |
parser.add_argument( | |
'--results_dir', | |
type=str, | |
default= '/home/ubuntu/COCO/savedTestImages/', | |
help="""\ | |
Path to dir of predicted test images\ | |
""" | |
) | |
parser.add_argument( | |
'--saved_sess', | |
type=str, | |
default= "/home/ubuntu/COCO/savedSession/model0.ckpt", | |
help="""\ | |
Path to saved session\ | |
""" | |
) | |
parser.add_argument( | |
'--dict_file', | |
type=str, | |
default= '/home/ubuntu/COCO/dataset/COCO_captioning/coco2014_vocab.json', | |
help="""\ | |
Path to dictionary file\ | |
""" | |
) | |
FLAGS, unparsed = parser.parse_known_args() | |
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) | |