In [1]:
import tensorflow as tf
import numpy as np
import os
import imageio
from tqdm import trange
from skimage.transform import resize

  from ._conv import register_converters as _register_converters


In [2]:
def generate_batches(path, batch_size, output_shape=(160, 160)):
    cur_batch = []
    for filename in sorted(list(os.listdir(path))):
        filename_path = os.path.join(path, filename)
        cur_image = imageio.imread(filename_path)
        cur_image = resize(cur_image, output_shape, mode='reflect')
        cur_batch.append(cur_image)
        if len(cur_batch) == batch_size:
            yield np.array(cur_batch)
            cur_batch = []
            
    if cur_batch:
        yield cur_batch

In [3]:
def get_model_filenames(model_dir):
    files = os.listdir(model_dir)
    meta_files = [s for s in files if s.endswith('.meta')]
    if len(meta_files)==0:
        raise ValueError('No meta file found in the model directory (%s)' % model_dir)
    elif len(meta_files)>1:
        raise ValueError('There should not be more than one meta file in the model directory (%s)' % model_dir)
    meta_file = meta_files[0]
    # ckpt = tf.train.get_checkpoint_state(model_dir)
    # if ckpt and ckpt.model_checkpoint_path:
    #     ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
    #     return meta_file, ckpt_file
    ckpt_file = "model-20170512-110547.ckpt-250000"
    return meta_file, ckpt_file

In [4]:
def load_model(model_name='20170512-110547'):
    model_path = os.path.join('model', model_name)
    meta_file, ckpt_file = get_model_filenames(model_path)
    sess = tf.Session()
    saver = tf.train.import_meta_graph(os.path.join(model_path, meta_file))
    saver.restore(sess, os.path.join(model_path, ckpt_file))
    return sess

In [5]:
sess = load_model()

'model_variables' collection should be of type 'byte_list', but instead is of type 'node_list'.
INFO:tensorflow:Restoring parameters from model/20170512-110547/model-20170512-110547.ckpt-250000


In [6]:
graph = tf.get_default_graph()

input_tensor = graph.get_tensor_by_name("input:0")
embeddings = graph.get_tensor_by_name("embeddings:0")
phase_train = graph.get_tensor_by_name("phase_train:0")

In [7]:
BATCH_SIZE = 128
IMG_COUNT = 202599

embs = []
img_gen = generate_batches('data/img_align_celeba/', batch_size=BATCH_SIZE)
for _ in trange(0, IMG_COUNT + BATCH_SIZE - 1, BATCH_SIZE):
    images = next(img_gen)
    img_emb = sess.run(embeddings, feed_dict={
        input_tensor : images,
        phase_train : False})
    embs.append(img_emb)

100%|█████████▉| 1583/1584 [42:08<00:01,  1.60s/it]

StopIteration: 

In [11]:
embs_arr = np.vstack(embs)
print(embs_arr.shape)
np.save('embeddings', embs_arr)

(202599, 128)
