In [1]:
import pandas as pd
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import tensorflow as tf

import scipy
from scipy.io import loadmat
import re

import string
import imageio
import numpy as np
import matplotlib.pyplot as plt
from utils import *
import random
import time
from tqdm import trange

import warnings
warnings.filterwarnings('ignore')

  return f(*args, **kwds)


In [2]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower',
                                                     word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('2428',
                                                     id2word_dict['2428']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'],
                                         word2Id_dict['<RARE>']))


there are 6375 vocabularies in total
Word to id mapping, for example: flower -> 2428
Id to word mapping, for example: 2428 -> flower
Tokens: <PAD>: 6372; <RARE>: 6374


In [3]:
text = "the flower shown has yellow anther red pistil and bright red petals."
print(text)
print(sent2IdList(text))

the flower shown has yellow anther red pistil and bright red petals.
None
None


In [4]:
data_path = './dataset'
df = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(df)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))

There are 7370 image in training data


In [5]:
df.head(5)

Unnamed: 0,Captions,ImagePath
1855,"[[2430, 2428, 2431, 2427, 2436, 2432, 2450, 24...",/102flowers/image_08110.jpg
6790,"[[2430, 2428, 2431, 2427, 2436, 2432, 2440, 24...",/102flowers/image_07749.jpg
7908,"[[2435, 2428, 2505, 2431, 2444, 2427, 2433, 24...",/102flowers/image_04381.jpg
1805,"[[2430, 2428, 2431, 2563, 2437, 2427, 2433, 24...",/102flowers/image_04518.jpg
5679,"[[2435, 2428, 2427, 2432, 5409, 2429, 2432, 24...",/102flowers/image_07620.jpg


In [6]:
df_img = df['ImagePath'].values
df_caption = df['Captions'].values
d_captions = []
for i in trange(len(df_caption)):
    caps = []
    for caption in df_caption[i]:
        cap = []
        cap.append(word2Id_dict['<ST>'])
        for word in caption:
            cap.append(word)
        cap.append(word2Id_dict['<ED>'])
        caps.append(cap)
    d_captions.append(caps)
    
d_captions = np.asarray(d_captions)
df_ = pd.DataFrame({
    'Captions': d_captions,
    'ImagePath': df_img                
})

df_.head(10)

100%|██████████| 7370/7370 [00:00<00:00, 23159.84it/s]


Unnamed: 0,Captions,ImagePath
0,"[[798, 2430, 2428, 2431, 2427, 2436, 2432, 245...",/102flowers/image_08110.jpg
1,"[[798, 2430, 2428, 2431, 2427, 2436, 2432, 244...",/102flowers/image_07749.jpg
2,"[[798, 2435, 2428, 2505, 2431, 2444, 2427, 243...",/102flowers/image_04381.jpg
3,"[[798, 2430, 2428, 2431, 2563, 2437, 2427, 243...",/102flowers/image_04518.jpg
4,"[[798, 2435, 2428, 2427, 2432, 5409, 2429, 243...",/102flowers/image_07620.jpg
5,"[[798, 2430, 2428, 2442, 2450, 2439, 2441, 243...",/102flowers/image_00724.jpg
6,"[[798, 2428, 2433, 2438, 2427, 2429, 2487, 244...",/102flowers/image_00550.jpg
7,"[[798, 2430, 2428, 2442, 2438, 2439, 2441, 243...",/102flowers/image_07209.jpg
8,"[[798, 2428, 2431, 2427, 2436, 2432, 2440, 243...",/102flowers/image_02334.jpg
9,"[[798, 2518, 2428, 2470, 2451, 2510, 2448, 242...",/102flowers/image_07389.jpg


In [7]:
df_t = pd.read_csv('./dataset/text_ImgPath.csv')

print (type(eval(df_t['Captions'][0])))

<class 'list'>


In [8]:
IMAGE_HEIGHT = 64
IMAGE_WIDTH = 64
IMAGE_DEPTH = 3

def train_data_generator(caption, image_path):
    # load in the image according to image path
    imagefile = tf.read_file(data_path + image_path)
    image = tf.image.decode_image(imagefile, channels=3)
    float_img = tf.image.convert_image_dtype(image, tf.float32)
    float_img.set_shape([None, None, 3])
    image = tf.image.resize_images(float_img, size=[IMAGE_HEIGHT, IMAGE_WIDTH])
    image.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH])
    
    return image, caption

def data_iterator(filenames, batch_size, data_generator):
    # Load the training data into two Numpy arrays
    # df = pd.read_pickle(filenames)
    captions = df_['Captions'].values
    caption = []
    
    for i in range(len(captions)):
        caption.append(random.choice(captions[i]))
    caption = np.asarray(caption)
    
    image_path = df_['ImagePath'].values
    
    assert caption.shape[0] == image_path.shape[0]
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, image_path))
    dataset = dataset.map(data_generator)
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    
    iterator = dataset.make_initializable_iterator()
    output_types = dataset.output_types
    output_shapes = dataset.output_shapes
    
    return iterator, output_types, output_shapes

def data_iterator_rnn(filenames, batch_size):
    captions = df_['Captions'].values
    caption = []
    
    for _ in range(3):
        for i in range(len(captions)):
            caption.append(random.choice(captions[i]))
    caption = np.asarray(caption)
    
    dataset = tf.data.Dataset.from_tensor_slices((caption))
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    
    iterator = dataset.make_initializable_iterator()
    output_types = dataset.output_types
    output_shapes = dataset.output_shapes
    
    return iterator, output_types, output_shapes

In [9]:
tf.reset_default_graph()
BATCH_SIZE = 64
iterator_train, types, shapes = data_iterator(
    data_path + '/text2ImgData.pkl', BATCH_SIZE, train_data_generator)
iter_initializer = iterator_train.initializer
next_element = iterator_train.get_next()

with tf.Session() as sess:
    sess.run(iterator_train.initializer)
    next_element = iterator_train.get_next()
    image, text = sess.run(next_element)
    
print (text)

[[b'798' b'2430' b'2428' ..., b'6372' b'6372' b'1784']
 [b'798' b'2430' b'2428' ..., b'6372' b'6372' b'1784']
 [b'798' b'2430' b'2442' ..., b'6372' b'6372' b'1784']
 ..., 
 [b'798' b'2430' b'2428' ..., b'6372' b'6372' b'1784']
 [b'798' b'2430' b'2428' ..., b'6372' b'6372' b'1784']
 [b'798' b'2435' b'2444' ..., b'6372' b'6372' b'1784']]


In [10]:
tf.reset_default_graph()
BATCH_SIZE = 64
iterator_train, types, shapes = data_iterator_rnn(
    data_path + '/text2ImgData.pkl', BATCH_SIZE)

with tf.Session() as sess:
    sess.run(iterator_train.initializer)
    next_element = iterator_train.get_next()
    text = sess.run(next_element)
    
print (text)

[[b'798' b'2430' b'2428' ..., b'2443' b'6372' b'1784']
 [b'798' b'2430' b'2428' ..., b'6372' b'6372' b'1784']
 [b'798' b'2435' b'2428' ..., b'6372' b'6372' b'1784']
 ..., 
 [b'798' b'2435' b'2427' ..., b'6372' b'6372' b'1784']
 [b'798' b'2430' b'2428' ..., b'6372' b'6372' b'1784']
 [b'798' b'2430' b'2428' ..., b'6372' b'6372' b'1784']]


In [11]:
def get_hparas():
    hparas = {
        'MAX_SEQ_LENGTH': 20,
        'EMBED_DIM': 64,  # word embedding dimension
        'VOCAB_SIZE': len(vocab),
        'TEXT_DIM': 64,  # text embrdding dimension
        'RNN_HIDDEN_SIZE': 64,
        'Z_DIM': 64,  # random noise z dimension
        'IMAGE_SIZE': [64, 64, 3],  # render image size
        'BATCH_SIZE': 64,
        'LR': 0.002,
        'DECAY_EVERY': 100,
        'LR_DECAY': 0.5,
        'RNN_EPOCH': 10, # For pretrain RNN
        'BETA': 0.5,  # AdamOptimizer parameter
        'N_EPOCH': 20,
        'N_SAMPLE': num_training_sample
    }
    return hparas

In [12]:
class TextEncoder(object):
    def __init__(self, sess, hparas, training_phase=True, reuse=False, return_embed=False):
        self.hparas = hparas
        self.sess = sess
        self.training_phase = training_phase
        self.return_embed = return_embed
        self.reuse = reuse
        self.sess = tf.Session()
        
        self._build_model()
        
    def _build_model(self):
        
        with tf.variable_scope('rnnftxt', reuse=self.reuse):
            # if self.training_phase:
            self.word_embed_matrix = tf.get_variable(
                'rnn/wordembed',
                shape=(self.hparas['VOCAB_SIZE'], self.hparas['EMBED_DIM']),
                initializer=tf.random_normal_initializer(stddev=0.02),
                dtype=tf.float32)
            #else:
            #    self.word_embed_matrix = tf.Variable(self.embed_matrix)
                
            self.text = tf.placeholder(dtype=tf.int64, shape=[self.hparas['BATCH_SIZE'], None], name='caption')
            embedded_word_ids = tf.nn.embedding_lookup(self.word_embed_matrix, self.text)
            embedded_word_ref = tf.nn.embedding_lookup(self.word_embed_matrix, self.text)
            
            # seq = tf.one_hot(self.seq, 22)
        with tf.variable_scope('rnncell', reuse=self.reuse):
            LSTMCell = tf.contrib.rnn.BasicLSTMCell(
                self.hparas['TEXT_DIM'],
                reuse=self.reuse)
            initial_state = LSTMCell.zero_state(
                self.hparas['BATCH_SIZE'],
                dtype=tf.float32)
            rnn_net = tf.nn.dynamic_rnn(
                cell=LSTMCell,
                inputs=embedded_word_ids,
                initial_state=initial_state,
                dtype=tf.float32,
                time_major=False,
                scope='rnn/dynamic')
            self.rnn_net = rnn_net
            self.outputs_last = rnn_net[0][:, -1, :]
            self.outputs = rnn_net[0]
        
        
        
        if self.training_phase:
            with tf.variable_scope('rnn/logits'):
                self.logits = tf.contrib.layers.fully_connected(
                    self.outputs, 
                    self.hparas['TEXT_DIM'], 
                    None)
                
            with tf.variable_scope('rnn/loss'):
                self.loss = tf.reduce_sum(
                    tf.nn.softmax_cross_entropy_with_logits(
                    logits=self.logits[:, :-1], 
                    labels=embedded_word_ref[:, 1:]))
            
            with tf.variable_scope('rnn/optim'):
                self.optim = tf.train.AdamOptimizer(self.hparas['LR'], beta1=self.hparas['BETA']) \
                                        .minimize(self.loss)
            
        
        self.sess = tf.Session()
        self.saver = tf.train.Saver()
        global_step = tf.train.get_or_create_global_step()
        self.sess.run(tf.global_variables_initializer()) 
            
        
    def train(self, iterator_train):
        
        self.sess.run(iterator_train.initializer)
        
        self.losses = []
        for _epoch in trange(self.hparas['N_EPOCH']):
            start_time = time.time()
            
            epoch_loss = 0
            for _step in range(100):
                next_element = iterator_train.get_next()
                text = self.sess.run(next_element)
                
                loss, _ = self.sess.run([self.loss, self.optim],
                                        feed_dict={
                                            self.text: text
                                        })
                epoch_loss += loss
            self.losses.append(epoch_loss)
            
        self.save(global_step)
        
    def inference(self, text):
        output_embedded = self.sess.run(self.outputs_last, feed_dict={self.text: text})
        return output_embedded
        
    
    def save(self, global_step):
        gs = global_step.eval(self.sess)
        self.saver.save(self.sess, 'model/' + 'preTrainRnn.ckpt', global_step=gs)

    def restore(self):
        ckpt = tf.train.get_checkpoint_state('./model')
        if ckpt and ckpt.model_checkpoint_path:
            print (ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            print ("restore Fail!!")

tf.reset_default_graph()
BATCH_SIZE = 64
iterator_train, types, shapes = data_iterator_rnn(
    data_path + '/text2ImgData.pkl', BATCH_SIZE)

hparas = get_hparas()
caption_bt = tf.placeholder(dtype=tf.int64, shape=[hparas['BATCH_SIZE'], None], name='caption')
preTrain_RNN = Rnn(caption_bt, hparas=hparas, training_phase=True, reuse=False)

In [13]:
class Generator(object):
    """
        Using encodded text(hidden representation) to generate fake image data
        Inputs: Hidden representation of input text and random noise z as random seed
        Outputs: Target image in size 64x64x3
    """
    def __init__(self, noise_z, text, training_phase, hparas, reuse):
        """
        Args:
            noise_z: random generated fake image data, probabily with noise
            text: encodded text (hidden representation)
            training_phase: bool variable, indicate whether is during train phase or not
            hparas: hyperparameters
            reuse: bool variable, indicate if reuse trained weights or not
        """
        self.z = noise_z
        self.text = text
        self.train = training_phase
        self.hparas = hparas
        self.reuse = reuse

        self._build_model()

    def _build_model(self):
        with tf.variable_scope('generator', reuse=self.reuse):
            text_flatten = tf.contrib.layers.flatten(self.text)
            text_input = tf.layers.dense(
                    text_flatten,
                    self.hparas['TEXT_DIM'],
                    name='generator/text_input',
                    reuse=self.reuse)
            
            z_text_concat = tf.concat(
                    [self.z, text_input],
                    axis=1,
                    name='generator/z_text_concat')

            g_net = tf.layers.dense(
                    z_text_concat,
                    64*64*3,
                    name='generator/g_net',
                    reuse=self.reuse)

            g_net = tf.reshape(
                    g_net,
                    [-1, 64, 64, 3],
                    name='generator/g_net_reshape')

            self.generator_net = g_net
            self.outputs = g_net

In [14]:
class Discriminator(object):
    """
        A binary classifier that discriminate real/fake image data
    1. True Image:
        Inputs: true image and pair text
        Outputs: a float to represent the result expected to be 1
    
    2. Fake Image:
        Inputs: generated fake image and paired image
        Outputs: a float to represent the result expected to be 0

    """

    def __init__(self, image, text, training_phase, hparas, reuse):
        self.image = image
        self.text = text
        self.training_phase = training_phase
        self.hparas = hparas
        self.reuse = reuse

        self._build_model()

    def _build_model(self):
        with tf.variable_scope('discriminator', reuse=self.reuse):
            text_flatten = tf.contrib.layers.flatten(self.text)
            text_input = tf.layers.dense(
                    text_flatten,
                    self.hparas['TEXT_DIM'],
                    name='discrim/text_input',
                    reuse=self.reuse)

            image_flatten = tf.contrib.layers.flatten(self.image)
            image_input = tf.layers.dense(
                    image_flatten,
                    self.hparas['TEXT_DIM'],
                    name='discrim/image_input',
                    reuse=self.reuse)

            img_text_concate = tf.concat(
                    [text_input, image_input],
                    axis=1,
                    name='discrim/concate')

            d_net = tf.layers.dense(
                    img_text_concate,
                    1,
                    name='discrim/d_net',
                    reuse=self.reuse
                    )

            self.logits = d_net
            net_output = tf.nn.sigmoid(d_net)
            self.discriminator = net_output
            self.outputs = net_output

In [17]:
class GAN:

    def __init__(self,
               hparas,
               training_phase,
               dataset_path,
               ckpt_path,
               inference_path,
               recover=None):
        self.hparas = hparas
        self.train = training_phase
        self.dataset_path = dataset_path  # dataPath+'/text2ImgData.pkl'
        self.ckpt_path = ckpt_path
        self.sample_path = './samples'
        self.inference_path = './inference'

        self._get_session()  # get session
        self._get_train_data_iter()  # initialize and get data iterator
        self._input_layer()  # define input placeholder
        self._get_inference()  # build generator and discriminator
        self._get_loss()  # define gan loss
        self._get_var_with_name()  # get variables for each part of model
        self._optimize()  # define optimizer
        self._init_vars()
        self._get_saver()

        if recover is not None:
            self._load_checkpoint(recover)

    def _get_train_data_iter(self):
        if self.train:  # training data iteratot
            
            iterator_train, types, shapes = data_iterator(
                self.dataset_path + '/text2ImgData.pkl', self.hparas['BATCH_SIZE'],
                train_data_generator)
            
            iter_initializer = iterator_train.initializer
            next_element = iterator_train.get_next()
            # self.sess.run(iterator_train.initializer)
            self.iterator_train = iterator_train
        else:  # testing data iterator
            iterator_train, types, shapes = data_iterator_test(
                self.dataset_path + '/testData.pkl', self.hparas['BATCH_SIZE'])
            iter_initializer = iterator_train.initializer
            next_element = iterator_train.get_next()
            self.sess.run(iterator_train.initializer)
            self.iterator_test = iterator_train

    def _input_layer(self):
        if self.train:
            self.real_image = tf.placeholder(
              'float32', [
                  self.hparas['BATCH_SIZE'], self.hparas['IMAGE_SIZE'][0],
                  self.hparas['IMAGE_SIZE'][1], self.hparas['IMAGE_SIZE'][2]
              ],
              name='real_image')
            self.caption = tf.placeholder(
              dtype=tf.int64,
              shape=[self.hparas['BATCH_SIZE'], None],
              name='caption')
            self.z_noise = tf.placeholder(
              tf.float32, [self.hparas['BATCH_SIZE'], self.hparas['Z_DIM']],
              name='z_noise')
            self.embed_text = tf.placeholder(
                tf.float32, [self.hparas['BATCH_SIZE'], self.hparas['EMBED_DIM']])
            
        else:
            self.caption = tf.placeholder(
              dtype=tf.int64,
              shape=[self.hparas['BATCH_SIZE'], None],
              name='caption')
            self.z_noise = tf.placeholder(
              tf.float32, [self.hparas['BATCH_SIZE'], self.hparas['Z_DIM']],
              name='z_noise')
            self.embed_text = tf.placeholder(
                tf.float32, [self.hparas['BATCH_SIZE'], self.hparas['EMBED_DIM']])

    def _get_inference(self):
        if self.train:
            
            # GAN training
            # encoding text
            self.text_encoder = TextEncoder(
              self.sess, hparas=self.hparas, training_phase=False, reuse=False)
            self.text_encoder.restore()
            # generating image
            generator = Generator(
              self.z_noise,
              self.embed_text,
              training_phase=True,
              hparas=self.hparas,
              reuse=False)
            self.generator = generator
            
            self.epsilon = np.random.uniform(.0, .1, size=(1))
            
            self.x_hat = self.epsilon * self.real_image + (1-self.epsilon) * self.generator.outputs

            # discriminize
            # real image
            real_discriminator = Discriminator(
              self.real_image,
              self.embed_text,
              training_phase=True,
              hparas=self.hparas,
              reuse=False)
            self.real_discriminator = real_discriminator
            
            # combined image discriminator
            x_hat_discriminator = Discriminator(
                self.x_hat,
                self.embed_text,
                training_phase=True,
                hparas=self.hparas,
                reuse=True)
            self.x_hat_discriminator = x_hat_discriminator
            
            # fake image
            fake_discriminator = Discriminator(
              generator.outputs,
              self.embed_text,
              training_phase=True,
              hparas=self.hparas,
              reuse=True)
            self.fake_discriminator = fake_discriminator
            

        else:  # inference mode

            self.text_embed = TextEncoder(
              self.sess, hparas=self.hparas, training_phase=False, reuse=False)
            self.text_embed.restore()
            self.generate_image_net = Generator(
              self.z_noise,
              self.embed_text,
              training_phase=False,
              hparas=self.hparas,
              reuse=False)

    def _get_loss(self):
        if self.train:
            
            d_hat_loss = tf.reduce_mean(tf.square(tf.norm(
                tf.gradients(self.x_hat_discriminator.logits, self.x_hat),
                ord=2,
                axis=1) - 1))
            
            d_loss1 = tf.reduce_mean(
              tf.nn.sigmoid_cross_entropy_with_logits(
                  logits=self.real_discriminator.logits,
                  labels=tf.ones_like(self.real_discriminator.logits),
                  name='d_loss1'))
            d_loss2 = tf.reduce_mean(
              tf.nn.sigmoid_cross_entropy_with_logits(
                  logits=self.fake_discriminator.logits,
                  labels=tf.zeros_like(self.fake_discriminator.logits),
                  name='d_loss2'))
            
            self.d_loss = d_loss2 - d_loss1 + \
                          10 * d_hat_loss
            self.g_loss = -tf.reduce_mean(
              tf.nn.sigmoid_cross_entropy_with_logits(
                  logits=self.fake_discriminator.logits,
                  labels=tf.ones_like(self.fake_discriminator.logits),
                  name='g_loss'))

    def _optimize(self):
        if self.train:
            with tf.variable_scope('learning_rate'):
                self.lr_var = tf.Variable(self.hparas['LR'], trainable=False)

                discriminator_optimizer = tf.train.AdamOptimizer(
                  self.lr_var, beta1=0.0, beta2=0.9)
                generator_optimizer = tf.train.AdamOptimizer(
                  self.lr_var, beta1=0.0, beta2=0.9)
                self.d_optim = discriminator_optimizer.minimize(
                  self.d_loss, var_list=self.discrim_vars)
                self.g_optim = generator_optimizer.minimize(
                  self.g_loss, var_list=self.generator_vars + self.text_encoder_vars)

    def training(self):
        
        self.sess.run(self.iterator_train.initializer)
        counter = 1
        n_critic = 5
        for _epoch in trange(self.hparas['N_EPOCH']):
            start_time = time.time()

            if _epoch != 0 and (_epoch % self.hparas['DECAY_EVERY'] == 0):
                new_lr_decay = self.hparas['LR_DECAY']**(
                    _epoch // self.hparas['DECAY_EVERY'])
                self.sess.run(tf.assign(self.lr_var, self.hparas['LR'] * new_lr_decay))
                print("new lr %f" % (self.hparas['LR'] * new_lr_decay))

            n_batch_epoch = int(self.hparas['N_SAMPLE'] / self.hparas['BATCH_SIZE'])
            for _step in range(n_batch_epoch):
                step_time = time.time()
                next_element = self.iterator_train.get_next()
                image_batch, caption_batch = self.sess.run(next_element)
                
                b_z = np.random.normal(
                    loc=0.0,
                    scale=1.0,
                    size=(self.hparas['BATCH_SIZE'],
                          self.hparas['Z_DIM'])).astype(np.float32)
                
                if counter % n_critic:
                
                    text_out = self.text_encoder.inference(caption_batch)
                
                    # update discriminator
                    self.discriminator_error, _ = self.sess.run(
                        [self.d_loss, self.d_optim],
                        feed_dict={
                            self.real_image: image_batch,
                            self.embed_text: text_out,
                            self.z_noise: b_z
                        })
                else:

                    # update generate
                    self.generator_error, _ = self.sess.run(
                        [self.g_loss, self.g_optim],
                        feed_dict={self.embed_text: text_out,
                                   self.z_noise: b_z})
                counter += 1
                if _step % 50 == 0 and _step > 10:
                    print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4fs, d_loss: %.3f, g_loss: %.3f" \
                          % (_epoch, self.hparas['N_EPOCH'], _step, n_batch_epoch,
                             time.time() - step_time,
                             self.discriminator_error, self.generator_error))
            if _epoch != 0 and (_epoch + 1) % 5 == 0:
                self._save_checkpoint(_epoch)
                # self._sample_visiualize(_epoch)

    def inference(self):
        for _iters in trange(100):
            caption, idx = self.sess.run(self.iterator_test.get_next())
            z_seed = np.random.normal(
              loc=0.0,
              scale=1.0,
              size=(self.hparas['BATCH_SIZE'],
                    self.hparas['Z_DIM'])).astype(np.float32)
            
            rnn_out = self.text_embed.inference(caption)
            
            img_gen = self.sess.run(
                self.generate_image_net.outputs,
                feed_dict={
                    self.z_noise: z_seed,
                    self.embed_text: rnn_out
                })

            """img_gen, rnn_out = self.sess.run(
              [self.generate_image_net.outputs, self.text_embed.outputs],
              feed_dict={self.caption: caption,
                         self.z_noise: z_seed})"""
            for i in range(self.hparas['BATCH_SIZE']):
                scipy.misc.imsave(
                    self.inference_path + '/inference_{:04d}.png'.format(idx[i]),
                    img_gen[i])

    def _init_vars(self):
        self.sess.run(tf.global_variables_initializer())

    def _get_session(self):
        self.sess = tf.Session()

    def _get_saver(self):
        if self.train:
            self.rnn_saver = tf.train.Saver(var_list=self.text_encoder_vars)
            self.g_saver = tf.train.Saver(var_list=self.generator_vars)
            self.d_saver = tf.train.Saver(var_list=self.discrim_vars)
        else:
            self.rnn_saver = tf.train.Saver(var_list=self.text_encoder_vars)
            self.g_saver = tf.train.Saver(var_list=self.generator_vars)

    def _sample_visiualize(self, epoch):
        ni = int(np.ceil(np.sqrt(self.hparas['BATCH_SIZE'])))
        sample_size = self.hparas['BATCH_SIZE']
        max_len = self.hparas['MAX_SEQ_LENGTH']

        sample_seed = np.random.normal(
            loc=0.0, scale=1.0, size=(sample_size,
                                      self.hparas['Z_DIM'])).astype(np.float32)
        sample_sentence = [
            "the flower shown has yellow anther red pistil and bright red petals."
        ] * int(sample_size / ni) + [
            "this flower has petals that are yellow, white and purple and has dark lines"
        ] * int(sample_size / ni) + [
            "the petals on this flower are white with a yellow center"
        ] * int(sample_size / ni) + [
            "this flower has a lot of small round pink petals."
        ] * int(sample_size / ni) + [
            "this flower is orange in color, and has petals that are ruffled and rounded."
        ] * int(sample_size / ni) + [
            "the flower has yellow petals and the center of it is brown."
        ] * int(sample_size / ni) + [
            "this flower has petals that are blue and white."
        ] * int(sample_size / ni) + [
            "these white flowers have petals that start off white in color and end in a white towards the tips."
        ] * int(sample_size / ni)

        for i, sent in enumerate(sample_sentence):
            sample_sentence[i] = sent2IdList(sent, max_len)

        rnn_out = self.text_encoder.inference(sample_sentence)
        
        img_gen = self.sess.run(
                self.generator.outputs,
                feed_dict={
                    self.z_noise: sample_seed,
                    self.embed_text: rnn_out
                })
        """img_gen, rnn_out = self.sess.run(
            [self.generator.outputs, self.text_encoder.outputs],
            feed_dict={self.caption: sample_sentence,
                       self.z_noise: sample_seed})"""
        print (type(img_gen))
        img_gen = np.asarray(img_gen)
        print (img_gen.shape)
        save_images(img_gen, [ni, ni],
                    self.sample_path + '/train_{:02d}.png'.format(epoch))

    def _get_var_with_name(self):
        t_vars = tf.trainable_variables()

        self.text_encoder_vars = [var for var in t_vars if 'rnn' in var.name]
        self.generator_vars = [var for var in t_vars if 'generator' in var.name]
        self.discrim_vars = [var for var in t_vars if 'discrim' in var.name]

    def _load_checkpoint(self, recover):
        if self.train:
            self.rnn_saver.restore(
              self.sess, self.ckpt_path + 'rnn_model_' + str(recover) + '.ckpt')
            self.g_saver.restore(self.sess,
                               self.ckpt_path + 'g_model_' + str(recover) + '.ckpt')
            self.d_saver.restore(self.sess,
                               self.ckpt_path + 'd_model_' + str(recover) + '.ckpt')
        else:
            self.rnn_saver.restore(
              self.sess, self.ckpt_path + 'rnn_model_' + str(recover) + '.ckpt')
            self.g_saver.restore(self.sess,
                               self.ckpt_path + 'g_model_' + str(recover) + '.ckpt')
        print('-----success restored checkpoint--------')

    def _save_checkpoint(self, epoch):
        self.rnn_saver.save(self.sess,
                            self.ckpt_path + 'rnn_model_' + str(epoch) + '.ckpt')
        self.g_saver.save(self.sess,
                          self.ckpt_path + 'g_model_' + str(epoch) + '.ckpt')
        self.d_saver.save(self.sess,
                          self.ckpt_path + 'd_model_' + str(epoch) + '.ckpt')
        print('-----success saved checkpoint--------')

In [18]:
tf.reset_default_graph()
checkpoint_path = './checkpoint/'
inference_path = './inference'
gan = GAN(
    get_hparas(),
    training_phase=True,
    dataset_path=data_path,
    ckpt_path=checkpoint_path,
    inference_path=inference_path)
gan.training()

./model/preTrainRnn.ckpt-0
INFO:tensorflow:Restoring parameters from ./model/preTrainRnn.ckpt-0


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch: [ 0/20] [  50/ 115] time: 0.2606s, d_loss: -4627.164, g_loss: -61208.746
Epoch: [ 0/20] [ 100/ 115] time: 0.2714s, d_loss: -8312.949, g_loss: -645429.500


  5%|▌         | 1/20 [00:28<09:08, 28.86s/it]

Epoch: [ 1/20] [  50/ 115] time: 0.2610s, d_loss: -9799.005, g_loss: -2107764.250
Epoch: [ 1/20] [ 100/ 115] time: 0.2759s, d_loss: -9429.382, g_loss: -3229984.500


 10%|█         | 2/20 [00:58<08:41, 28.98s/it]

Epoch: [ 2/20] [  50/ 115] time: 0.2640s, d_loss: -10455.911, g_loss: -4880709.000
Epoch: [ 2/20] [ 100/ 115] time: 0.2704s, d_loss: -9716.897, g_loss: -6336608.000


 15%|█▌        | 3/20 [01:27<08:16, 29.22s/it]

Epoch: [ 3/20] [  50/ 115] time: 0.2728s, d_loss: -11270.501, g_loss: -8427815.000
Epoch: [ 3/20] [ 100/ 115] time: 0.2807s, d_loss: -11264.778, g_loss: -10274084.000


 20%|██        | 4/20 [01:58<07:55, 29.71s/it]

Epoch: [ 4/20] [  50/ 115] time: 0.2796s, d_loss: -12193.809, g_loss: -12892684.000
Epoch: [ 4/20] [ 100/ 115] time: 0.3024s, d_loss: -13129.537, g_loss: -15023148.000


 25%|██▌       | 5/20 [02:30<07:36, 30.46s/it]

-----success saved checkpoint--------
Epoch: [ 5/20] [  50/ 115] time: 0.2896s, d_loss: -14209.454, g_loss: -18118926.000
Epoch: [ 5/20] [ 100/ 115] time: 0.2974s, d_loss: -14938.394, g_loss: -20593070.000


 30%|███       | 6/20 [03:03<07:13, 30.97s/it]

Epoch: [ 6/20] [  50/ 115] time: 0.2986s, d_loss: -16382.497, g_loss: -24146276.000
Epoch: [ 6/20] [ 100/ 115] time: 0.3005s, d_loss: -17012.910, g_loss: -26895306.000


 35%|███▌      | 7/20 [03:36<06:50, 31.58s/it]

Epoch: [ 7/20] [  50/ 115] time: 0.3120s, d_loss: -18100.441, g_loss: -31047616.000
Epoch: [ 7/20] [ 100/ 115] time: 0.3176s, d_loss: -19513.688, g_loss: -34210696.000


 40%|████      | 8/20 [04:10<06:27, 32.27s/it]

Epoch: [ 8/20] [  50/ 115] time: 0.3120s, d_loss: -21048.264, g_loss: -38699108.000
Epoch: [ 8/20] [ 100/ 115] time: 0.3139s, d_loss: -22645.482, g_loss: -42198496.000


 45%|████▌     | 9/20 [04:44<06:02, 32.97s/it]

Epoch: [ 9/20] [  50/ 115] time: 0.3245s, d_loss: -23774.645, g_loss: -47231304.000
Epoch: [ 9/20] [ 100/ 115] time: 0.3252s, d_loss: -25487.984, g_loss: -51542336.000


 50%|█████     | 10/20 [05:20<05:39, 33.99s/it]

-----success saved checkpoint--------
Epoch: [10/20] [  50/ 115] time: 0.3404s, d_loss: -27025.039, g_loss: -56909400.000
Epoch: [10/20] [ 100/ 115] time: 0.3322s, d_loss: -28660.713, g_loss: -61617320.000


 55%|█████▌    | 11/20 [05:57<05:12, 34.75s/it]

Epoch: [11/20] [  50/ 115] time: 0.3386s, d_loss: -30201.861, g_loss: -67014044.000
Epoch: [11/20] [ 100/ 115] time: 0.3418s, d_loss: -32084.277, g_loss: -71928432.000


 60%|██████    | 12/20 [06:34<04:43, 35.48s/it]

Epoch: [12/20] [  50/ 115] time: 0.3462s, d_loss: -33773.734, g_loss: -77839760.000
Epoch: [12/20] [ 100/ 115] time: 0.3471s, d_loss: -35919.184, g_loss: -83962640.000


 65%|██████▌   | 13/20 [07:12<04:13, 36.17s/it]

Epoch: [13/20] [  50/ 115] time: 0.3490s, d_loss: -37579.285, g_loss: -89217280.000
Epoch: [13/20] [ 100/ 115] time: 0.3496s, d_loss: -39702.281, g_loss: -96389096.000


 70%|███████   | 14/20 [07:51<03:41, 36.91s/it]

Epoch: [14/20] [  50/ 115] time: 0.3556s, d_loss: -41904.531, g_loss: -101585296.000
Epoch: [14/20] [ 100/ 115] time: 0.3575s, d_loss: -43974.004, g_loss: -108568208.000


 75%|███████▌  | 15/20 [08:31<03:09, 37.93s/it]

-----success saved checkpoint--------
Epoch: [15/20] [  50/ 115] time: 0.3616s, d_loss: -46135.164, g_loss: -114915864.000
Epoch: [15/20] [ 100/ 115] time: 0.3634s, d_loss: -48373.180, g_loss: -122466784.000


 80%|████████  | 16/20 [09:11<02:34, 38.54s/it]

Epoch: [16/20] [  50/ 115] time: 0.3684s, d_loss: -50555.484, g_loss: -129163000.000
Epoch: [16/20] [ 100/ 115] time: 0.3700s, d_loss: -52872.543, g_loss: -137036288.000


 85%|████████▌ | 17/20 [09:51<01:57, 39.16s/it]

Epoch: [17/20] [  50/ 115] time: 0.3773s, d_loss: -56020.500, g_loss: -143692720.000
Epoch: [17/20] [ 100/ 115] time: 0.3784s, d_loss: -58019.945, g_loss: -153428832.000


 90%|█████████ | 18/20 [10:33<01:19, 39.78s/it]

Epoch: [18/20] [  50/ 115] time: 0.3920s, d_loss: -61131.852, g_loss: -159472416.000
Epoch: [18/20] [ 100/ 115] time: 0.3782s, d_loss: -62923.660, g_loss: -170521328.000


 95%|█████████▌| 19/20 [11:15<00:40, 40.45s/it]

Epoch: [19/20] [  50/ 115] time: 0.3997s, d_loss: -66642.883, g_loss: -176016608.000
Epoch: [19/20] [ 100/ 115] time: 0.3890s, d_loss: -68481.297, g_loss: -187879712.000


100%|██████████| 20/20 [11:59<00:00, 41.54s/it]

-----success saved checkpoint--------





In [19]:
def data_iterator_test(filenames, batch_size):
    data = pd.read_pickle(filenames)
    captions = data['Captions'].values
    caption = []
    for i in range(len(captions)):
        caption.append(([word2Id_dict['<ST>']] + captions[i] + [word2Id_dict['<ED>']]))
    caption = np.asarray(caption)
    index = data['ID'].values
    index = np.asarray(index)

    dataset = tf.data.Dataset.from_tensor_slices((caption, index))
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)

    iterator = dataset.make_initializable_iterator()
    output_types = dataset.output_types
    output_shapes = dataset.output_shapes

    return iterator, output_types, output_shapes

In [20]:
tf.reset_default_graph()
iterator_train, types, shapes = data_iterator_test(data_path + '/testData.pkl',
                                                   64)
iter_initializer = iterator_train.initializer
next_element = iterator_train.get_next()

with tf.Session() as sess:
    sess.run(iterator_train.initializer)
    next_element = iterator_train.get_next()
    caption, idex = sess.run(next_element)

In [21]:
tf.reset_default_graph()
gan = GAN(
    get_hparas(),
    training_phase=False,
    dataset_path=data_path,
    ckpt_path=checkpoint_path,
    inference_path=inference_path,
    recover=19)
img = gan.inference()

./model/preTrainRnn.ckpt-0
INFO:tensorflow:Restoring parameters from ./model/preTrainRnn.ckpt-0
INFO:tensorflow:Restoring parameters from ./checkpoint/rnn_model_19.ckpt
INFO:tensorflow:Restoring parameters from ./checkpoint/g_model_19.ckpt
-----success restored checkpoint--------


100%|██████████| 100/100 [00:09<00:00, 10.15it/s]
