In [None]:
import tensorflow as tf
import numpy as np
import os
import sys
import cPickle
import time

slim = tf.contrib.slim

In [None]:
def _imread(img_path):
    """Reads img out from file and return Image object"""
    return Image.open(img_path)

def _img_scale_on_short_edge(img, new_size):
    width, height = img.size
    if width < height:
        height = int(height * new_size / width)
        width = new_size
    else:
        width = int(width * new_size / height)
        height = new_size
    new_img = img.resize((width, height))
    return new_img

def _center_crop(img):
    width, height = img.size
    result_size = min(width, height)
    x = int((width - result_size) / 2)
    y = int((height - result_size) / 2)
    crop_box = (x, y, x+result_size, y+result_size)
    new_img = img.crop(crop_box)
    return new_img

def loads_images(image_names, images_dir, image_size):
    image_data = []
    for image_name in image_names:
        logging.info("load image %s" % (image_name))
        img = _imread(os.path.join(images_dir, image_name))
        img = _img_scale_on_short_edge(img, image_size)
        img = _center_crop(img)
        img_np = np.asarray(img, dtype=np.float32)
        img_np = img_np / 127.5 - 1
        image_data.append(img_np)
    image_data = np.asarray(image_data)
    return image_data

In [None]:
def parse_token_file(token_file):
    image_name_to_tokens = {}
    with open(token_file, 'r') as f:
        lines = f.readlines()
    for line in lines:
        image_id, description = line.strip('\r\n').split('\t')
        image_name, _ = image_id.split('#')
        image_name_to_tokens.setdefault(image_name, [])
        image_name_to_tokens[image_name].append(description)
    return image_name_to_tokens
  
def convert_token_to_id(image_name_to_tokens, vocab):
    image_name_to_token_ids = {}
    for image_name in image_name_to_tokens:
        image_name_to_token_ids.setdefault(image_name, [])
        descriptions = image_name_to_tokens[image_name]
        for description in descriptions:
            token_ids = vocab.encode(description)
            image_name_to_token_ids[image_name].append(token_ids)
    return image_name_to_token_ids

In [None]:
def build_feature_extractor(model_name, checkpoint_path):
    # https://github.com/tensorflow/models/blob/master/research/slim/nets/nets_factory.py
    # https://github.com/tensorflow/models/tree/master/research
    network_fn = nets_factory.get_network_fn(model_name, num_classes=0, is_training=False)
    image_size = network_fn.default_image_size
    images_placeholder = tf.placeholder(tf.float32, [None, image_size, image_size, 3])
    logits, endpoints = network_fn(images_placeholder)
    variables_to_restore = slim.get_variables_to_restore()
    restore_fn = slim.assign_from_checkpoint_fn(checkpoint_path, variables_to_restore)
    return images_placeholder, logits, restore_fn, image_size

In [None]:
model_name = 'inception_v4'
checkpoint_path = 'checkpoint_inception_v4/inception_v4.ckpt'
images_dir = 'flickr30k_images/'
token_file = 'results_20130124.token'
output_dir = 'features2'

In [None]:
images_placeholder, logits, restore_fn, image_size = build_feature_extractor(
      model_name, checkpoint_path)

image_name_to_tokens = parse_token_file(token_file)

existed_all_image_names = []
for image_name in image_name_to_tokens:
    if os.path.exists(os.path.join(images_dir, image_name)):
        existed_all_image_names.append(image_name)

tf.logging.info("image_size: %d" % image_size)
tf.logging.info("num of all images: %d" % len(existed_all_image_names))

if not os.path.exists(output_dir):
    os.mkdir(output_dir)

example_in_batch = 100
example_in_mini_batch = 10

num_batches = len(all_image_names) / example_in_batch
if len(all_image_names) % example_in_batch != 0:
    num_batches += 1
num_mini_batches = example_in_batch / example_in_mini_batch


sess = tf.Session()
restore_fn(sess)
sess.run(tf.global_variables_initializer())

for i in range(num_batches):
    image_names = all_image_names[i*example_in_batch: (i+1)*example_in_batch]
    features = []
    for j in range(num_mini_batches):
        mini_image_names = image_names[
            j*example_in_mini_batch: (j+1)*example_in_mini_batch]

        mini_image_data = loads_images(mini_image_names, FLAGS.images_dir, image_size)
        logits_val = sess.run(logits, 
                              feed_dict={images_placeholder: mini_image_data})
        num_example, num_width, num_height, num_channel = logits_val.shape
        assert num_width == 1 and num_height == 1
        logits_val = np.reshape(logits_val, (num_example, num_channel))
        print(logits_val.shape)
        features.append(logits_val)
    features = np.vstack(features)
    with open(os.path.join(output_dir, "image_features-%d.pickle" % i), 'w') as f:
        cPickle.dump((image_names, features), f)