# Train

In [None]:
import tensorflow as tf
import numpy as np
import model
import pickle
from os.path import join
import h5py
from Utils import image_processing
import scipy.misc
import random
import json
import os
import shutil


batch_size = 64
resume_model = './Data/Models/ep_10_nice1152.ckpt'
# resume_model = './Data/Models/temp.ckpt'

def initialize_uninitialized_vars(sess):
    from itertools import compress
    global_vars = tf.global_variables()
    is_not_initialized = sess.run([~(tf.is_variable_initialized(var)) \
                                   for var in global_vars])
    not_initialized_vars = list(compress(global_vars, is_not_initialized))

    if len(not_initialized_vars):
        sess.run(tf.variables_initializer(not_initialized_vars))
def load_training_data(data_dir):

    h = h5py.File(join(data_dir, 'flower_train.h5'))
    flower_captions = {}
    for ds in h.items():
        flower_captions[ds[0]] = np.array(ds[1])
    image_list = [key for key in flower_captions]
    image_list.sort()

    # img_75 = int(len(image_list))
    # training_image_list = image_list[0:img_75]

    random.shuffle(image_list)
    total_img_cnt = len(image_list)
    print(total_img_cnt)
    return {
        'image_list' : image_list, # image name
        'captions' : flower_captions, #flower_captions['image_00001.jpg'].shape (5,4800) 5 captions 4800 value for each 
        'data_length' : total_img_cnt
    }


def visualize_G(data_dir, real_images, generated_images, image_files):
    
    batch_size = real_images.shape[0]

    w = int(batch_size/8)
    h = int(batch_size/w)
    fake_all = np.zeros( (64*w,64*h,3), dtype=np.float32)   

    real_all = np.zeros( (64*w,64*h,3), dtype=np.float32) 


    print('w,h',w,h)
    for i in range(w):
        for j in range(h):
            fake_all[i*64:(i+1)*64,j*64:(j+1)*64,:] = (generated_images[i*h+j,:,:,:])

    # scipy.misc.imsave(join(data_dir, 'samples/fake_all.jpg'),fake_all)

    for i in range(w):
        for j in range(h):
            real_all[i*64:(i+1)*64,j*64:(j+1)*64,:] = (real_images[i*h+j,:,:,:])

    # scipy.misc.imsave(join(data_dir, 'samples/real_all.jpg'),real_all)

    both = np.zeros( (2*64*w,64*h,3), dtype=np.float32) 
    both[:64*w,:,:] = fake_all;
    both[64*w:,:,:] = real_all;
    # both = color.rgb2gray(both)
    scipy.misc.imsave(join(data_dir, 'samples/both.jpg'),both)



def batch_gen(batch_no, batch_size, image_size, z_dim, 
    caption_vector_length, split, data_dir, loaded_data = None):


    real_images = np.zeros((batch_size, 64, 64, 3))
    wrong_images = np.zeros((batch_size, 64, 64, 3))
    captions = np.zeros((batch_size, caption_vector_length))

    cnt = 0
    image_files = []
    for i in range(batch_no * batch_size, batch_no * batch_size + batch_size):
        idx = i % len(loaded_data['image_list'])
        image_file =  join(data_dir, 'flowers/jpg/'+loaded_data['image_list'][idx])
        image_array = image_processing.load_image_array(image_file, image_size)
        # image_array = image_array/255.0*2-1
        real_images[cnt,:,:,:] = image_array

        # Improve this selection of wrong image
        wrong_image_id = random.randint(0,len(loaded_data['image_list'])-1)
        wrong_image_file =  join(data_dir, 'flowers/jpg/'+loaded_data['image_list'][wrong_image_id])
        wrong_image_array = image_processing.load_image_array(wrong_image_file, image_size)
        # wrong_image_array = wrong_image_array/255.0*2-1
        wrong_images[cnt, :,:,:] = wrong_image_array

        random_caption = random.randint(0,4)
        captions[cnt,:] = loaded_data['captions'][ loaded_data['image_list'][idx] ][ random_caption ][:caption_vector_length]
        image_files.append( image_file )
        cnt += 1

    # z_noise = np.random.uniform(-1, 1, [batch_size, z_dim])
    z_noise = np.random.normal(0, 0.1,  [batch_size, z_dim])
    return real_images, wrong_images, captions, z_noise, image_files



epochs = 200
data_dir = "Data"
image_size = 64
z_dim = 100
caption_vector_length = 4800
model_options = {
    'z_dim' : z_dim,
    't_dim' :256,
    'batch_size' : batch_size,
    'image_size' :64,
    'gf_dim' : 64,
    'df_dim' : 64,
    'gfc_dim' :1024,
    'caption_vector_length' : caption_vector_length
}


gan = model.GAN(model_options)

input_tensors, variables, loss, outputs, checks,debug = gan.build_model()
d_optim = tf.train.AdamOptimizer(0.000001, beta1 = 0.5).minimize(loss['d'], var_list=variables['d'])
g_optim = tf.train.AdamOptimizer(0.000001, beta1 = 0.5).minimize(loss['g'], var_list=variables['g'])

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

sess = tf.InteractiveSession(config = config)
tf.global_variables_initializer().run()
# initialize_uninitialized_vars(sess)
saver = tf.train.Saver()

if resume_model:
    saver.restore(sess, resume_model)
    print('model restore:', resume_model)

variables_names = [v.name for v in tf.trainable_variables()]
values = sess.run(variables_names)
# for k, v in zip(variables_names, values):
#     print("Variable: ", k)
#     print("Shape: ", v.shape)
#     print(v)



loaded_data = load_training_data(data_dir)
batches_per_epoch =  int(loaded_data['data_length']/batch_size)
print(loaded_data['data_length'])

for i in range(epochs):
    print("ep",i)
    for batch_no in range(batches_per_epoch):
        real_images, wrong_images, caption_vectors, z_noise, image_files = batch_gen(batch_no, batch_size,
            image_size, z_dim, caption_vector_length, 'train', data_dir,  loaded_data)

        # DISCR UPDATE
        check_ts = [ checks['d_loss1'] , checks['d_loss2'], checks['d_loss3']]
        
        ##DON’t CHANGE FOR RANGE, KEEP IT 1!!!!
        for _ in range(1):
            _, d_loss, gen, d1, d2, d3 = sess.run([d_optim, loss['d'], outputs['generator']] + check_ts,
                feed_dict = {
                    input_tensors['t_real_image'] : real_images,
                    input_tensors['t_wrong_image'] : wrong_images,
                    input_tensors['t_real_caption'] : caption_vectors,
                    input_tensors['t_z'] : z_noise,
                })

        print("real/wrong/fake/total loss:",d1,d2,d3,d_loss)

        # GEN UPDATE
        ##DON’t CHANGE FOR RANGE, KEEP IT 1!!!!
        for _ in range(1):
            _, g_loss, gen,de = sess.run([g_optim, loss['g'], outputs['generator'],debug],
                feed_dict = {
                    input_tensors['t_real_image'] : real_images,
                    input_tensors['t_wrong_image'] : wrong_images,
                    input_tensors['t_real_caption'] : caption_vectors,
                    input_tensors['t_z'] : z_noise,
                            })

        # print(de)
        print("g_loss:", g_loss)

        if (batch_no % 30) == 0:
            print("Saving Images, Model")
            visualize_G(data_dir, real_images, gen, image_files)
            save_path = saver.save(sess, "Data/Models/temp.ckpt")

    if i%5 == 0:
        save_path = saver.save(sess, "Data/Models/ep_{}.ckpt".format( i))



# Model

In [None]:
import tensorflow as tf
from Utils import ops
import copy as cp
def lrelu(x, leak=0.2):
    return tf.maximum(x, leak*x)

class GAN:
    def __init__(self, options):
        self.options = options
        self.bn =  tf.layers.batch_normalization
        self.linear = tf.layers.dense


    def build_model(self):
        img_size = self.options['image_size']
        t_real_image = tf.placeholder('float32', [self.options['batch_size'],img_size, img_size, 3 ], name = 'real_image')
        t_wrong_image = tf.placeholder('float32', [self.options['batch_size'],img_size, img_size, 3 ], name = 'wrong_image')
        t_real_caption = tf.placeholder('float32', [self.options['batch_size'], self.options['caption_vector_length']], name = 'real_caption_input')
        t_z = tf.placeholder('float32', [self.options['batch_size'], self.options['z_dim']])


        fake_image,debug = self.generator(t_z, t_real_caption)

        # fake_g = tf.image.rgb_to_grayscale(fake_image)
        # fake_g = tf.tile(fake_g,[1,1,1,3])

        # real_g = tf.image.rgb_to_grayscale(t_real_image)
        # real_g = tf.tile(real_g,[1,1,1,3])


        pred_real, logit_real    = self.discriminator(t_real_image, t_real_caption,print_dim = True)
        pred_wrong, logit_wrong   = self.discriminator(t_wrong_image, t_real_caption)
        pred_fake, logit_fake   = self.discriminator(fake_image, t_real_caption)

        # pred_real_g, logit_real_g   = self.discriminator(real_g, t_real_caption)
        # pred_fake_g, logit_fake_g   = self.discriminator(fake_g, t_real_caption)
        

        g_loss1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits= logit_fake, labels  = tf.ones_like(pred_fake)))
        # g_loss2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits= logit_fake_g, labels  = tf.ones_like(pred_fake_g)))
        
        self.g_loss = g_loss1
        

        # soft_bias =tf.random_uniform( pred_real.shape,0, 0.3)
        
        d_loss1 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits= logit_real ,labels  =tf.ones_like(pred_real)))
        d_loss2 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits= logit_wrong,labels =tf.zeros_like(pred_wrong)))
        d_loss3 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits= logit_fake, labels  =tf.zeros_like(pred_fake)))

        # d_loss4 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits= logit_real_g, labels  =tf.ones_like(pred_real_g)))
        # d_loss5 = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits= logit_fake_g, labels  =tf.zeros_like(pred_fake_g)))

        self.d_loss = d_loss1+(d_loss2+d_loss3)*0.5
        # debug = [d_loss4,d_loss5]

#########################gradient penalty fail
        # LAMBDA = 1.0
        # self.epsilon = tf.random_uniform(
        #                         shape=[self.options['batch_size'],1,1], 
        #                         minval=0.,
        #                         maxval=1.
        #                         )
        # interpolates = t_real_image*self.epsilon + (1.0-self.epsilon)*fake_image
        # gradients = tf.gradients( self.discriminator(interpolates,t_real_caption)[0], [interpolates])[0]
        # self.slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2]))

        # self.gradient_penalty = tf.reduce_mean((self.slopes-1.)**2)
        # self.d_loss += (LAMBDA*self.gradient_penalty)



        self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
        self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')

        input_tensors = {
            't_real_image' : t_real_image,
            't_wrong_image' : t_wrong_image,
            't_real_caption' : t_real_caption,
            't_z' : t_z
        }

        variables = {
            'd' : self.d_vars,
            'g' : self.g_vars
        }

        loss = {
            'g' : self.g_loss,
            'd' : self.d_loss
        }

        outputs = {
            'generator' : fake_image
        }

        checks = {
            'd_loss1': d_loss1,
            'd_loss2': d_loss2,
            'd_loss3' : d_loss3,
            'logit_real ' : logit_real ,
            'logit_wrong' : pred_wrong,
            'logit_fake' : logit_fake
        }
        
        return input_tensors, variables, loss, outputs, checks,debug

    #this will be used in generate_image.py
    def build_generator(self):

        img_size = self.options['image_size']
        t_real_caption = tf.placeholder('float32', [self.options['batch_size'], self.options['caption_vector_length']], name = 'real_caption_input')
        t_z = tf.placeholder('float32', [self.options['batch_size'], self.options['z_dim']])
        fake_image,_ = self.generator(t_z, t_real_caption)
        
        input_tensors = {
            't_real_caption' : t_real_caption,
            't_z' : t_z
        }
        
        outputs = {
            'generator' : fake_image
        }

        return input_tensors, outputs

    def generator(self, t_z, t_text_embedding):
        with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
            s = self.options['image_size']
            s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16)
            
            reduced_text_embedding = lrelu( self.linear(t_text_embedding, self.options['t_dim']) )
            z_concat = tf.concat([t_z, reduced_text_embedding],1)
            z_ = self.linear(z_concat, self.options['gf_dim']*8*s16*s16)
            h0 = tf.reshape(z_, [-1, s16, s16, self.options['gf_dim'] * 8])
            h0 = lrelu(self.bn(h0))
            h0 = tf.nn.dropout(h0,0.5)

            h1 = ops.deconv2d(h0, [self.options['batch_size'], s8, s8, self.options['gf_dim']*4], name='g_h1')
            h1 = lrelu(self.bn(h1))
            h1 = tf.nn.dropout(h1,0.5)


            h2 = ops.deconv2d(h1, [self.options['batch_size'], s4, s4, self.options['gf_dim']*2], name='g_h2')
            h2 = lrelu(self.bn(h2))
            h2 = tf.nn.dropout(h2,0.5)

            h3 = ops.deconv2d(h2, [self.options['batch_size'], s2, s2, self.options['gf_dim']*1], name='g_h3')
            h3 = lrelu(self.bn(h3))

            h4 = ops.deconv2d(h3, [self.options['batch_size'], s, s, 3], name='g_h4')

            print("G",h0.shape,h1.shape,h2.shape,h3.shape,h4.shape)
            debug = h4
            return (tf.tanh(h4)/2. + 0.5),debug

    def discriminator(self, image, t_text_embedding,print_dim =False):
        with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE):
            h0 = lrelu(ops.conv2d_legacy(image, self.options['df_dim'])) #32

            h1 = lrelu( self.bn(ops.conv2d_legacy(h0, self.options['df_dim']*2, name = 'd_h1_conv'))) #16
            h2 = lrelu( self.bn(ops.conv2d_legacy(h1, self.options['df_dim']*4, name = 'd_h2_conv'))) #8
            h3 = lrelu( self.bn(ops.conv2d_legacy(h2, self.options['df_dim']*8, name = 'd_h3_conv'))) #4
            if print_dim:
                print("D",h0.shape,h1.shape,h2.shape,h3.shape)
            

            # ADD TEXT EMBEDDING TO THE NETWORK
            reduced_text_embeddings = lrelu(self.linear(t_text_embedding, self.options['t_dim']))
            reduced_text_embeddings = tf.expand_dims(reduced_text_embeddings,1)
            reduced_text_embeddings = tf.expand_dims(reduced_text_embeddings,2)
            tiled_embeddings = tf.tile(reduced_text_embeddings, [1,4,4,1], name='tiled_embeddings')
            
            h3_concat = tf.concat(  [h3, tiled_embeddings],3 ,name='h3_concat')
            h3_new = lrelu( self.bn(ops.conv2d_legacy(h3_concat, self.options['df_dim']*8, 1,1,1,1, name = 'd_h3_conv_new'))) #4
            
            h4 = self.linear(tf.reshape(h3_new, [self.options['batch_size'], -1]), 1)
            
            return tf.nn.sigmoid(h4), h4
