-
Notifications
You must be signed in to change notification settings - Fork 1
/
infer.py
66 lines (54 loc) · 1.95 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import sys, os
import tensorflow as tf
import numpy as np
import scipy.misc
import skimage.io
import skimage.transform
from ops import *
from model import *
#import skipthoughts
#model = skipthoughts.load_model()
#vecs = skipthoughts.encode(model, ['blue hair red eyes', 'brown hair blue eyes'])
#print(vec.shape)
LOG_DIR = sys.argv[1]
# Define Network
with tf.variable_scope('input'):
z_dim = 100
z = tf.placeholder(tf.float32, [None, z_dim], name='z')
with tf.variable_scope('generator'):
fake_img = build_dec(z)
# initialize and saver
init_op = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=5)
sess = tf.Session()
# if model exist, restore, else init a new one
ckpt = tf.train.get_checkpoint_state(LOG_DIR)
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
print("=====Reading model parameters from %s=====" % ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
prev_step_num = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
else:
print("=====Model Loading Error=====")
exit()
try:
summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for step in range(1):
if coord.should_stop():
break
BATCH_SIZE = 30
# generate noise z and a batch of real images
batch_z = np.array(np.random.multivariate_normal(np.zeros(z_dim, dtype=np.float32),
np.identity(z_dim, dtype=np.float32), BATCH_SIZE), dtype=np.float32)
fake_img_eval = sess.run(fake_img, feed_dict={z:batch_z})
print(fake_img_eval.shape)
for idx, img in enumerate(fake_img_eval):
save_path = os.path.join(sys.argv[2], '%d.jpg' % idx)
scipy.misc.imsave(save_path, img)
except Exception as e:
coord.request_stop(e)
finally :
coord.request_stop()
coord.join(threads)
sess.close()