In [1]:
import tensorflow as tf
import numpy as np
# import argparse
import os
import time
import model
import utils


os.environ["CUDA_VISIBLE_DEVICES"]="0"

# def arg_parser():
#     parser = argparse.ArgumentParser()
#     parser.add_argument("--image_size", default = 256, type = int)
#     parser.add_argument("--crop_size", default = 70, type = int)
#     parser.add_argument("--batch_size", default = 128, type = int)     
#     parser.add_argument("--pre_train_iter", default = 20000, type = int)
#     parser.add_argument("--iter", default = 100000, type = int)
#     parser.add_argument("--learning_rate", default = 1e-4, type = float)
#     parser.add_argument("--gpu_fraction", default = 0.5, type = float)
#     parser.add_argument("--save_dir", default = 'saved_models')
#     parser.add_argument("--train_out_dir", default = 'train_output')
#     parser.add_argument("--test_out_dir", default = 'test_output')
#     parser.add_argument("--mode", default = 'train')
    
#     args = parser.parse_args()
#     return args


class CartoonGAN():
    def __init__(self):
        self.image_size = 256
        self.crop_size = 70
        self.batch_size = 8
        self.pre_train_iter = 20000
        self.iter = 100000
        self.learning_rate = 1e-4
        self.gpu_fraction = 0.5
        self.train_out_dir = 'saved_models'
        self.test_out_dir = 'train_output'
        self.save_dir = 'test_output'
        self.lambda_ = 10
        
        self.is_train = tf.placeholder(tf.bool)
        self.photo_input = tf.placeholder(tf.float32, [None, None, None, 3], name="photo")
        self.cartoon_input = tf.placeholder(tf.float32, [None, None, None, 3], name="cartoon")
        self.blur_input = tf.placeholder(tf.float32, [None, None, None, 3], name="blur")


    
    def input_setup(self):
        
        self.celeba_list = utils.get_filename_list('real_world')
        self.cartoon_list = utils.get_filename_list('cartoon_original')
        print('Finished loading data')

            
    def build_model(self):
        
        self.fake_cartoon = model.generator(self.photo_input, name='generator', 
                                            reuse=tf.AUTO_REUSE, is_train=self.is_train)
#                                             reuse=False, is_train=self.is_train)

        self.real_logit_cartoon = model.multi_patch_discriminator(self.cartoon_input, self.crop_size, 
                                                            name='discriminator', reuse=tf.AUTO_REUSE)    
#                                                             name='discriminator', reuse=False)  
        
        self.fake_logit_cartoon = model.multi_patch_discriminator(self.fake_cartoon, self.crop_size, 
                                                            name='discriminator', reuse=tf.AUTO_REUSE)  
#                                                             name='discriminator', reuse=True)
        
        self.logit_blur = model.multi_patch_discriminator(self.blur_input, self.crop_size,
                                                            name='discriminator', reuse=tf.AUTO_REUSE)  
#                                                             name='discriminator', reuse=True)

        VGG_loss = utils.vgg_loss(self.photo_input, self.fake_cartoon)
        
        g_loss = -tf.reduce_mean(tf.log(tf.nn.sigmoid(self.fake_logit_cartoon))) + 5e3*VGG_loss
        
        d_loss = -tf.reduce_mean(tf.log(tf.nn.sigmoid(self.real_logit_cartoon))
                                + tf.log(1. - tf.nn.sigmoid(self.fake_logit_cartoon))
                                + tf.log(1. - tf.nn.sigmoid(self.logit_blur)))


        all_vars = tf.trainable_variables()

        d_vars = [var for var in all_vars if 'discriminator' in var.name]
        g_vars = [var for var in all_vars if 'generator' in var.name]

        
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            self.init_optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0., beta2=0.9).\
                                        minimize(VGG_loss, var_list=g_vars, colocate_gradients_with_ops=True)
            self.d_optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0., beta2=0.9).\
                                        minimize(d_loss, var_list=d_vars, colocate_gradients_with_ops=True)
            self.g_optim = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0., beta2=0.9).\
                                        minimize(g_loss, var_list=g_vars, colocate_gradients_with_ops=True)

        #Summary variables for tensorboard

        self.g_A_loss_summ = tf.summary.scalar('g_loss', g_loss)
        self.d_A_loss_summ = tf.summary.scalar('d_loss', d_loss)
        self.VGG_loss_summ = tf.summary.scalar('VGG_loss', VGG_loss)
        
        self.saver = tf.train.Saver(g_vars)
        
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.gpu_fraction)
        self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        print('Finished building model')



    def train(self):
        if not os.path.exists(self.train_out_dir):
            os.makedirs(self.train_out_dir)
        
        # Initializing the global variables
        init = ([tf.global_variables_initializer(), tf.local_variables_initializer()])
        train_writer = tf.summary.FileWriter(self.save_dir+"/train", self.sess.graph)
        summary_op = tf.summary.merge_all()
        

        with tf.device('/device:GPU:0'):
            sess = self.sess
            sess.run(init)
            start_time = time.time()
            
            print("before pre train")
            
            # Pre-training iterations
            if os.path.isfile(self.save_dir+ '/pre_train-{}.meta'.format(self.pre_train_iter-1)):
                self.saver.restore(sess, self.save_dir+ '/pre_train-'+str(self.pre_train_iter-1))
                print('Finished loading pre_trained model')
            else:
                for iter in range(self.pre_train_iter):
                    photo_batch = utils.next_batch(self.batch_size, self.image_size, self.celeba_list)
                    cartoon_batch, blur_batch = utils.next_blur_batch(self.batch_size, 
                                                                  self.image_size, 
                                                                  self.cartoon_list)
                
                    _ = sess.run([self.init_optim], feed_dict={self.photo_input: photo_batch, 
                                                                self.cartoon_input: cartoon_batch, 
                                                                self.blur_input: blur_batch, 
                                                                self.is_train: True})
    
                    if np.mod(iter+1, 50) == 0:
                        print('pre_train iteration:[%d/%d], time cost:%f' \
                                %(iter+1, self.pre_train_iter, time.time()-start_time))
                        start_time = time.time()

                        if np.mod(iter+1, 1000) == 0:
                            batch_image = sess.run([self.fake_cartoon], 
                                         feed_dict={self.photo_input: photo_batch, self.is_train: True})
                            batch_image = np.squeeze(batch_image)
                            utils.print_fused_image(batch_image, self.train_out_dir, str(iter)+'_pre_train.png', 4)
                        
                        if np.mod(iter+1, self.pre_train_iter) == 0:
                            self.saver.save(sess, self.save_dir+ '/pre_train', global_step=iter)
                        
            print("after pre train")
            print("before train")
            
            #Training iterations
            for iter in range(self.iter):                
                
                photo_batch = utils.next_batch(self.batch_size, self.image_size, self.celeba_list)
                cartoon_batch, blur_batch = utils.next_blur_batch(self.batch_size, 
                                                                  self.image_size, 
                                                                  self.cartoon_list)
                
                
                _ = sess.run([self.g_optim], feed_dict={self.photo_input: photo_batch, 
                                                        self.cartoon_input: cartoon_batch, 
                                                        self.blur_input: blur_batch, 
                                                        self.is_train: True})

                _, summary = sess.run([self.d_optim, summary_op], 
                                      feed_dict={self.photo_input: photo_batch, 
                                                self.cartoon_input: cartoon_batch, 
                                                self.blur_input: blur_batch, 
                                                self.is_train: True})      

                train_writer.add_summary(summary, iter)         
                    
                if np.mod(iter+1, 10) == 0:
                    print('train iteration:[%d/%d], time cost:%f' \
                            %(iter+1, self.iter, time.time()-start_time))
                    start_time = time.time()

                    if np.mod(iter+1, 500) == 0:
                        batch_image = sess.run([self.fake_cartoon], 
                                               feed_dict={self.photo_input: photo_batch, 
                                                          self.is_train: True})
                        batch_image = np.squeeze(batch_image)
                        utils.print_fused_image(batch_image, self.train_out_dir, str(iter)+'.png', 4 )
                        
                    if np.mod(iter+1, 20000) == 0:
                        self.saver.save(sess, self.save_dir+ '/model', global_step=iter)
                        
            print("after train")

    def test(self):
        
        if not os.path.exists(self.test_out_dir):
            os.mkdir(self.test_out_dir)
        
        self.test_list = utils.get_filename_list('actress')
        
        init = ([tf.global_variables_initializer(), tf.local_variables_initializer()])
        self.sess.run(init)
        self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir)) 

        for idx in range(100):
            photo_batch = utils.next_batch(self.batch_size, self.image_size, self.test_list)
            images = self.sess.run([self.fake_cartoon], feed_dict={self.photo_input: photo_batch, 
                                                                    self.is_train: True})
            images = np.squeeze(images)
            utils.print_fused_image(images, self.test_out_dir, str(idx)+'.png', 4 )



def main():
#     args = arg_parser()
#     model = CartoonGAN(args)
    model = CartoonGAN()
    
    print("train mode test mode")
    mode = 'train'
    
#     if args.mode == 'train':
    if mode == 'train':
        model.build_model()
        model.input_setup()
        model.train()

#     elif args.mode == 'test':
    elif mode == 'test':
        model.build_model()
        model.test()
    
print("Log before main function")
    
main()

print("Log after main function")

Log before main function
train mode test mode
Finished loading vgg19.npy
Finished loading vgg19.npy
Finished building vgg19: 0s
Finished building vgg19: 0s
Finished building model
Finished loading data
before pre train


ResourceExhaustedError: OOM when allocating tensor with shape[8,128,128,64] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[Node: gradients/zeros_66-0-1-TransposeNCHWToNHWC-LayoutOptimizer = Transpose[T=DT_FLOAT, Tperm=DT_INT32, _device="/job:localhost/replica:0/task:0/device:GPU:0"](gradients/zeros_66, PermConstNCHWToNHWC-LayoutOptimizer)]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
