In [2]:
import pandas as pd
import os
import tensorflow as tf

import scipy
from scipy.io import loadmat
import re

import string
import imageio
import numpy as np
import matplotlib.pyplot as plt
from utils import *
import random
import time
import nltk

import warnings
warnings.filterwarnings('ignore')

In [3]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path+'/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path+'/word2Id.npy'))
id2word_dict =  dict(np.load(dictionary_path+'/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s'%('flower', word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s'%('2428', id2word_dict['2428']))
print('Tokens: <PAD>: %s; <RARE>: %s'%(word2Id_dict['<PAD>'], word2Id_dict['<RARE>']))
print(word2Id_dict['<RARE>'])
print(id2word_dict['5428'])

there are 5427 vocabularies in total
Word to id mapping, for example: flower -> 1
Id to word mapping, for example: 2428 -> polkadots
Tokens: <PAD>: 5427; <RARE>: 5428
5428
<RARE>


In [4]:
def sent2IdList(line, MAX_SEQ_LENGTH=20):
    MAX_SEQ_LIMIT = MAX_SEQ_LENGTH
    padding = 0
    prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
    prep_line = prep_line.replace('-', ' ')
    tokens = []
    tokens.extend(nltk.tokenize.word_tokenize(prep_line.lower()))
    l = len(tokens)
    padding = MAX_SEQ_LIMIT - l
    for i in range(padding):
        tokens.append('<PAD>')
    line = [word2Id_dict[tokens[k]] if tokens[k] in word2Id_dict else word2Id_dict['<RARE>'] for k in range(len(tokens))]
    
    return line

text = "the flower shown has yellow anther red pistil and bright red petals."
print(text)
print(sent2IdList(text))
print(len(vocab))

the flower shown has yellow anther red pistil and bright red petals.
['9', '1', '82', '5', '11', '70', '20', '31', '3', '29', '20', '2', '5427', '5427', '5427', '5427', '5427', '5427', '5427', '5427']
5427


In [5]:
data_path = './dataset'
df = pd.read_pickle(data_path+'/text2ImgData.pkl')
num_training_sample = len(df)
n_images_train = num_training_sample
print('There are %d image in training data'%(n_images_train))

There are 7370 image in training data


In [6]:
import math
IMAGE_HEIGHT = 64
IMAGE_WIDTH = 64
IMAGE_DEPTH = 3
def training_data_generator(caption, image_path):
    # load in the image according to image path
    imagefile = tf.read_file(image_path)
    image = tf.image.decode_image(imagefile, channels=3)
    float_img = tf.image.convert_image_dtype(image, tf.float32)
    float_img.set_shape([None, None, 3])
    image = float_img
    #degrees = tf.random_uniform((1,), 0,10)
    #image = tf.contrib.image.rotate(image, degrees * math.pi / 180, interpolation='BILINEAR')
    image = tf.image.resize_images(image, size = [80, 80])
    image = tf.random_crop(image, size=[64, 64, 3])
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    #image = tf.image.random_brightness(image, max_delta=0.1)
    
    image = tf.minimum(image, 1.0)
    image = tf.maximum(image, 0.0)
    #degrees = tf.random_uniform((1,), 0,45)
    #image = tf.contrib.image.rotate(image, degrees * math.pi / 180, interpolation='BILINEAR')
    #image = tf.image.flip_left_right(image)
    image = (image*2) - 1 
    image.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_DEPTH])
    
    return image, caption

def data_iterator(filenames, batch_size, data_generator):
    # Load the training data into two NumPy arrays
    df = pd.read_pickle(filenames)
    captions = df['Captions'].values
    caption = []
    for i in range(len(captions)):
        word = random.choice(captions[i])
        caption.append(word ) 
    caption = np.asarray(caption)
    image_path = df['ImagePath'].values

    # Assume that each row of `features` corresponds to the same row as `labels`.
    assert caption.shape[0] == image_path.shape[0]
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, image_path))
    dataset = dataset.shuffle(7370)
    dataset = dataset.map(data_generator)
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    
    iterator = dataset.make_initializable_iterator()
    output_types = dataset.output_types
    output_shapes = dataset.output_shapes
    
    return iterator, output_types, output_shapes

In [67]:
def get_length(sequence):
    mask = tf.logical_not(tf.equal(sequence, word2Id_dict['<PAD>']))
    print(mask.shape)
    length = tf.reduce_sum(tf.cast(mask,tf.int32), 1)
    #length = tf.cast(length, tf.int32)
    return length
def last_relevant(output, length):
    batch_size = tf.shape(output)[0]
    max_length = tf.shape(output)[1]
    out_size = int(output.get_shape()[2])
    index = tf.range(0, batch_size) * max_length + (length - 1)
    flat = tf.reshape(output, [-1, out_size])
    relevant = tf.gather(flat, index)
    return relevant
class TextEncoder:
    """
    Encode text (a caption) into hidden representation
    input: text (a list of id)
    output: hidden representation of input text in dimention of TEXT_DIM
    """
    def __init__(self, text, hparas, training_phase=True, reuse=False, return_embed=False):
        self.text = text
        self.hparas = hparas
        self.train = training_phase
        self.reuse = reuse
        self._build_model()
    def _build_model(self):
        with tf.variable_scope('rnnftxt', reuse=self.reuse):
            # Word embedding
            txt_len = get_length(self.text)
            word_embed_matrix = tf.get_variable('rnn/wordembed', 
                                                shape=(self.hparas['VOCAB_SIZE'], self.hparas['EMBED_DIM']),
                                                initializer=tf.random_normal_initializer(stddev=0.02),
                                                dtype=tf.float32)
            embedded_word_ids = tf.nn.embedding_lookup(word_embed_matrix, self.text)
            # RNN encoder
            LSTMCell = tf.nn.rnn_cell.LSTMCell(self.hparas['TEXT_DIM'], 
                                               initializer=tf.random_normal_initializer(stddev=0.02), 
                                               reuse=self.reuse)
            initial_state = LSTMCell.zero_state(self.hparas['BATCH_SIZE'], dtype=tf.float32)
            rnn_net = tf.nn.dynamic_rnn(cell=LSTMCell, 
                                        inputs=embedded_word_ids, 
                                        initial_state=initial_state, 
                                        dtype=np.float32, time_major=False,
                                        scope='rnn/dynamic',
                                        sequence_length=txt_len)
            self.rnn_net = rnn_net
            self.outputs = last_relevant(rnn_net[0], txt_len)

In [68]:

class Generator:
    def __init__(self, noise_z, text, training_phase, hparas, reuse, is_train, p):
        self.z = noise_z
        self.text = text
        self.train = training_phase
        self.hparas = hparas
        self.gf_dim = 128
        self.reuse = reuse
        self.k_init = tf.random_normal_initializer(stddev=0.02)
        self.is_train = is_train
        self.keep_prob = p
        self._build_model()
    def Encode(self, x):

        with tf.variable_scope('encode', reuse=self.reuse) as scope:
            ep = tf.random_normal(shape=[64, 128])
            
            fc1 = tf.nn.relu(tf.layers.dense(x, 1024, name='z_fc1', 
                                         kernel_initializer=self.k_init, reuse=self.reuse))
            fc2 = tf.nn.relu(tf.layers.dense(x, 1024, name='z_fc2', 
                                         kernel_initializer=self.k_init, reuse=self.reuse))
            self.z_mean = tf.layers.dense(fc2 , 128, name='z_mean', 
                                         kernel_initializer=self.k_init, reuse=self.reuse)
            self.z_sigma = tf.layers.dense(fc2 , 128, name='z_sigma', 
                                         kernel_initializer=self.k_init, reuse=self.reuse)
            z_x = tf.add(self.z_mean, tf.sqrt(tf.exp(self.z_sigma))*ep)

            return z_x    
    def _build_model(self):
        with tf.variable_scope('generator', reuse=self.reuse):
            gf_dim = 64
            g_init = tf.random_normal_initializer(1., 0.2)
            #text_flatten = tf.layers.flatten(self.text)
            text_input = tf.layers.dense(self.text, self.hparas['TEXT_DIM'], name='generator/text_input', 
                                         kernel_initializer=self.k_init, reuse=self.reuse)
            # 
            z_text_concat = tf.concat([self.z, text_input], axis=1, name='generator/z_text_concat')
            #z_text_concat = self.Encode(z_text_concat)
            pre_conv = tf.layers.dense(z_text_concat, gf_dim * 8 * 4 * 4, kernel_initializer=self.k_init, 
                                       name='generator/pre_dense', reuse=self.reuse)
            pre_conv = tf.contrib.gan.features.VBN(pre_conv, gamma_initializer=g_init,name= 'generator/bn0')(pre_conv)
            pre_conv = tf.reshape(pre_conv, [-1, 4, 4, gf_dim*8],name='generator/to_conv')
            pre_conv = tf.nn.selu(pre_conv, name='generator/act_pre_conv1')
            #pre_conv = tf.nn.relu(pre_conv, name='generator/act_pre_conv')
            res0 = tf.keras.layers.UpSampling2D(size=(2, 2))(pre_conv)
            res0 = tf.layers.conv2d(res0, int(res0.shape[-1]), 1, 1, kernel_initializer=self.k_init, reuse=self.reuse, name='res0')
            
            upconv_1 = tf.layers.conv2d_transpose(pre_conv, gf_dim * 8, 5, 2, padding='same', 
                                                 kernel_initializer=self.k_init, name='generator/upconv1',
                                                  reuse=self.reuse)
            upconv_1 = tf.contrib.gan.features.VBN(upconv_1, gamma_initializer=g_init, name= 'generator/bn1')(upconv_1)
            upconv_1 = tf.nn.selu(upconv_1, name='generator/act_up_conv1') + tf.nn.selu(res0)
            res1 = tf.keras.layers.UpSampling2D(size=(2, 2))(upconv_1)
            res1 = tf.layers.conv2d(res1, int(res1.shape[-1])//2, 1, 1, kernel_initializer=self.k_init, reuse=self.reuse, name='res1')
            #upconv_1 = tf.nn.dropout(upconv_1,self.keep_prob,name='generator/drop1')
            upconv_2 = tf.layers.conv2d_transpose(upconv_1, gf_dim * 4, 5, 2, padding='same', 
                                                 kernel_initializer=self.k_init, name='generator/upconv2',
                                                  reuse=self.reuse)
            upconv_2 = tf.contrib.gan.features.VBN(upconv_2, gamma_initializer=g_init, name= 'generator/bn2')(upconv_2)
            upconv_2 = tf.nn.selu(upconv_2, name='generator/act_up_conv2') + tf.nn.selu(res1)
            res2 = tf.keras.layers.UpSampling2D(size=(2, 2))(upconv_2)
            res2 = tf.layers.conv2d(res2, int(res2.shape[-1])//2, 1, 1, kernel_initializer=self.k_init, reuse=self.reuse, name='res2')
            
            upconv_3 = tf.layers.conv2d_transpose(upconv_2, gf_dim*2, 5, 2, padding='same', 
                                                 kernel_initializer=self.k_init, name='generator/upconv3',
                                                  reuse=self.reuse)
            upconv_3 = tf.contrib.gan.features.VBN(upconv_3, gamma_initializer=g_init, name= 'generator/bn3')(upconv_3)
            upconv_3 = tf.nn.selu(upconv_3, name='generator/act_up_conv3') + tf.nn.selu(res2)
            res3 = tf.keras.layers.UpSampling2D(size=(2, 2))(upconv_3)
            res3 = tf.layers.conv2d(res3, 16, 1, 1, kernel_initializer=self.k_init, reuse=self.reuse, name='res3')
            #upconv_3 = tf.nn.dropout(upconv_3, self.keep_prob,name='generator/drop2')
            
            upconv_4 = tf.layers.conv2d_transpose(upconv_3, 16, 5, 2, padding='same', 
                                                 kernel_initializer=self.k_init, name='generator/upconv4',
                                                  reuse=self.reuse)
            upconv_4 = tf.contrib.gan.features.VBN(upconv_4, gamma_initializer=g_init, name= 'generator/bn4')(upconv_4)
            upconv_4 = tf.nn.selu(upconv_4, name='generator/act_up_conv4') 
            upconv_4 = tf.layers.conv2d_transpose(upconv_4, 3, 1, 1, padding='same', 
                                                 kernel_initializer=self.k_init, name='generator/upconv4_2',
                                                  reuse=self.reuse)
            #upconv_4 = tf.layers.batch_normalization(upconv_4, training=self.is_train, gamma_initializer=g_init, name= 'generator/bn4')
            upconv_fn = tf.identity(upconv_4, name='generator/act_up_logit')
            
            g_net = tf.nn.tanh(upconv_fn, name='generator/act_up_convfn')
            
            self.generator_net = g_net
            self.outputs = g_net

In [69]:
# resnet structure
class Discriminator:
    def __init__(self, image, text, training_phase, hparas, reuse, is_train):
        self.image = image
        self.text = text
        self.train = training_phase
        self.hparas = hparas
        self.df_dim = 64 # 196 for MSCOCO
        self.reuse = reuse
        self.k_init = tf.random_normal_initializer(stddev=0.02)
        self.is_train = is_train
        self.act_fn = tf.nn.leaky_relu
        
        self._build_model()
        
    
    def _build_model(self):        
        with tf.variable_scope('discriminator', reuse=self.reuse):
            g_init = tf.random_normal_initializer(1., 0.1)
            edge = tf.reduce_mean(self.image, -1,keepdims=True)
            kernel_h = np.array([ [1,2,1], [0,0,0], [-1,-2,-1] ])/4
            kernel_v = np.array([ [1,0,-1], [2,0,-2], [1,0,-1] ])/4
            conv_w_h = tf.constant(kernel_h, dtype=tf.float32, shape=(3, 3, 1, 1),name='Const_h')
            conv_w_v = tf.constant(kernel_v, dtype=tf.float32, shape=(3, 3, 1, 1),name='Const_v')    
            edge1 = tf.nn.conv2d(input=edge, filter=conv_w_h, strides=[1, 1, 1, 1], padding='SAME')
            edge2 = tf.nn.conv2d(input=edge, filter=conv_w_v, strides=[1, 1, 1, 1], padding='SAME')
            self.image = tf.concat([self.image, (edge1+edge2)/2], 3)
            conv_1 = tf.layers.conv2d(self.image, self.df_dim, 5, 2, padding='same', 
                                                 kernel_initializer=self.k_init, name='discriminator/conv1',
                                                  reuse=self.reuse)
            conv_1 = tf.contrib.layers.layer_norm(conv_1)
            conv_1 = self.act_fn(conv_1, name='discriminator/act_conv1')
            res0 = tf.layers.conv2d(conv_1, self.df_dim*4, 1, 1, kernel_initializer=self.k_init, reuse=self.reuse, name='res0')
            res0 = tf.layers.average_pooling2d(res0,5,2,padding='same')

            conv_2 = tf.layers.conv2d(conv_1, self.df_dim*4, 5, 2, padding='same', 
                                                 kernel_initializer=self.k_init, name='discriminator/conv2',
                                                  reuse=self.reuse)
            conv_2 = tf.contrib.layers.layer_norm(conv_2)
            conv_2 = self.act_fn(conv_2, name='discriminator/act_conv2')
            res1 = tf.layers.conv2d(conv_2, self.df_dim*6, 1, 1, kernel_initializer=self.k_init, reuse=self.reuse, name='res1')
            res1 = tf.layers.average_pooling2d(res1,5,2,padding='same')
            conv_3 = tf.layers.conv2d(conv_2, self.df_dim*6, 5, 2, padding='same', 
                                                 kernel_initializer=self.k_init, name='discriminator/conv3',
                                                  reuse=self.reuse)
            conv_3 = tf.contrib.layers.layer_norm(conv_3)
            conv_3 = self.act_fn(conv_3, name='discriminator/act_conv3')
            
            # multi-task branch
            dim_3 = tf.layers.conv2d(conv_3, self.df_dim*3, 5, 2, padding='same', 
                                                 kernel_initializer=self.k_init, name='discriminator/dim3_1',
                                                  reuse=self.reuse)
            dim_3 = tf.contrib.layers.layer_norm(dim_3)
            dim_3 = self.act_fn(dim_3)
            dim_3 = tf.layers.conv2d(dim_3, self.df_dim, 1, 1, padding='same', 
                                                 kernel_initializer=self.k_init, name='discriminator/dim3_2',
                                                  reuse=self.reuse)
            dim_3 = tf.contrib.layers.layer_norm(dim_3)
            dim_3 = self.act_fn(dim_3)
            flat_3 = tf.layers.flatten(dim_3,'f_conv')
            self._to_rnn = tf.layers.dense(flat_3, 128, kernel_initializer=self.k_init, 
                                       name='discriminator/_to_rnn', reuse=self.reuse)
            
            # Text part
            text_flat = tf.layers.flatten(self.text,'f_text')
            text_in = tf.layers.dense(text_flat, self.df_dim*6, kernel_initializer=self.k_init, 
                                       name='discriminator/pre_dense', reuse=self.reuse)   
            text_in = self.act_fn(text_in)
            text_in = tf.expand_dims(tf.expand_dims(text_in,1),1)
            text_in = tf.tile(text_in, [1, conv_3.shape[1],conv_3.shape[2], 1])
            feature_concat = tf.concat([text_in, conv_3], 3)
            feature_concat = tf.contrib.layers.layer_norm(feature_concat)
            
            
            conv_4 = tf.layers.conv2d(feature_concat, self.df_dim*6, 1, 1, padding='same', 
                                                 kernel_initializer=self.k_init, name='discriminator/conv4',
                                                  reuse=self.reuse)
            conv_4 = tf.contrib.layers.layer_norm(conv_4)
            conv_4 = self.act_fn(conv_4, name='discriminator/act_conv4')
            
            conv_5 = tf.layers.conv2d(conv_4, self.df_dim*8, 2, 2, padding='same', 
                                                 kernel_initializer=self.k_init, name='discriminator/conv5',
                                                  reuse=self.reuse)
            conv_5 = tf.contrib.layers.layer_norm(conv_5)
            conv_5 = self.act_fn(conv_5, name='discriminator/act_conv5')
            
            conv_5_flat = tf.layers.flatten(conv_5,'f_text')
            self.logits = tf.layers.dense(conv_5_flat, 1, kernel_initializer=self.k_init, 
                                       name='discriminator/final', reuse=self.reuse)              
            self.discriminator_net = self.logits
            self.outputs = self.logits

In [70]:
def get_hparas():
    hparas = {
        'MAX_SEQ_LENGTH' : 20,
        'EMBED_DIM' : 128, # word embedding dimension
        'VOCAB_SIZE' : len(vocab)+1,
        'TEXT_DIM' : 128, # text embedding dimension
        'RNN_HIDDEN_SIZE' : 128,
        'Z_DIM' : 128, # random noise z dimension
        'IMAGE_SIZE' : [64, 64, 3], # render image size
        'BATCH_SIZE' : 64,
        'LR' : 2e-4,
        'BETA' : 0., # AdamOptimizer parameter
        'N_EPOCH' : 800*5,
        'N_SAMPLE' : num_training_sample
    }
    return hparas

In [None]:
class GAN:
    def __init__(self, hparas, training_phase, dataset_path, ckpt_path, inference_path, recover=None):
        self.hparas = hparas
        self.train = training_phase
        self.dataset_path = dataset_path # dataPath+'/text2ImgData.pkl'
        self.ckpt_path = ckpt_path
        self.sample_path = './samples'
        self.inference_path = './inference'
        self.train_compare = False
        
        self._get_session() # get session
        self._get_train_data_iter() # initialize and get data iterator
        self._input_layer() # define input placeholder
        self._get_inference() # build generator and discriminator
        self._get_loss() # define gan loss
        self._get_var_with_name() # get variables for each part of model
        self._optimize() # define optimizer
        self._init_vars()
        self._get_saver()
        self.fixed_sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(self.hparas['BATCH_SIZE'], self.hparas['Z_DIM'])).astype(np.float32)
        if recover is not None:
            self._load_checkpoint(recover)
            
            
        
    def _get_train_data_iter(self):
        if self.train: # training data iteratot
            iterator_train, types, shapes = data_iterator(self.dataset_path+'/text2ImgData.pkl',
                                                          self.hparas['BATCH_SIZE'], training_data_generator)
            iter_initializer = iterator_train.initializer
            self.next_element = iterator_train.get_next()
            self.sess.run(iterator_train.initializer)
            self.iterator_train = iterator_train
        else: # testing data iterator
            iterator_test, types, shapes = data_iterator_test(self.dataset_path+'/testData.pkl', self.hparas['BATCH_SIZE'])
            iter_initializer = iterator_test.initializer
            self.next_element = iterator_test.get_next()
            self.sess.run(iterator_test.initializer)
            self.iterator_test = iterator_test
            
    def _input_layer(self):
        if self.train:
            self.real_image = tf.placeholder('float32',
                                              [self.hparas['BATCH_SIZE'], self.hparas['IMAGE_SIZE'][0],
                                               self.hparas['IMAGE_SIZE'][1], self.hparas['IMAGE_SIZE'][2]],
                                              name='real_image')
            self.noise_level = tf.placeholder('float32', shape=(),name='n_level')
            self.real_image_n = self.real_image #+ tf.random_normal(shape=tf.shape(self.real_image), mean=0.0, stddev=self.noise_level, dtype=tf.float32)
            self.caption = tf.placeholder(dtype=tf.int64, shape=[self.hparas['BATCH_SIZE'], None], name='caption')
            self.z_noise = tf.placeholder(tf.float32, [self.hparas['BATCH_SIZE'], self.hparas['Z_DIM']], name='z_noise')
            self.training_flags = tf.placeholder(dtype=tf.bool, name='is_train')
            self.keep_prob = tf.placeholder(dtype=tf.float32, name='k_prob')
             
        else:
            self.caption = tf.placeholder(dtype=tf.int64, shape=[self.hparas['BATCH_SIZE'], None], name='caption')
            self.z_noise = tf.placeholder(tf.float32, [self.hparas['BATCH_SIZE'], self.hparas['Z_DIM']], name='z_noise')
            self.training_flags = tf.placeholder(dtype=tf.bool, name='is_train')
            self.keep_prob = tf.placeholder(dtype=tf.float32, name='k_prob')
    def _get_inference(self):
        if self.train:
            # GAN training
            # encoding text
            text_encoder = TextEncoder(self.caption, hparas = self.hparas, training_phase=True, reuse=False)
            self.text_encoder = text_encoder
            print(text_encoder.outputs.shape)
            # generating image
            generator = Generator(self.z_noise, text_encoder.outputs, training_phase=True,
                                  hparas=self.hparas, reuse=False, is_train = self.training_flags,p=self.keep_prob)
            self.generator = generator
            print(generator.outputs.shape)
            
            # discriminize
            # fake image
            fake_discriminator = Discriminator(generator.outputs, text_encoder.outputs,
                                               training_phase=True, hparas=self.hparas, reuse=False, is_train = self.training_flags)
            self.fake_discriminator = fake_discriminator
            # real image
            real_discriminator = Discriminator(self.real_image_n, text_encoder.outputs, training_phase=True,
                                              hparas=self.hparas, reuse=True, is_train = self.training_flags)
            # Wrong caption
            self.roll_caps = tf.roll(text_encoder.outputs, 10, 0)
            #if self.train_compare:
            #    print('Training with mismatch samples')
            self.wrong_discriminator = Discriminator(self.real_image_n, self.roll_caps, training_phase=True,
                                             hparas=self.hparas, reuse=True, is_train = self.training_flags)
            
            
            self.real_discriminator = real_discriminator
            #alph = tf.random.uniform([self.hparas['BATCH_SIZE'], 1, 1, 1], 0., 1.)
            difference = self.generator.outputs - self.real_image_n
            interpolates = 0.5*self.real_image_n + 0.5*difference
            interpolates = tf.reshape(interpolates, self.real_image.shape)
            text_embd = text_encoder.outputs
            inter_discriminator = Discriminator(interpolates, text_embd, training_phase=True,
                                             hparas=self.hparas, reuse=True, is_train = self.training_flags)
            l_gp = tf.gradients(self.real_discriminator.outputs, [self.real_image_n])
            gradients = l_gp[0] #+ 1e-8
            #gradients_embd = l_gp[1]
            slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=list(range(1, self.real_image_n.shape.ndims))))
            #slopes_embd = tf.sqrt(tf.reduce_sum(tf.square(gradients_embd), axis=list(range(1, text_embd.shape.ndims))))
            self.gradient_penalty = tf.reduce_mean(tf.maximum(0.0, slopes)**2) #+ tf.reduce_mean(tf.maximum(0.0, slopes_embd)**2)
            
        else: # inference mode
            
            self.text_embed = TextEncoder(self.caption, hparas=self.hparas, training_phase=False, reuse=False)
            self.generate_image_net = Generator(self.z_noise, self.text_embed.outputs, training_phase=False,
                                                hparas=self.hparas, reuse=False, is_train = self.training_flags, p=self.keep_prob)
    def KL_loss(self, mu, log_var):
        return -0.5 * tf.reduce_sum(1 + log_var - tf.pow(mu, 2) - tf.exp(log_var))
        
    def _get_loss(self):
        if self.train:
#             d_loss1 =  tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.real_discriminator.logits,
#                                                                             labels=0.9*tf.ones_like(self.real_discriminator.logits),
#                                                                             name='d_loss1'))
            d_loss1 = tf.reduce_mean(self.real_discriminator.logits)
#             d_loss2 =  tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_discriminator.logits,
#                                                                             labels=tf.zeros_like(self.fake_discriminator.logits),
#                                                                             name='d_loss2'))
            d_loss2 = tf.reduce_mean(self.fake_discriminator.logits)
            
            d_loss3 = tf.reduce_mean(self.wrong_discriminator.logits)
#             d_loss3 =  tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.wrong_discriminator.logits,
#                                                                             labels=tf.zeros_like(self.fake_discriminator.logits),
#                                                                             name='d_loss2'))            
            #if self.train_compare:
            #d_loss3 = tf.reduce_mean(self.wrong_discriminator.logits)
            #self.d_loss = (d_loss2 + d_loss3)/2 - d_loss1 + 10 * self.gradient_penalty
#             self.g_loss =  tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.fake_discriminator.logits,
#                                                                                  labels=tf.ones_like(self.fake_discriminator.logits),
#                                                                                 name='g_loss'))
            #d_loss3 = tf.reduce_mean(self.wrong_discriminator.logits)
            #self.rnn_loss = d_loss3 - d_loss1
            word_sim_T = tf.reduce_sum(self.real_discriminator._to_rnn * self.text_encoder.outputs,1)/(tf.sqrt(tf.reduce_sum(self.real_discriminator._to_rnn* self.real_discriminator._to_rnn, 1)+1e-8) * tf.sqrt(tf.reduce_sum(self.text_encoder.outputs* self.text_encoder.outputs, 1)+1e-8))
            word_sim_F = tf.reduce_sum(self.wrong_discriminator._to_rnn * self.roll_caps,1)/(tf.sqrt(tf.reduce_sum(self.wrong_discriminator._to_rnn* self.wrong_discriminator._to_rnn, 1)+1e-8) *  tf.sqrt(tf.reduce_sum(self.roll_caps* self.roll_caps, 1)+1e-8))
            self.rnn_loss = tf.reduce_mean(tf.maximum(0., 0.5-word_sim_T+word_sim_F))
            #self.kl_loss = self.KL_loss(self.generator.z_mean, self.generator.z_sigma)
            self.d_loss = (d_loss2)  -  d_loss1 + self.rnn_loss +10 * self.gradient_penalty# + self.kl_loss
            self.g_loss = -d_loss2 #+ self.kl_loss/(128*64)
                        

    def _optimize(self):
        if self.train:
            with tf.variable_scope('learning_rate'):
                self.lr_var = tf.Variable(self.hparas['LR'], trainable=False)

            discriminator_optimizer = tf.train.AdamOptimizer(self.lr_var, beta1=self.hparas['BETA'])
            #discriminator_optimizer = tf.contrib.opt.AdamWOptimizer(0.005, self.lr_var, beta1=self.hparas['BETA'])
            #discriminator_optimizer = tf.train.RMSPropOptimizer(self.lr_var)
            generator_optimizer = tf.train.AdamOptimizer(self.lr_var, beta1=self.hparas['BETA'])
            #rnn_optimizer = tf.train.AdamOptimizer(self.lr_var, beta1=self.hparas['BETA'])
            #generator_optimizer = tf.contrib.opt.AdamWOptimizer(0.0005, self.lr_var, beta1=self.hparas['BETA'])
            #generator_optimizer = tf.train.RMSPropOptimizer(self.lr_var)
            self.d_optim = discriminator_optimizer.minimize(self.d_loss, var_list=self.discrim_vars)
            self.d_optim_nornn =  discriminator_optimizer.minimize(self.d_loss, var_list=self.discrim_vars)
            self.g_optim = generator_optimizer.minimize(self.g_loss, var_list=self.generator_vars)
            #self.rnn_optim = rnn_optimizer.minimize(self.rnn_loss, var_list=self.text_encoder_vars)
            # mix_loss
            grads, _ = tf.clip_by_global_norm(tf.gradients(self.rnn_loss, self.text_encoder_vars), 10)
            rnn_optimizer = tf.train.AdamOptimizer(self.lr_var, beta1=self.hparas['BETA'])
            self.rnn_optim = rnn_optimizer.apply_gradients(zip(grads, self.text_encoder_vars))
            
            self.weight_clip_ops = []

#             for var in self.discrim_vars:            
#                 self.weight_clip_ops.append(var.assign(tf.clip_by_value(var, -0.01, 0.01)))
    def training(self):
        n_level = 0.01
        g_loss = d_loss = 0
        for _epoch in range(self.hparas['N_EPOCH']):
            start_time = time.time()
            n_critic = 1
            n_batch_epoch = int(self.hparas['N_SAMPLE']/self.hparas['BATCH_SIZE'])
            current_num = n_critic
            if _epoch < 60:
                n_level = n_level
            else:
                n_level = n_level - 0.001   
            n_level = np.maximum(0.0, n_level)
            for _step in range(n_batch_epoch):
                if current_num == 0:
                    current_num = n_critic
                step_time = time.time()
                image_batch, caption_batch = self.sess.run(self.next_element)
                b_z = np.random.normal(loc=0.0, scale=1.0, 
                                       size=(self.hparas['BATCH_SIZE'], self.hparas['Z_DIM'])).astype(np.float32)

                # update discriminator
                #if _epoch < 60:
                self.discriminator_error, _ = self.sess.run([self.d_loss, self.d_optim],
                                                           feed_dict={
                                                                self.real_image:image_batch,
                                                                self.caption:caption_batch,
                                                                self.z_noise:b_z,
                                                                self.training_flags:True,
                                                                self.noise_level: n_level, 
                                                                self.keep_prob:1})
#                 else:
#                     self.discriminator_error, _ = self.sess.run([self.d_loss, self.d_optim_nornn],
#                                                                feed_dict={
#                                                                     self.real_image:image_batch,
#                                                                     self.caption:caption_batch,
#                                                                     self.z_noise:b_z,
#                                                                     self.training_flags:True,
#                                                                     self.noise_level: n_level, 
#                                                                     self.keep_prob:1})
                d_loss = self.discriminator_error
#                     self.rnn_error, _ = self.sess.run([self.rnn_loss, self.rnn_optim],
#                                                                feed_dict={
#                                                                     self.real_image:image_batch,
#                                                                     self.caption:caption_batch,
#                                                                     self.z_noise:b_z,
#                                                                     self.training_flags:True,
#                                                                     self.keep_prob:1})
                    #rnn_loss = self.rnn_error
                if current_num == 1:
                    self.generator_error, _ = self.sess.run([self.g_loss, self.g_optim],
                                       feed_dict={self.caption: caption_batch, 
                                                  self.z_noise : b_z,
                                                  self.training_flags:True,
                                                  self.noise_level: n_level,
                                                  self.keep_prob:0.8})
                    g_loss = self.generator_error
                self.rnn_optim
                self.rnn_error, _ = self.sess.run([self.rnn_loss, self.rnn_optim],
                                       feed_dict={self.real_image:image_batch,
                                                  self.caption:caption_batch,
                                                  self.z_noise:b_z,
                                                  self.training_flags:True,
                                                  self.noise_level: n_level, 
                                                  self.keep_prob:1})
                #n_level -= 1./self.hparas['N_EPOCH']/int(self.hparas['N_SAMPLE']/self.hparas['BATCH_SIZE'])                 
                
                
                if _step%50==0:
                    print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4fs, d_loss: %.3f, g_loss: %.3f, rnn_loss: %.3f" \
                            % (_epoch, self.hparas['N_EPOCH'], _step, n_batch_epoch, time.time() - step_time,
                               d_loss, g_loss, self.rnn_error))
            if _epoch != 0 and (_epoch+1)%1==0:
                self._save_checkpoint(_epoch)
                self._sample_visiualize(_epoch)
            
    def inference(self):
        for _iters in range(100):
            caption, idx = self.sess.run(self.next_element)
            z_seed = np.random.normal(loc=0.0, scale=1.0, size=(self.hparas['BATCH_SIZE'], self.hparas['Z_DIM'])).astype(np.float32)

            img_gen, rnn_out = self.sess.run([self.generate_image_net.outputs, self.text_embed.outputs],
                                             feed_dict={self.caption : caption, self.z_noise : z_seed, self.training_flags:False,self.keep_prob:1})
            img_gen = (img_gen + 1)/2
            for i in range(self.hparas['BATCH_SIZE']):
                scipy.misc.imsave(self.inference_path+'/inference_{:04d}.png'.format(idx[i]), img_gen[i])
                
    def _init_vars(self):
        self.sess.run(tf.global_variables_initializer())
    
    def _get_session(self):
        self.sess = tf.Session()
    
    def _get_saver(self):
        if self.train:
            self.rnn_saver = tf.train.Saver(var_list=self.text_encoder_vars)
            self.g_saver = tf.train.Saver(var_list=self.generator_vars)
            self.d_saver = tf.train.Saver(var_list=self.discrim_vars)
        else:
            self.rnn_saver = tf.train.Saver(var_list=self.text_encoder_vars)
            self.g_saver = tf.train.Saver(var_list=self.generator_vars)
            
    def _sample_visiualize(self, epoch):
        ni = int(np.ceil(np.sqrt(self.hparas['BATCH_SIZE'])))
        sample_size = self.hparas['BATCH_SIZE']
        max_len = self.hparas['MAX_SEQ_LENGTH']
        
        sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(self.hparas['BATCH_SIZE'], self.hparas['Z_DIM'])).astype(np.float32)
        sample_sentence = ["the flower shown has yellow anther red pistil and bright red petals."]*int(sample_size/ni) + ["this flower has petals that are yellow, white and purple and has dark lines"]*int(sample_size/ni) + ["the petals on this flower are white with a yellow center"] * int(sample_size/ni) + ["this flower has a lot of small round pink petals."] * int(sample_size/ni) + ["this flower is orange in color, and has petals that are ruffled and rounded."] * int(sample_size/ni) + ["the flower has yellow petals and the center of it is brown."] * int(sample_size/ni) + ["this flower has petals that are blue and white."] * int(sample_size/ni) + ["these white flowers have petals that start off white in color and end in a white towards the tips."] * int(sample_size/ni)

        for i, sent in enumerate(sample_sentence):
            sample_sentence[i] = sent2IdList(sent, max_len)
            
        img_gen, rnn_out = self.sess.run([self.generator.outputs, self.text_encoder.outputs],
                                         feed_dict={self.caption : sample_sentence, self.z_noise : self.fixed_sample_seed, self.training_flags:False,self.keep_prob:1})
        save_images((img_gen + 1)/2, [ni, ni], self.sample_path+'/train_{:02d}.png'.format(epoch))
        
    def _get_var_with_name(self):
        t_vars = tf.trainable_variables()

        self.text_encoder_vars = [var for var in t_vars if 'rnn' in var.name]
        self.generator_vars = [var for var in t_vars if 'generator' in var.name]
        self.discrim_vars = [var for var in t_vars if 'discrim' in var.name]
    
    def _load_checkpoint(self, recover):
        if self.train:
            self.rnn_saver.restore(self.sess, self.ckpt_path+'rnn_model_'+str(recover)+'.ckpt')
            self.g_saver.restore(self.sess, self.ckpt_path+'g_model_'+str(recover)+'.ckpt')
            self.d_saver.restore(self.sess, self.ckpt_path+'d_model_'+str(recover)+'.ckpt')
        else:
            self.rnn_saver.restore(self.sess, self.ckpt_path+'rnn_model_'+str(recover)+'.ckpt')
            self.g_saver.restore(self.sess, self.ckpt_path+'g_model_'+str(recover)+'.ckpt')
        print('-----success restored checkpoint--------')
    
    def _save_checkpoint(self, epoch):
        self.rnn_saver.save(self.sess, self.ckpt_path+'rnn_model_'+str(epoch)+'.ckpt')
        self.g_saver.save(self.sess, self.ckpt_path+'g_model_'+str(epoch)+'.ckpt')
        self.d_saver.save(self.sess, self.ckpt_path+'d_model_'+str(epoch)+'.ckpt')
        print('-----success saved checkpoint--------')

In [None]:
tf.reset_default_graph()
checkpoint_path = './checkpoint/'
inference_path = './inference'
gan = GAN(get_hparas(), training_phase=True, dataset_path=data_path, ckpt_path=checkpoint_path, inference_path=inference_path)
gan.training()

(64, ?)
(64, 128)
(64, 64, 64, 3)


In [12]:
def data_iterator_test(filenames, batch_size):
    data = pd.read_pickle(filenames)
    captions = data['Captions'].values
    caption = []
    for i in range(len(captions)):
        caption.append(captions[i])
    caption = np.asarray(caption)
    index = data['ID'].values
    index = np.asarray(index)
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, index))
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size)
    
    iterator = dataset.make_initializable_iterator()
    output_types = dataset.output_types
    output_shapes = dataset.output_shapes
    
    return iterator, output_types, output_shapes
tf.reset_default_graph()
iterator_train, types, shapes = data_iterator_test(data_path+'/testData.pkl', 64)
iter_initializer = iterator_train.initializer
next_element = iterator_train.get_next()

with tf.Session() as sess:
    sess.run(iterator_train.initializer)
    next_element = iterator_train.get_next()
    caption, idex = sess.run(next_element)

In [13]:
tf.reset_default_graph()
gan = GAN(get_hparas(), training_phase=False, dataset_path=data_path, ckpt_path=checkpoint_path, inference_path=inference_path, recover=3)
img = gan.inference()

(64, ?)
INFO:tensorflow:Restoring parameters from ./checkpoint/rnn_model_3.ckpt
INFO:tensorflow:Restoring parameters from ./checkpoint/g_model_3.ckpt
-----success restored checkpoint--------
