In [1]:
import pandas as pd
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import tensorflow as tf

import scipy
from scipy.io import loadmat
import re

import string
import imageio
import numpy as np
import matplotlib.pyplot as plt
from utils import *
import random
import time
from tqdm import trange

import warnings
warnings.filterwarnings('ignore')

In [2]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower',
                                                     word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('2428',
                                                     id2word_dict['2428']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'],
                                         word2Id_dict['<RARE>']))

there are 6375 vocabularies in total
Word to id mapping, for example: flower -> 2428
Id to word mapping, for example: 2428 -> flower
Tokens: <PAD>: 6372; <RARE>: 6374


In [3]:
# Transform a sentence into its IDs and then add padding
def sent2IdList(line, MAX_SEQ_LENGTH=20):
  MAX_SEQ_LIMIT = MAX_SEQ_LENGTH
  padding = 0
  prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
  prep_line = prep_line.replace('-', ' ')
  prep_line = prep_line.replace('-', ' ')
  prep_line = prep_line.replace('  ', ' ')
  prep_line = prep_line.replace('.', '')
  tokens = prep_line.split(' ')
  tokens = [
      tokens[i] for i in range(len(tokens))
      if tokens[i] != ' ' and tokens[i] != ''
  ]
  l = len(tokens)
  padding = MAX_SEQ_LIMIT - l
  for i in range(padding):
    tokens.append('<PAD>')
  line = [
      word2Id_dict[tokens[k]]
      if tokens[k] in word2Id_dict else word2Id_dict['<RARE>']
      for k in range(len(tokens))
  ]

  return line

In [4]:
text = "the flower shown has yellow anther red pistil and bright red petals."
print(text)
print(sent2IdList(text))

the flower shown has yellow anther red pistil and bright red petals.
['2435', '2428', '2505', '2431', '2437', '2465', '2446', '2457', '2429', '2455', '2446', '6374', '6372', '6372', '6372', '6372', '6372', '6372', '6372', '6372']


## Dataset

In [5]:
data_path = './dataset'
df = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(df)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))

There are 7370 image in training data


In [6]:
df.head(5)

Unnamed: 0,Captions,ImagePath
1855,"[[2430, 2428, 2431, 2427, 2436, 2432, 2450, 24...",/102flowers/image_08110.jpg
6790,"[[2430, 2428, 2431, 2427, 2436, 2432, 2440, 24...",/102flowers/image_07749.jpg
7908,"[[2435, 2428, 2505, 2431, 2444, 2427, 2433, 24...",/102flowers/image_04381.jpg
1805,"[[2430, 2428, 2431, 2563, 2437, 2427, 2433, 24...",/102flowers/image_04518.jpg
5679,"[[2435, 2428, 2427, 2432, 5409, 2429, 2432, 24...",/102flowers/image_07620.jpg


In [13]:
df_img = df['ImagePath'].values
df_caption = df['Captions'].values
d_captions = []
for i in trange(len(df_caption)):
    caps = []
    for caption in df_caption[i]:
        cap = []
        cap.append(word2Id_dict['<ST>'])
        for word in caption:
            cap.append(word)
        cap.append(word2Id_dict['<ED>'])
        caps.append(cap)
    d_captions.append(caps)
    
d_captions = np.asarray(d_captions)
df_ = pd.DataFrame({
    'Captions': d_captions,
    'ImagePath': df_img                
})

df_.head(10)

100%|███████████████████████████████████████████████████████████████████████████| 7370/7370 [00:00<00:00, 26201.78it/s]


Unnamed: 0,Captions,ImagePath
0,"[[798, 2430, 2428, 2431, 2427, 2436, 2432, 245...",/102flowers/image_08110.jpg
1,"[[798, 2430, 2428, 2431, 2427, 2436, 2432, 244...",/102flowers/image_07749.jpg
2,"[[798, 2435, 2428, 2505, 2431, 2444, 2427, 243...",/102flowers/image_04381.jpg
3,"[[798, 2430, 2428, 2431, 2563, 2437, 2427, 243...",/102flowers/image_04518.jpg
4,"[[798, 2435, 2428, 2427, 2432, 5409, 2429, 243...",/102flowers/image_07620.jpg
5,"[[798, 2430, 2428, 2442, 2450, 2439, 2441, 243...",/102flowers/image_00724.jpg
6,"[[798, 2428, 2433, 2438, 2427, 2429, 2487, 244...",/102flowers/image_00550.jpg
7,"[[798, 2430, 2428, 2442, 2438, 2439, 2441, 243...",/102flowers/image_07209.jpg
8,"[[798, 2428, 2431, 2427, 2436, 2432, 2440, 243...",/102flowers/image_02334.jpg
9,"[[798, 2518, 2428, 2470, 2451, 2510, 2448, 242...",/102flowers/image_07389.jpg


In [14]:
#df_.to_csv('./dataset/text_ImgPath.csv')

In [15]:
df_t = pd.read_csv('./dataset/text_ImgPath.csv')

print (type(eval(df_t['Captions'][0])))

<class 'list'>


### Create Dataset Iterator by dataset api

In [16]:
IMAGE_HEIGHT = 64
IMAGE_WIDTH = 64
IMAGE_DEPTH = 3

In [17]:
def train_data_generator(caption, image_path):
    # load in the image according to image path
    imagefile = tf.read_file(data_path + image_path)
    image = tf.image.decode_image(imagefile, channels=3)
    float_img = tf.image.convert_image_dtype(image, tf.float32)
    float_img.set_shape([None, None, 3])
    image = tf.image.resize_images(float_img, size=[IMAGE_HEIGHT, IMAGE_WIDTH])
    
    image = tf.nn.l2_normalize(image, dim=[2])
    
    image.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH])
    
    return image, caption

def data_iterator(filenames, batch_size, data_generator):
    # Load the training data into two Numpy arrays
    # df = pd.read_pickle(filenames)
    captions = df_['Captions'].values
    caption = []
    
    for i in range(len(captions)):
        caption.append(random.choice(captions[i]))
    caption = np.asarray(caption)
    
    image_path = df_['ImagePath'].values
    
    assert caption.shape[0] == image_path.shape[0]
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, image_path))
    dataset = dataset.map(data_generator)
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    
    iterator = dataset.make_initializable_iterator()
    output_types = dataset.output_types
    output_shapes = dataset.output_shapes
    
    return iterator, output_types, output_shapes

def data_iterator_rnn(filenames, batch_size):
    captions = df_['Captions'].values
    caption = []
    
    for _ in range(3):
        for i in range(len(captions)):
            caption.append(random.choice(captions[i]))
    caption = np.asarray(caption)
    
    dataset = tf.data.Dataset.from_tensor_slices((caption))
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    
    iterator = dataset.make_initializable_iterator()
    output_types = dataset.output_types
    output_shapes = dataset.output_shapes
    
    return iterator, output_types, output_shapes

### Iterate the data_iterator

In [18]:
BATCH_SIZE = 64

In [19]:
tf.reset_default_graph()

iterator_train, types, shapes = data_iterator(
    data_path + '/text2ImgData.pkl', BATCH_SIZE, train_data_generator)
iter_initializer = iterator_train.initializer
next_element = iterator_train.get_next()

with tf.Session() as sess:
  sess.run(iterator_train.initializer)
  next_element = iterator_train.get_next()
  image, text = sess.run(next_element)


## Build Model

### Text Encoder

A RNN encoder that captures the meaning of input text

    Input: text (a list of ids)
    Output: Hidden representation of input text


In [20]:
class TextEncoder(object):
    def __init__(self, sess, hparas, training_phase=True, reuse=False, return_embed=False):
        self.hparas = hparas
        self.sess = sess
        self.training_phase = training_phase
        self.return_embed = return_embed
        self.reuse = reuse
        self.sess = tf.Session()
        
        self._build_model()
        
    def _build_model(self):
        
        with tf.variable_scope('rnnftxt', reuse=self.reuse):
            # if self.training_phase:
            self.word_embed_matrix = tf.get_variable(
                'rnn/wordembed',
                shape=(self.hparas['VOCAB_SIZE'], self.hparas['EMBED_DIM']),
                initializer=tf.random_normal_initializer(stddev=0.02),
                dtype=tf.float32)
            #else:
            #    self.word_embed_matrix = tf.Variable(self.embed_matrix)
                
            self.text = tf.placeholder(dtype=tf.int64, shape=[self.hparas['BATCH_SIZE'], None], name='caption')
            embedded_word_ids = tf.nn.embedding_lookup(self.word_embed_matrix, self.text)
            embedded_word_ref = tf.nn.embedding_lookup(self.word_embed_matrix, self.text)
            
            # seq = tf.one_hot(self.seq, 22)
        with tf.variable_scope('rnncell', reuse=self.reuse):
            LSTMCell = tf.contrib.rnn.BasicLSTMCell(
                self.hparas['TEXT_DIM'],
                reuse=self.reuse)
            initial_state = LSTMCell.zero_state(
                self.hparas['BATCH_SIZE'],
                dtype=tf.float32)
            rnn_net = tf.nn.dynamic_rnn(
                cell=LSTMCell,
                inputs=embedded_word_ids,
                initial_state=initial_state,
                dtype=tf.float32,
                time_major=False,
                scope='rnn/dynamic')
            self.rnn_net = rnn_net
            self.outputs_last = rnn_net[0][:, -1, :]
            self.outputs = rnn_net[0]
        
        
        
        if self.training_phase:
            with tf.variable_scope('rnn/logits'):
                self.logits = tf.contrib.layers.fully_connected(
                    self.outputs, 
                    self.hparas['TEXT_DIM'], 
                    None)
                
            with tf.variable_scope('rnn/loss'):
                self.loss = tf.reduce_sum(
                    tf.nn.softmax_cross_entropy_with_logits(
                    logits=self.logits[:, :-1], 
                    labels=embedded_word_ref[:, 1:]))
            
            with tf.variable_scope('rnn/optim'):
                self.optim = tf.train.AdamOptimizer(self.hparas['LR'], beta1=self.hparas['BETA']) \
                                        .minimize(self.loss)
            
        
        self.sess = tf.Session()
        self.saver = tf.train.Saver()
        global_step = tf.train.get_or_create_global_step()
        self.sess.run(tf.global_variables_initializer()) 
            
        
    def train(self, iterator_train):
        
        self.sess.run(iterator_train.initializer)
        
        self.losses = []
        for _epoch in trange(self.hparas['N_EPOCH']):
            start_time = time.time()
            
            epoch_loss = 0
            for _step in range(100):
                next_element = iterator_train.get_next()
                text = self.sess.run(next_element)
                
                loss, _ = self.sess.run([self.loss, self.optim],
                                        feed_dict={
                                            self.text: text
                                        })
                epoch_loss += loss
            self.losses.append(epoch_loss)
            
        self.save(global_step)
        
    def inference(self, text):
        output_embedded = self.sess.run(self.outputs_last, feed_dict={self.text: text})
        return output_embedded
        
    
    def save(self, global_step):
        gs = global_step.eval(self.sess)
        self.saver.save(self.sess, 'model/' + 'preTrainRnn.ckpt', global_step=gs)

    def restore(self):
        ckpt = tf.train.get_checkpoint_state('./model')
        if ckpt and ckpt.model_checkpoint_path:
            print (ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            print ("restore Fail!!")

### Generator

In [21]:
class Generator(object):
    """
        Using encodded text(hidden representation) to generate fake image data
        Inputs: Hidden representation of input text and random noise z as random seed
        Outputs: Target image in size 64x64x3
    """
    def __init__(self, noise_z, text, training_phase, hparas, reuse):
        """
        Args:
            noise_z: random generated fake image data, probabily with noise
            text: encodded text (hidden representation)
            training_phase: bool variable, indicate whether is during train phase or not
            hparas: hyperparameters
            reuse: bool variable, indicate if reuse trained weights or not
        """
        self.z = noise_z
        self.text = text
        self.train = training_phase
        self.hparas = hparas
        self.reuse = reuse

        self._build_model()

    def _build_model(self):
        with tf.variable_scope('generator', reuse=self.reuse):
            text_flatten = tf.contrib.layers.flatten(self.text)
            text_input = tf.layers.dense(
                    text_flatten,
                    self.hparas['TEXT_DIM'],
                    name='gen/text_input',
                    activation=tf.nn.leaky_relu,
                    reuse=self.reuse)
            
            z_text_concat = tf.concat(
                    [self.z, text_input],
                    axis=1,
                    name='gen/z_text_concat')
            
            z_text = tf.layers.dense(
                z_text_concat,
                units=16 * 16 * 128,
                name='gen/z_text_dense',
                reuse=self.reuse)
            z_text = tf.nn.tanh(z_text)
            
            z_reshape = tf.reshape(
                    z_text,
                    [self.hparas['BATCH_SIZE'], 16, 16, z_text_concat.shape[-1]],
                    name='gen/z_reshape')
            
            self.deconv1 = tf.layers.conv2d_transpose(z_reshape, 64, 5,
                                                      strides=2,
                                                      padding='same',
                                                      name='gen/deconv1')
            
            self.deconv2 = tf.layers.conv2d_transpose(self.deconv1, 3, 5,
                                                      strides=2,
                                                      padding='same',
                                                      name='gen/deconv2')

            g_net = self.deconv2
            g_net = tf.nn.tanh(g_net)
            self.generator_net = g_net
            self.outputs = g_net

### Discriminator

In [28]:
class Discriminator(object):
    """
        A binary classifier that discriminate real/fake image data
    1. True Image:
        Inputs: true image and pair text
        Outputs: a float to represent the result expected to be 1
    
    2. Fake Image:
        Inputs: generated fake image and paired image
        Outputs: a float to represent the result expected to be 0

    """

    def __init__(self, image, text, training_phase, hparas, reuse):
        self.image = image
        self.text = text
        self.training_phase = training_phase
        self.hparas = hparas
        self.reuse = reuse

        self._build_model()

    def _build_model(self):
        with tf.variable_scope('discriminator', reuse=self.reuse):
            text_flatten = tf.contrib.layers.flatten(self.text)
            text_input = tf.layers.dense(
                    text_flatten,
                    self.hparas['TEXT_DIM'],
                    activation=tf.nn.leaky_relu,
                    name='dis/text_input',
                    reuse=self.reuse)
            
            self.conv1 = tf.layers.conv2d(
                inputs=self.image,
                filters=32,
                kernel_size=[3, 3], #[5, 5]
                padding='same',
                activation=tf.nn.leaky_relu,
                name='discrim/conv1',
                reuse=self.reuse)
            
            self.pool1 = tf.layers.average_pooling2d(
                inputs=self.conv1,
                pool_size=[3, 3], #[4, 4]
                strides=2,
                padding='same',
                name='discrim/pool1')
            
            self.conv2 = tf.layers.conv2d(
                inputs=self.pool1,
                filters=64,
                kernel_size=[3, 3],
                padding='same',
                activation=tf.nn.leaky_relu,
                name='discrim/conv2d',
                reuse=self.reuse)
            
            self.pool2 = tf.layers.average_pooling2d(
                inputs=self.conv2,
                pool_size=[3, 3],
                strides=2,
                padding='same',
                name='discrim/pool2')
            
            image_flatten = tf.contrib.layers.flatten(self.pool2)
            
            image_input = tf.layers.dense(
                    image_flatten,
                    self.hparas['TEXT_DIM'],
                    name='dis/image_input',
                    reuse=self.reuse)

            img_text_concate = tf.concat(
                    [text_input, image_input],
                    axis=1,
                    name='dis/concate')

            d_net = tf.layers.dense(
                    img_text_concate,
                    1,
                    name='dis/d_net',
                    activation=tf.nn.leaky_relu,
                    reuse=self.reuse
                    )
            
            self.logits = d_net      
            
            d_net = tf.nn.dropout(d_net, keep_prob=0.5)
            
            net_output = tf.nn.sigmoid(d_net)
            
            self.discriminator = net_output
            self.outputs = net_output

### Build Main GAN Model

In [29]:
class GAN:

    def __init__(self,
               hparas,
               training_phase,
               dataset_path,
               ckpt_path,
               inference_path,
               recover=None):
        self.hparas = hparas
        self.train = training_phase
        self.dataset_path = dataset_path  # dataPath+'/text2ImgData.pkl'
        self.ckpt_path = ckpt_path
        self.sample_path = './samples'
        self.inference_path = './inference'

        self._get_session()  # get session
        self._get_train_data_iter()  # initialize and get data iterator
        self._input_layer()  # define input placeholder
        self._get_inference()  # build generator and discriminator
        self._get_loss()  # define gan loss
        self._get_var_with_name()  # get variables for each part of model
        self._optimize()  # define optimizer
        self._init_vars()
        self._get_saver()

        if recover is not None:
            self._load_checkpoint(recover)

    def _get_train_data_iter(self):
        if self.train:  # training data iteratot
            
            iterator_train, types, shapes = data_iterator(
                self.dataset_path + '/text2ImgData.pkl', self.hparas['BATCH_SIZE'],
                train_data_generator)
            
            iter_initializer = iterator_train.initializer
            next_element = iterator_train.get_next()
            # self.sess.run(iterator_train.initializer)
            self.iterator_train = iterator_train
        else:  # testing data iterator
            iterator_train, types, shapes = data_iterator_test(
                self.dataset_path + '/testData.pkl', self.hparas['BATCH_SIZE'])
            iter_initializer = iterator_train.initializer
            next_element = iterator_train.get_next()
            self.sess.run(iterator_train.initializer)
            self.iterator_test = iterator_train

    def _input_layer(self):
        if self.train:
            self.real_image = tf.placeholder(
              'float32', [
                  self.hparas['BATCH_SIZE'], self.hparas['IMAGE_SIZE'][0],
                  self.hparas['IMAGE_SIZE'][1], self.hparas['IMAGE_SIZE'][2]
              ],
              name='real_image')
            self.caption = tf.placeholder(
              dtype=tf.int64,
              shape=[self.hparas['BATCH_SIZE'], None],
              name='caption')
            self.z_noise = tf.placeholder(
              tf.float32, [self.hparas['BATCH_SIZE'], self.hparas['Z_DIM']],
              name='z_noise')
            self.embed_text = tf.placeholder(
                tf.float32, [self.hparas['BATCH_SIZE'], self.hparas['EMBED_DIM']])
            
        else:
            self.caption = tf.placeholder(
              dtype=tf.int64,
              shape=[self.hparas['BATCH_SIZE'], None],
              name='caption')
            self.z_noise = tf.placeholder(
              tf.float32, [self.hparas['BATCH_SIZE'], self.hparas['Z_DIM']],
              name='z_noise')
            self.embed_text = tf.placeholder(
                tf.float32, [self.hparas['BATCH_SIZE'], self.hparas['EMBED_DIM']])

    def _get_inference(self):
        if self.train:
            
            # GAN training
            # encoding text
            self.text_encoder = TextEncoder(
              self.sess, hparas=self.hparas, training_phase=False, reuse=False)
            # self.text_encoder.restore()
            # generating image
            generator = Generator(
              self.z_noise,
              self.embed_text,
              training_phase=True,
              hparas=self.hparas,
              reuse=False)
            self.generator = generator

            # discriminize
            # real image
            real_discriminator = Discriminator(
              self.real_image,
              self.embed_text,
              training_phase=True,
              hparas=self.hparas,
              reuse=False)
            self.real_discriminator = real_discriminator
            
            # fake image
            fake_discriminator = Discriminator(
              generator.outputs,
              self.embed_text,
              training_phase=True,
              hparas=self.hparas,
              reuse=True)
            self.fake_discriminator = fake_discriminator
            

        else:  # inference mode

            self.text_embed = TextEncoder(
              self.sess, hparas=self.hparas, training_phase=False, reuse=False)
            # self.text_embed.restore()
            self.generate_image_net = Generator(
              self.z_noise,
              self.embed_text,
              training_phase=False,
              hparas=self.hparas,
              reuse=False)

    def _get_loss(self):
        if self.train:
            self.epsilon = np.random.rand(1)[0]
            
            x_hat = (self.epsilon * self.real_image) + (1-self.epsilon) * self.generator.outputs
            
            x_hat_discriminator = Discriminator(
                x_hat,
                self.embed_text,
                training_phase=True,
                hparas=self.hparas,
                reuse=True)
            
            d_hat_loss = tf.reduce_mean(tf.square(tf.norm(
                tf.gradients(x_hat_discriminator.logits, x_hat),
                ord=2) - 1))
            
            d_loss1 = tf.reduce_mean(self.real_discriminator.logits)
            d_loss2 = tf.reduce_mean(self.fake_discriminator.logits)
            
            self.d_loss = d_loss2 - d_loss1 + \
                          10 * d_hat_loss
            
            self.g_loss = -tf.reduce_mean(self.fake_discriminator.logits)

    def _optimize(self):
        if self.train:
            with tf.variable_scope('learning_rate'):
                self.lr_var = tf.Variable(self.hparas['LR'], trainable=False)

                discriminator_optimizer = tf.train.AdamOptimizer(
                  self.lr_var, beta1=0.0, beta2=0.9)
                generator_optimizer = tf.train.AdamOptimizer(
                  self.lr_var, beta1=0.0, beta2=0.9)
                self.d_optim = discriminator_optimizer.minimize(
                  self.d_loss, var_list=self.discrim_vars)
                self.g_optim = generator_optimizer.minimize(
                  self.g_loss, var_list=self.generator_vars + self.text_encoder_vars)

    def training(self):
        
        self.sess.run(self.iterator_train.initializer)
        counter = 1
        n_critic = 3
        for _epoch in trange(self.hparas['N_EPOCH']):
            start_time = time.time()

            if _epoch != 0 and (_epoch % self.hparas['DECAY_EVERY'] == 0):
                new_lr_decay = self.hparas['LR_DECAY']**(
                    _epoch // self.hparas['DECAY_EVERY'])
                self.sess.run(tf.assign(self.lr_var, self.hparas['LR'] * new_lr_decay))
                print("new lr %f" % (self.hparas['LR'] * new_lr_decay))

            n_batch_epoch = int(self.hparas['N_SAMPLE'] / self.hparas['BATCH_SIZE'])
            for _step in range(n_batch_epoch):
                step_time = time.time()
                next_element = self.iterator_train.get_next()
                image_batch, caption_batch = self.sess.run(next_element)
                
                b_z = np.random.normal(
                    loc=0.0,
                    scale=1.0,
                    size=(self.hparas['BATCH_SIZE'],
                          self.hparas['Z_DIM'])).astype(np.float32)
                
                if counter % n_critic:
                
                    text_out = self.text_encoder.inference(caption_batch)
                
                    # update discriminator
                    self.discriminator_error, _ = self.sess.run(
                        [self.d_loss, self.d_optim],
                        feed_dict={
                            self.real_image: image_batch,
                            self.embed_text: text_out,
                            self.z_noise: b_z
                        })
                else:

                    # update generate
                    self.generator_error, _ = self.sess.run(
                        [self.g_loss, self.g_optim],
                        feed_dict={self.embed_text: text_out,
                                   self.z_noise: b_z})
                counter += 1
                if _step % 50 == 0 and _step > 10:
                    print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4fs, d_loss: %.3f, g_loss: %.3f" \
                          % (_epoch, self.hparas['N_EPOCH'], _step, n_batch_epoch,
                             time.time() - step_time,
                             self.discriminator_error, self.generator_error))
            if _epoch != 0 and (_epoch + 1) % 5 == 0:
                self._save_checkpoint(_epoch)
                self._sample_visiualize(_epoch)

    def inference(self):
        for _iters in trange(100):
            caption, idx = self.sess.run(self.iterator_test.get_next())
            z_seed = np.random.normal(
              loc=0.0,
              scale=1.0,
              size=(self.hparas['BATCH_SIZE'],
                    self.hparas['Z_DIM'])).astype(np.float32)
            
            rnn_out = self.text_embed.inference(caption)
            
            img_gen = self.sess.run(
                self.generate_image_net.outputs,
                feed_dict={
                    self.z_noise: z_seed,
                    self.embed_text: rnn_out
                })

            """img_gen, rnn_out = self.sess.run(
              [self.generate_image_net.outputs, self.text_embed.outputs],
              feed_dict={self.caption: caption,
                         self.z_noise: z_seed})"""
            for i in range(self.hparas['BATCH_SIZE']):
                scipy.misc.imsave(
                    self.inference_path + '/inference_{:04d}.png'.format(idx[i]),
                    img_gen[i])

    def _init_vars(self):
        self.sess.run(tf.global_variables_initializer())

    def _get_session(self):
        self.sess = tf.Session()

    def _get_saver(self):
        if self.train:
            self.rnn_saver = tf.train.Saver(var_list=self.text_encoder_vars)
            self.g_saver = tf.train.Saver(var_list=self.generator_vars)
            self.d_saver = tf.train.Saver(var_list=self.discrim_vars)
        else:
            self.rnn_saver = tf.train.Saver(var_list=self.text_encoder_vars)
            self.g_saver = tf.train.Saver(var_list=self.generator_vars)

    def _sample_visiualize(self, epoch):
        ni = int(np.ceil(np.sqrt(self.hparas['BATCH_SIZE'])))
        sample_size = self.hparas['BATCH_SIZE']
        max_len = self.hparas['MAX_SEQ_LENGTH']

        sample_seed = np.random.normal(
            loc=0.0, scale=1.0, size=(sample_size,
                                      self.hparas['Z_DIM'])).astype(np.float32)
        sample_sentence = [
            "the flower shown has yellow anther red pistil and bright red petals."
        ] * int(sample_size / ni) + [
            "this flower has petals that are yellow, white and purple and has dark lines"
        ] * int(sample_size / ni) + [
            "the petals on this flower are white with a yellow center"
        ] * int(sample_size / ni) + [
            "this flower has a lot of small round pink petals."
        ] * int(sample_size / ni) + [
            "this flower is orange in color, and has petals that are ruffled and rounded."
        ] * int(sample_size / ni) + [
            "the flower has yellow petals and the center of it is brown."
        ] * int(sample_size / ni) + [
            "this flower has petals that are blue and white."
        ] * int(sample_size / ni) + [
            "these white flowers have petals that start off white in color and end in a white towards the tips."
        ] * int(sample_size / ni)

        sample_captions = []
        
        for sent in sample_sentence:
            sample_captions.append(([word2Id_dict['<ST>']] + \
                                    sent2IdList(sent, max_len) + \
                                    [word2Id_dict['<ED>']]))
            
            
        """for i, sent in enumerate(sample_sentence):
            sample_sentence[i] = sent2IdList(sent, max_len)"""
            
        
        sample_captions = np.asarray(sample_captions)

        rnn_out = self.text_encoder.inference(sample_captions)
        
        img_gen = self.sess.run(
                self.generator.outputs,
                feed_dict={
                    self.z_noise: sample_seed,
                    self.embed_text: rnn_out
                })
        """img_gen, rnn_out = self.sess.run(
            [self.generator.outputs, self.text_encoder.outputs],
            feed_dict={self.caption: sample_sentence,
                       self.z_noise: sample_seed})"""
        
        save_images(img_gen, [ni, ni],
                    self.sample_path + '/train_{:02d}.png'.format(epoch))

    def _get_var_with_name(self):
        t_vars = tf.trainable_variables()

        self.text_encoder_vars = [var for var in t_vars if 'rnn' in var.name]
        self.generator_vars = [var for var in t_vars if 'generator' in var.name]
        self.discrim_vars = [var for var in t_vars if 'discrim' in var.name]

    def _load_checkpoint(self, recover):
        if self.train:
            self.rnn_saver.restore(
              self.sess, self.ckpt_path + 'rnn_model_' + str(recover) + '.ckpt')
            self.g_saver.restore(self.sess,
                               self.ckpt_path + 'g_model_' + str(recover) + '.ckpt')
            self.d_saver.restore(self.sess,
                               self.ckpt_path + 'd_model_' + str(recover) + '.ckpt')
        else:
            self.rnn_saver.restore(
              self.sess, self.ckpt_path + 'rnn_model_' + str(recover) + '.ckpt')
            self.g_saver.restore(self.sess,
                               self.ckpt_path + 'g_model_' + str(recover) + '.ckpt')
        print('-----success restored checkpoint--------')

    def _save_checkpoint(self, epoch):
        self.rnn_saver.save(self.sess,
                            self.ckpt_path + 'rnn_model_' + str(epoch) + '.ckpt')
        self.g_saver.save(self.sess,
                          self.ckpt_path + 'g_model_' + str(epoch) + '.ckpt')
        self.d_saver.save(self.sess,
                          self.ckpt_path + 'd_model_' + str(epoch) + '.ckpt')
        print('-----success saved checkpoint--------')

## Training

In [34]:
epoch = 50

In [35]:
def get_hparas():
    hparas = {
        'MAX_SEQ_LENGTH': 20,
        'EMBED_DIM': 64,  # word embedding dimension
        'VOCAB_SIZE': len(vocab),
        'TEXT_DIM': 64,  # text embrdding dimension
        'RNN_HIDDEN_SIZE': 64,
        'Z_DIM': 64,  # random noise z dimension
        'IMAGE_SIZE': [64, 64, 3],  # render image size
        'BATCH_SIZE': 64,
        'LR': 0.0001,
        'DECAY_EVERY': 10,
        'LR_DECAY': 0.5,
        'RNN_EPOCH': 10, # For pretrain RNN
        'BETA': 0.5,  # AdamOptimizer parameter
        'N_EPOCH': epoch,
        'N_SAMPLE': num_training_sample
    }
    return hparas

In [36]:
tf.reset_default_graph()
checkpoint_path = './checkpoint/'
inference_path = './inference'
gan = GAN(
    get_hparas(),
    training_phase=True,
    dataset_path=data_path,
    ckpt_path=checkpoint_path,
    inference_path=inference_path)

In [None]:
gan.training()


  0%|                                                                                           | 0/50 [00:00<?, ?it/s]


Epoch: [ 0/50] [  50/ 115] time: 1.3383s, d_loss: -2.653, g_loss: -1.136
Epoch: [ 0/50] [ 100/ 115] time: 1.5406s, d_loss: -2.373, g_loss: -0.898


  2%|█▌                                                                              | 1/50 [03:19<2:43:12, 199.84s/it]

Epoch: [ 1/50] [  50/ 115] time: 0.9470s, d_loss: -1.720, g_loss: -0.487
Epoch: [ 1/50] [ 100/ 115] time: 0.8011s, d_loss: -1.734, g_loss: -0.276


  4%|███▏                                                                            | 2/50 [05:01<2:16:23, 170.48s/it]

Epoch: [ 2/50] [  50/ 115] time: 0.9611s, d_loss: -1.122, g_loss: -1.265
Epoch: [ 2/50] [ 100/ 115] time: 0.9415s, d_loss: -0.817, g_loss: -2.040


  6%|████▊                                                                           | 3/50 [06:50<1:59:05, 152.04s/it]

Epoch: [ 3/50] [  50/ 115] time: 0.8412s, d_loss: -0.656, g_loss: -0.798
Epoch: [ 3/50] [ 100/ 115] time: 0.9570s, d_loss: -0.511, g_loss: -0.642


  8%|██████▍                                                                         | 4/50 [08:33<1:45:15, 137.28s/it]

Epoch: [ 4/50] [  50/ 115] time: 1.0388s, d_loss: -0.261, g_loss: -1.203
Epoch: [ 4/50] [ 100/ 115] time: 0.8026s, d_loss: -0.363, g_loss: -1.266
-----success saved checkpoint--------


 10%|████████                                                                        | 5/50 [10:43<1:41:11, 134.92s/it]

Epoch: [ 5/50] [  50/ 115] time: 1.1826s, d_loss: -0.309, g_loss: -1.103
Epoch: [ 5/50] [ 100/ 115] time: 0.9275s, d_loss: -0.329, g_loss: -0.645


 12%|█████████▌                                                                      | 6/50 [12:36<1:34:13, 128.50s/it]

Epoch: [ 6/50] [  50/ 115] time: 0.8613s, d_loss: -0.567, g_loss: -0.943
Epoch: [ 6/50] [ 100/ 115] time: 0.9921s, d_loss: -0.360, g_loss: -1.780


 14%|███████████▏                                                                    | 7/50 [14:28<1:28:28, 123.46s/it]

Epoch: [ 7/50] [  50/ 115] time: 0.8182s, d_loss: -0.539, g_loss: -2.675
Epoch: [ 7/50] [ 100/ 115] time: 1.4428s, d_loss: -0.459, g_loss: -2.821


 16%|████████████▊                                                                   | 8/50 [16:14<1:22:42, 118.15s/it]

Epoch: [ 8/50] [  50/ 115] time: 2.1572s, d_loss: -0.487, g_loss: -2.750
Epoch: [ 8/50] [ 100/ 115] time: 0.9761s, d_loss: -0.367, g_loss: -1.936


 18%|██████████████▍                                                                 | 9/50 [18:29<1:24:16, 123.33s/it]

Epoch: [ 9/50] [  50/ 115] time: 0.9255s, d_loss: -0.473, g_loss: -1.894
Epoch: [ 9/50] [ 100/ 115] time: 1.1546s, d_loss: -0.371, g_loss: -3.499
-----success saved checkpoint--------


 20%|███████████████▊                                                               | 10/50 [20:39<1:23:30, 125.27s/it]

new lr 0.000050
Epoch: [10/50] [  50/ 115] time: 1.0759s, d_loss: -0.517, g_loss: -1.264
Epoch: [10/50] [ 100/ 115] time: 0.8137s, d_loss: -0.491, g_loss: -1.453


 22%|█████████████████▍                                                             | 11/50 [22:33<1:19:20, 122.07s/it]

Epoch: [11/50] [  50/ 115] time: 0.8588s, d_loss: -0.568, g_loss: -1.886
Epoch: [11/50] [ 100/ 115] time: 0.8337s, d_loss: -0.432, g_loss: -1.032


 24%|██████████████████▉                                                            | 12/50 [24:06<1:11:37, 113.09s/it]

Epoch: [12/50] [  50/ 115] time: 0.6126s, d_loss: -0.497, g_loss: -1.441
Epoch: [12/50] [ 100/ 115] time: 0.7585s, d_loss: -0.474, g_loss: -1.132


 26%|████████████████████▌                                                          | 13/50 [25:30<1:04:32, 104.65s/it]

Epoch: [13/50] [  50/ 115] time: 0.7746s, d_loss: -0.497, g_loss: -1.238
Epoch: [13/50] [ 100/ 115] time: 1.6103s, d_loss: -0.363, g_loss: -1.249


 28%|██████████████████████                                                         | 14/50 [27:48<1:08:47, 114.65s/it]

Epoch: [14/50] [  50/ 115] time: 1.4972s, d_loss: -0.478, g_loss: -1.672
Epoch: [14/50] [ 100/ 115] time: 0.9425s, d_loss: -0.435, g_loss: -1.744
-----success saved checkpoint--------


 30%|███████████████████████▋                                                       | 15/50 [30:42<1:17:15, 132.45s/it]

Epoch: [15/50] [  50/ 115] time: 0.5475s, d_loss: -0.575, g_loss: -1.386
Epoch: [15/50] [ 100/ 115] time: 0.8017s, d_loss: -0.576, g_loss: -1.745


 32%|█████████████████████████▎                                                     | 16/50 [32:07<1:06:51, 118.00s/it]

Epoch: [16/50] [  50/ 115] time: 0.6990s, d_loss: -0.618, g_loss: -1.224
Epoch: [16/50] [ 100/ 115] time: 0.5318s, d_loss: -0.618, g_loss: -1.740


 34%|███████████████████████████▌                                                     | 17/50 [33:22<57:46, 105.05s/it]

Epoch: [17/50] [  50/ 115] time: 0.6852s, d_loss: -0.505, g_loss: -1.233
Epoch: [17/50] [ 100/ 115] time: 0.9204s, d_loss: -0.511, g_loss: -1.344


 36%|█████████████████████████████▌                                                    | 18/50 [34:46<52:44, 98.90s/it]

Epoch: [18/50] [  50/ 115] time: 0.6457s, d_loss: -0.695, g_loss: -1.500
Epoch: [18/50] [ 100/ 115] time: 0.6843s, d_loss: -0.464, g_loss: -1.836


 38%|███████████████████████████████▏                                                  | 19/50 [36:09<48:36, 94.09s/it]

Epoch: [19/50] [  50/ 115] time: 0.8287s, d_loss: -0.522, g_loss: -1.706
Epoch: [19/50] [ 100/ 115] time: 0.8864s, d_loss: -0.529, g_loss: -1.964
-----success saved checkpoint--------


 40%|████████████████████████████████▊                                                 | 20/50 [37:50<48:01, 96.03s/it]

new lr 0.000025


## Testing

### Define testing data interator

In [None]:
def data_iterator_test(filenames, batch_size):
  data = pd.read_pickle(filenames)
  captions = data['Captions'].values
  caption = []
  for i in range(len(captions)):
    caption.append(captions[i])
  caption = np.asarray(caption)
  index = data['ID'].values
  index = np.asarray(index)

  dataset = tf.data.Dataset.from_tensor_slices((caption, index))
  dataset = dataset.repeat()
  dataset = dataset.batch(batch_size)

  iterator = dataset.make_initializable_iterator()
  output_types = dataset.output_types
  output_shapes = dataset.output_shapes

  return iterator, output_types, output_shapes

In [None]:
tf.reset_default_graph()
iterator_train, types, shapes = data_iterator_test(data_path + '/testData.pkl',
                                                   64)
iter_initializer = iterator_train.initializer
next_element = iterator_train.get_next()

with tf.Session() as sess:
  sess.run(iterator_train.initializer)
  next_element = iterator_train.get_next()
  caption, idex = sess.run(next_element)

### Inference test data

In [None]:
tf.reset_default_graph()
gan = GAN(
    get_hparas(),
    training_phase=False,
    dataset_path=data_path,
    ckpt_path=checkpoint_path,
    inference_path=inference_path,
    recover=epoch-1)
img = gan.inference()