In [1]:
import scipy.io
import scipy.misc
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt

In [2]:
def conv2d(input, kernel_size, stride, num_filter, name = 'conv2d'):
    with tf.variable_scope(name):
        stride_shape = [1, stride, stride, 1]
        filter_shape = [kernel_size, kernel_size, input.get_shape()[3], num_filter]

        W = tf.get_variable('w', filter_shape, tf.float32, tf.random_normal_initializer(0.0, 0.02))
        b = tf.get_variable('b', [1, 1, 1, num_filter], initializer = tf.constant_initializer(0.0))
        return tf.nn.conv2d(input, W, stride_shape, padding = 'SAME') + b

def conv2d_transpose(input, kernel_size, stride, num_filter, name = 'conv2d_transpose'):
    with tf.variable_scope(name):
        stride_shape = [1, stride, stride, 1]
        filter_shape = [kernel_size, kernel_size, num_filter, input.get_shape()[3]]
        output_shape = tf.stack([tf.shape(input)[0], tf.shape(input)[1] * 2, tf.shape(input)[2] * 2, num_filter])

        W = tf.get_variable('w', filter_shape, tf.float32, tf.random_normal_initializer(0.0, 0.02))
        b = tf.get_variable('b', [1, 1, 1, num_filter], initializer = tf.constant_initializer(0.0))
        return tf.nn.conv2d_transpose(input, W, output_shape, stride_shape, padding = 'SAME') + b

def fc(input, num_output, name = 'fc'):
    with tf.variable_scope(name):
        num_input = input.get_shape()[1]
        W = tf.get_variable('w', [num_input, num_output], tf.float32, tf.random_normal_initializer(0.0, 0.02))
        b = tf.get_variable('b', [num_output], initializer = tf.constant_initializer(0.0))
        return tf.matmul(input, W) + b

def batch_norm(input, is_training):
    out = tf.contrib.layers.batch_norm(input, decay = 0.99, center = True, scale = True,
                                       is_training = is_training, updates_collections = None)
    return out

def leaky_relu(input, alpha = 0.2):
    return tf.maximum(alpha * input, input)

In [5]:
class cycleGAN(object):
    def __init__(self):
        #initiater
        self.num_epoch = 10
        self.batch_size = 32
        self.log_step = 1
        self.visualize_step = 200
        self.code_size = 64
        self.learning_rate = 1e-4
        
        self.dis_name_1 = 'dis1'
        self.dis_name_2 = 'dis2'
        self.gen_name_1_to_2 = 'gen_1_to_2'
        self.gen_name_2_to_1 = 'gen_2_to_1'

        self.reuse = {
            self.dis_name_1: False,
            self.dis_name_2: False,
            self.gen_name_1_to_2: False,
            self.gen_name_2_to_1: False
        }
        
        self.batch_size = 20
        
        self.lamda = 0.5

        self.input1 = tf.placeholder(tf.float32, [None, 32, 32, 3])
        self.input2 = tf.placeholder(tf.float32, [None, 32, 32, 3])

        self.real_label = tf.placeholder(tf.float32, [None, 1])
        self.fake_label = tf.placeholder(tf.float32, [None, 1])

        self.is_train = tf.placeholder(tf.bool)

        self._init_ops()
        
    def _discriminator(self, input, scopeName):
        #initiate discriminator for a certain scope
        with tf.variable_scope(scopeName, reuse = self.reuse[scopeName]):
            self.reuse[scopeName] = True
            dis_conv1 = conv2d(input, 4, 2, 32, 'conv1')
            dis_lrelu1 = leaky_relu(dis_conv1)
            dis_conv2 = conv2d(dis_lrelu1, 4, 2, 64, 'conv2')
            dis_batchnorm2 = batch_norm(dis_conv2, self.is_train)
            dis_lrelu2 = leaky_relu(dis_batchnorm2)
            dis_conv3 = conv2d(dis_lrelu2, 4, 2, 128, 'conv3')
            dis_batchnorm3 = batch_norm(dis_conv3, self.is_train)
            dis_lrelu3 = leaky_relu(dis_batchnorm3)
            dis_reshape3 = tf.reshape(dis_lrelu3, [-1, 4 * 4 * 128])
            dis_fc4 = fc(dis_reshape3, 1, 'fc4')
            return dis_fc4

    
    def _generator(self, input, scopeName):
        #initiate generator for a certain scope
        with tf.variable_scope(scopeName, reuse = self.reuse[scopeName]):
            self.reuse[scopeName] = True
            gen_conv2 = conv2d_transpose(input, 4, 2, 64, 'conv2')
            gen_batchnorm2 = batch_norm(gen_conv2, self.is_train)
            gen_lrelu2 = leaky_relu(gen_batchnorm2)
            gen_conv3 = conv2d_transpose(gen_lrelu2, 4, 2, 32, 'conv3')
            gen_batchnorm3 = batch_norm(gen_conv3, self.is_train)
            gen_lrelu3 = leaky_relu(gen_batchnorm3)
            gen_conv4 = conv2d_transpose(gen_lrelu3, 4, 2, 3, 'conv4')
            gen_sigmoid4 = tf.sigmoid(gen_conv4)
            return gen_sigmoid4
    
    def _adviserial_loss(self, logits, labels):
        #binary L2 loss
        return tf.square(logits - labels)
        
    def _cycle_loss(self, logits, labels):
        #L1 loss
        return tf.losses.absolute_difference(logits, labels)
    
    def _init_ops(self):
        #operations
        self.real_1_dis = self._discriminator(self.input1, self.dis_name_1)
        self.real_2_dis = self._discriminator(self.input2, self.dis_name_2)
        
        self.generated_1 = self._generator(self.input2, self.gen_name_2_to_1)
        self.generated_2 = self._generator(self.input1, self.gen_name_1_to_2)
        
        self.cycle_fake_1 = self._generator(self.generated_2, self.gen_name_2_to_1)
        self.cycle_fake_2 = self._generator(self.generated_1, self.gen_name_1_to_2)
        
        self.fake_1_dis = self._discriminator(self.generated_1, self.dis_name_1)
        self.fake_2_dis = self._discriminator(self.generated_2, self.dis_name_2)
        
        #variable scope
        self.gen_1_to_2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,self.gen_name_1_to_2)
        self.gen_2_to_1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,self.gen_name_2_to_1)
        self.gen_scope = self.gen_1_to_2 + self.gen_2_to_1
        
        self.dis1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,self.dis_name_1)
        self.dis2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,self.dis_name_2)
        self.dis_scope = self.dis1 + self.dis2
        
        #loss functions
        gan_loss_1 = self._adviserial_loss(self.real_1_dis, self.real_label) + self._adviserial_loss(self.fake_1_dis, self.fake_label)
        
        gan_loss_2 = self._adviserial_loss(self.real_2_dis, self.real_label) + self._adviserial_loss(self.fake_2_dis, self.fake_label)
            
        cycle_loss = self._cycle_loss(self.cycle_fake_1, self.input1) + self._cycle_loss(self.cycle_fake_2, self.input2)
        
        self.gen_loss = gan_loss_1 + gan_loss_2 + self.lamda*cycle_loss
        
        self.dis_loss = self._adviserial_loss(self.real_1_dis, self.real_label) + self._adviserial_loss(self.real_2_dis, self.real_label) 
        
        #optimizers and training step
        dis_optimizer = tf.train.RMSPropOptimizer(self.learning_rate)
        self.dis_train_op = dis_optimizer.minimize(self.dis_loss, var_list = self.dis_scope)
        
        gen_optimizer = tf.train.RMSPropOptimizer(self.learning_rate)
        self.gen_train_op = gen_optimizer.minimize(self.gen_loss, var_list = self.gen_scope)
        
        
    def train(self, sess):
        sess.run(tf.global_variables_initializer())

        num_train = 10000
        step = 0
        
        # smooth the loss curve so that it does not fluctuate too much
        smooth_factor = 0.95
        plot_dis_s = 0
        plot_gen_s = 0
        plot_ws = 0
        
        dis_losses = []
        gen_losses = []
        
        print('start training')
        for epoch in range(self.num_epoch):
            for i in range(num_train // self.batch_size):
                
                step += 1

                input1 = np.random.standard_normal([self.batch_size, 32,32,3])
                input2 = np.random.exponential(1.0, [self.batch_size, 32,32,3])
                
                zeros = np.zeros([self.batch_size, 1])
                ones = np.ones([self.batch_size, 1])
        
                ################################################################################
                # Prob 2-1: complete the feed dictionary                                       #
                ################################################################################
                
                dis_feed_dict = {self.input1: input1,
                                 self.input2: input2,
                                 self.real_label: ones,
                                 self.fake_label: zeros,
                                 self.is_train: True}
        
                ################################################################################
                #                               END OF YOUR CODE                               #
                ################################################################################

                _, dis_loss = sess.run([self.dis_train_op, self.dis_loss], feed_dict = dis_feed_dict)
                
                input1 = np.random.standard_normal([self.batch_size, 32,32,3])
                input2 = np.random.exponential(1.0, [self.batch_size, 32,32,3])
        
                ################################################################################
                # Prob 2-1: complete the feed dictionary                                       #
                ################################################################################
                
                gen_feed_dict = {self.input1: input1,
                                 self.input2: input2,
                                 self.real_label: ones,
                                 self.fake_label: zeros,
                                 self.is_train: True}
        
                ################################################################################
                #                               END OF YOUR CODE                               #
                ################################################################################

                _, gen_loss = sess.run([self.gen_train_op, self.gen_loss], feed_dict = gen_feed_dict)

                plot_dis_s = plot_dis_s * smooth_factor + dis_loss * (1 - smooth_factor)
                plot_gen_s = plot_gen_s * smooth_factor + gen_loss * (1 - smooth_factor)
                plot_ws = plot_ws * smooth_factor + (1 - smooth_factor)
                dis_losses.append(plot_dis_s / plot_ws)
                gen_losses.append(plot_gen_s / plot_ws)

                if step % self.log_step == 0:
                    print('Iteration {0}: dis loss = {1:.4f}, gen loss = {2:.4f}'.format(step, dis_loss, gen_loss))

            plt.plot(dis_losses)
            plt.title('discriminator loss')
            plt.xlabel('iterations')
            plt.ylabel('loss')
            plt.show()

            plt.plot(gen_losses)
            plt.title('generator loss')
            plt.xlabel('iterations')
            plt.ylabel('loss')
            plt.show()
        
        
        

In [None]:
tf.reset_default_graph()

with tf.Session() as sess:
    with tf.device('/cpu:0'):
        cycle_gan = cycleGAN()
        sess.run(tf.global_variables_initializer())
        cycle_gan.train(sess)

start training
