### Download data from here: http://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/

## Params

In [1]:
BATCH_SIZE = 32
LEARNING_RATE = 0.0002
EPOCHS = 1111
RESTORE_TRAINING= False

## Data preprocess

In [2]:
'''
This file consists of the helper functions for processing
'''

## Package Imports
from PIL import Image
import os
import tensorflow as tf
import numpy as np
import glob
try:
    import wget
except:
    print ("Can't import wget as you are probably on windows laptop")

def extract_files(data_dir,type = 'bags'):
    '''
    :param data_dir: Input directory
    :param type: bags or shoes
    :return: saves the cropped files to the bags to shoes directory
    '''
    input_file_dir = os.path.join(os.getcwd(),data_dir, "train")
    result_dir = os.path.join(os.getcwd(),type)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    file_names= os.listdir(input_file_dir)
    for file in file_names:
        input_image = Image.open(os.path.join(input_file_dir,file))
        input_image = input_image.resize([128, 64])
        input_image = input_image.crop([64, 0, 128, 64])  # Cropping only the colored image. Excluding the edge image
        input_image.save(os.path.join(result_dir,file))


def generate_dataset():
    '''
    Before executing this function. Follow these steps;
    1. Download the datasets
    Handbags data Link 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/edges2handbags.tar.gz'
    Shoes data Link 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/edges2shoes.tar.gz'

    2. Extract the tar files.

    3. Execute this function. This function will extract the handbags and shoe images from the datasets.
    '''
    if not os.path.exists(os.path.join(os.getcwd(), "edges2handbags")):
        try:
            print ("Downloading dataset")
            bag_data_link = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/edges2handbags.tar.gz'
            shoe_data_link = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/edges2shoes.tar.gz'
    
            wget.download(bag_data_link)
            wget.download(shoe_data_link)
    
            with tarfile.open('./edges2handbags.tar.gz') as tar:
                tar.extractall()
                tar.close()
    
            with tarfile.open('./edges2shoes.tar.gz') as tar:
                tar.extractall()
                tar.close()
        except:
            print ("It seems you are on windows laptop. Please download the data as instructed in README before executing the code")

    extract_files("edges2handbags", 'bags')
    extract_files("edges2shoes", 'shoes')


def load_data(load_type = 'train'):
    shoelist = glob.glob(os.path.join(os.getcwd(), "shoes/*jpg"))
    shoe_data = np.array([np.array(Image.open(fname)) for fname in shoelist]).astype(np.float32)
    baglist = glob.glob(os.path.join(os.getcwd(), "bags/*jpg"))
    bags_data = np.array([np.array(Image.open(fname)) for fname in baglist]).astype(np.float32)
    shoe_data = shoe_data/255.
    bags_data = bags_data/255.
    return shoe_data, bags_data


def save_image(global_step, img_data, file_name):
    sample_results_dir = os.path.join(os.getcwd(), "sample_results", "epoch_" +str(global_step))
    if not os.path.exists(sample_results_dir):
        os.makedirs(sample_results_dir)


    result = Image.fromarray((img_data[0] * 255).astype(np.uint8))
    result.save(os.path.join(sample_results_dir, file_name + ".jpg"))



def discriminator(x,initializer, scope_name ='discriminator',  reuse=False):
    with tf.variable_scope(scope_name) as scope:
        if reuse:
            scope.reuse_variables()
        conv1 = tf.contrib.layers.conv2d(inputs=x, num_outputs=32, kernel_size=4, stride=2, padding="SAME",
                                         reuse=reuse, activation_fn=tf.nn.leaky_relu, weights_initializer=initializer,
                                         scope="disc_conv1")  # 32 x 32 x 32
        conv2 = tf.contrib.layers.conv2d(inputs=conv1, num_outputs=64, kernel_size=4, stride=2, padding="SAME",
                                         reuse=reuse, activation_fn=tf.nn.leaky_relu, normalizer_fn=tf.contrib.layers.batch_norm,
                                         weights_initializer=initializer, scope="disc_conv2")  # 16 x 16 x 64
        conv3 = tf.contrib.layers.conv2d(inputs=conv2, num_outputs=128, kernel_size=4, stride=2, padding="SAME",
                                         reuse=reuse, activation_fn=tf.nn.leaky_relu, normalizer_fn=tf.contrib.layers.batch_norm,
                                         weights_initializer=initializer, scope="disc_conv3")  # 8 x 8 x 128
        conv4 = tf.contrib.layers.conv2d(inputs=conv3, num_outputs=256, kernel_size=4, stride=2, padding="SAME",
                                         reuse=reuse, activation_fn=tf.nn.leaky_relu, normalizer_fn=tf.contrib.layers.batch_norm,
                                         weights_initializer=initializer, scope="disc_conv4")  # 4 x 4 x 256
        conv5 = tf.contrib.layers.conv2d(inputs=conv4, num_outputs=512, kernel_size=4, stride=2, padding="SAME",
                                         reuse=reuse, activation_fn=tf.nn.leaky_relu, normalizer_fn=tf.contrib.layers.batch_norm,
                                         weights_initializer=initializer, scope="disc_conv5")  # 2 x 2 x 512
        fc1 = tf.reshape(conv5, shape=[tf.shape(x)[0], 2 * 2 * 512])
        fc1 = tf.contrib.layers.fully_connected(inputs=fc1, num_outputs=512, reuse=reuse, activation_fn=tf.nn.leaky_relu,
                                                normalizer_fn=tf.contrib.layers.batch_norm,
                                                weights_initializer=initializer, scope="disc_fc1")
        fc2 = tf.contrib.layers.fully_connected(inputs=fc1, num_outputs=1, reuse=reuse, activation_fn=tf.nn.sigmoid,
                                                weights_initializer=initializer, scope="disc_fc2")

        return fc2


def generator(x, initializer, scope_name = 'generator',reuse=False):
    with tf.variable_scope(scope_name) as scope:
        if reuse:
            scope.reuse_variables()
        conv1 = tf.contrib.layers.conv2d(inputs=x, num_outputs=32, kernel_size=4, stride=2, padding="SAME",
                                         reuse=reuse, activation_fn=tf.nn.leaky_relu, weights_initializer=initializer,
                                         scope="disc_conv1")  # 32 x 32 x 32
        conv2 = tf.contrib.layers.conv2d(inputs=conv1, num_outputs=64, kernel_size=4, stride=2, padding="SAME",
                                         reuse=reuse, activation_fn=tf.nn.leaky_relu, normalizer_fn=tf.contrib.layers.batch_norm,
                                         weights_initializer=initializer, scope="disc_conv2")  # 16 x 16 x 64
        conv3 = tf.contrib.layers.conv2d(inputs=conv2, num_outputs=128, kernel_size=4, stride=2, padding="SAME",
                                         reuse=reuse, activation_fn=tf.nn.leaky_relu, normalizer_fn=tf.contrib.layers.batch_norm,
                                         weights_initializer=initializer, scope="disc_conv3")  # 8 x 8 x 128
        conv4 = tf.contrib.layers.conv2d(inputs=conv3, num_outputs=256, kernel_size=4, stride=2, padding="SAME",
                                         reuse=reuse, activation_fn=tf.nn.leaky_relu, normalizer_fn=tf.contrib.layers.batch_norm,
                                         weights_initializer=initializer, scope="disc_conv4")  # 4 x 4 x 256

        deconv1 = tf.contrib.layers.conv2d(conv4, num_outputs=4 * 128, kernel_size=4, stride=1, padding="SAME",
                                               activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
                                               weights_initializer=initializer, scope="gen_conv1")
        deconv1 = tf.reshape(deconv1, shape=[tf.shape(x)[0], 8, 8, 128])

        deconv2 = tf.contrib.layers.conv2d(deconv1, num_outputs=4 * 64, kernel_size=4, stride=1, padding="SAME",
                                               activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
                                               weights_initializer=initializer, scope="gen_conv2")
        deconv2 = tf.reshape(deconv2, shape=[tf.shape(x)[0], 16, 16, 64])

        deconv3 = tf.contrib.layers.conv2d(deconv2, num_outputs=4 * 32, kernel_size=4, stride=1, padding="SAME",
                                               activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
                                               weights_initializer=initializer, scope="gen_conv3")
        deconv3 = tf.reshape(deconv3, shape=[tf.shape(x)[0], 32, 32, 32])

        deconv4 = tf.contrib.layers.conv2d(deconv3, num_outputs=4 * 16, kernel_size=4, stride=1, padding="SAME",
                                               activation_fn=tf.nn.relu, normalizer_fn=tf.contrib.layers.batch_norm,
                                               weights_initializer=initializer, scope="gen_conv4")
        deconv4 = tf.reshape(deconv4, shape=[tf.shape(x)[0], 64, 64, 16])

        recon = tf.contrib.layers.conv2d(deconv4, num_outputs=3, kernel_size=4, stride=1, padding="SAME", \
                                             activation_fn=tf.nn.relu, scope="gen_conv5")

        return recon
    

## Build a DiscoGAN model

In [3]:
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


class DiscoGAN:
    def __init__(self):
        with tf.variable_scope('Input'):
            self.X_bags = tf.placeholder(shape = [None, 64, 64, 3], name='bags', dtype=tf.float32)
            self.X_shoes = tf.placeholder(shape= [None, 64, 64, 3], name='shoes',dtype= tf.float32)
        self.initializer = tf.truncated_normal_initializer(stddev=0.02)
        self.define_network()
        self.define_loss()
        self.get_trainable_params()
        self.define_optimizer()
        self.summary_()

    def define_network(self):
        
        # Generators
        # This one is used to generate fake data
        self.gen_b_fake = generator(self.X_shoes, self.initializer,scope_name="generator_sb")
        self.gen_s_fake =   generator(self.X_bags, self.initializer,scope_name="generator_bs")

        # Reconstruction Generators
        # Note that parameters are being used from previous layers
        self.gen_recon_s = generator(self.gen_b_fake, self.initializer,scope_name="generator_sb",  reuse=True)
        self.gen_recon_b = generator(self.gen_s_fake,  self.initializer, scope_name="generator_bs", reuse=True)

        # Discriminator for Shoes
        self.disc_s_real = discriminator(self.X_shoes,self.initializer, scope_name="discriminator_s")
        self.disc_s_fake = discriminator(self.gen_s_fake,self.initializer, scope_name="discriminator_s", reuse=True)

        # Discriminator for Bags
        self.disc_b_real = discriminator(self.X_bags,self.initializer,scope_name="discriminator_b")
        self.disc_b_fake = discriminator(self.gen_b_fake, self.initializer, reuse=True,scope_name="discriminator_b")

        # Defining Discriminators of Bags and Shoes

    def define_loss(self):
        # Reconstruction loss for generators
        self.const_loss_s = tf.reduce_mean(tf.losses.mean_squared_error(self.gen_recon_s, self.X_shoes))
        self.const_loss_b = tf.reduce_mean(tf.losses.mean_squared_error(self.gen_recon_b, self.X_bags))

        # Generator loss for GANs
        self.gen_s_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_s_fake, labels=tf.ones_like(self.disc_s_fake)))
        self.gen_b_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_b_fake, labels=tf.ones_like(self.disc_b_fake)))

        # Total Generator Loss
        self.gen_loss =  (self.const_loss_b + self.const_loss_s)  + self.gen_s_loss + self.gen_b_loss

        # Cross Entropy loss for discriminators for shoes and bags
        # Shoes
        self.disc_s_real_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_s_real, labels=tf.ones_like(self.disc_s_real)))
        self.disc_s_fake_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_s_fake, labels=tf.zeros_like(self.disc_s_fake)))
        self.disc_s_loss = self.disc_s_real_loss + self.disc_s_fake_loss  # Combined


        # Bags
        self.disc_b_real_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_b_real, labels=tf.ones_like(self.disc_b_real)))
        self.disc_b_fake_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=self.disc_b_fake, labels=tf.zeros_like(self.disc_b_fake)))
        self.disc_b_loss = self.disc_b_real_loss + self.disc_b_fake_loss

        # Total Discriminator Loss
        self.disc_loss = self.disc_b_loss + self.disc_s_loss

    def get_trainable_params(self):
        '''
        This function is useful for obtaining trainable parameters which need to be trained either with discriminator or generator loss
        :return:
        '''
        self.disc_params = []
        self.gen_params = []
        for var in tf.trainable_variables():
            if 'generator' in var.name:
                self.gen_params.append(var)
            elif 'discriminator' in var.name:
                self.disc_params.append(var)

    def define_optimizer(self):
        self.disc_optimizer = tf.train.AdamOptimizer(LEARNING_RATE).minimize(self.disc_loss, var_list=self.disc_params)
        self.gen_optimizer = tf.train.AdamOptimizer(LEARNING_RATE).minimize(self.gen_loss, var_list=self.gen_params)

    def summary_(self):
        # Store the losses
        tf.summary.scalar("gen_loss", self.gen_loss)
        tf.summary.scalar("gen_s_loss", self.gen_s_loss)
        tf.summary.scalar("gen_b_loss", self.gen_b_loss)
        tf.summary.scalar("const_loss_s", self.const_loss_s)
        tf.summary.scalar("const_loss_b", self.const_loss_b)
        tf.summary.scalar("disc_loss", self.disc_loss)
        tf.summary.scalar("disc_b_loss", self.disc_b_loss)
        tf.summary.scalar("disc_s_loss", self.disc_s_loss)

        # Histograms for all vars
        for var in tf.trainable_variables():
            tf.summary.histogram(var.name, var)

        self.summary_ = tf.summary.merge_all()


## Train

In [4]:
import tensorflow as tf
import os
import random


def train(model):
    # Load the data first
    # Define a function to load the next batch
    # start training

    # Define a function to get the data for the next batch
    def get_next_batch(BATCH_SIZE, type ="shoes"):
        if type == "shoes":
            next_batch_indices = random.sample(range(0, X_shoes.shape[0]), BATCH_SIZE)
            batch_data = X_shoes[next_batch_indices,:,:,:]
        elif type == "bags":
            next_batch_indices = random.sample(range(0, X_bags.shape[0]), BATCH_SIZE)
            batch_data = X_bags[next_batch_indices, :, :, :]
        return batch_data

    # Loading the dataset
    print ("Loading Dataset")
    X_shoes, X_bags = load_data(load_type='train')

    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        if RESTORE_TRAINING:
            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state("./model")
            saver.restore(sess, ckpt.model_checkpoint_path)
            print('Model Loaded')
            start_epoch = int(str(ckpt.model_checkpoint_path).split('-')[-1].split(".")[0])
            print ("Start EPOCH", start_epoch)
        else:
            saver = tf.train.Saver(tf.global_variables())
            tf.global_variables_initializer().run()
            if not os.path.exists("logs"):
                os.makedirs("logs")
            start_epoch = 0

        # Starting training from here
        train_writer = tf.summary.FileWriter(os.getcwd() + '/logs', graph=sess.graph)
        print ("Starting Training")
        for global_step in range(start_epoch,EPOCHS):
            shoe_batch = get_next_batch(BATCH_SIZE,"shoes")
            bag_batch = get_next_batch(BATCH_SIZE,"bags")
            feed_dict_batch = {model.X_bags: bag_batch, model.X_shoes: shoe_batch}
            op_list = [model.disc_optimizer, model.gen_optimizer, model.disc_loss, model.gen_loss, model.summary_]
            _, _, disc_loss, gen_loss, summary_ = sess.run(op_list, feed_dict=feed_dict_batch)
            shoe_batch = get_next_batch(BATCH_SIZE, "shoes")
            bag_batch = get_next_batch(BATCH_SIZE, "bags")
            feed_dict_batch = {model.X_bags: bag_batch, model.X_shoes: shoe_batch}
            _, gen_loss = sess.run([model.gen_optimizer, model.gen_loss], feed_dict=feed_dict_batch)
            if global_step%10 ==0:
                train_writer.add_summary(summary_,global_step)

            if global_step%100 == 0:
                print("EPOCH:" + str(global_step) + "\tGenerator Loss: " + str(gen_loss) + "\tDiscriminator Loss: " + str(disc_loss))


            if global_step % 1000 == 0:

                shoe_sample = get_next_batch(1, "shoes")
                bag_sample = get_next_batch(1, "bags")

                ops = [model.gen_s_fake, model.gen_b_fake, model.gen_recon_s, model.gen_recon_b]
                gen_s_fake, gen_b_fake, gen_recon_s, gen_recon_b = sess.run(ops, feed_dict={model.X_shoes: shoe_sample, model.X_bags: bag_sample})

                save_image(global_step, gen_s_fake, str("gen_s_fake_") + str(global_step))
                save_image(global_step,gen_b_fake, str("gen_b_fake_") + str(global_step))
                save_image(global_step, gen_recon_s, str("gen_recon_s_") + str(global_step))
                save_image(global_step, gen_recon_b, str("gen_recon_b_") + str(global_step))

            if global_step % 1000 == 0:
                if not os.path.exists("./model"):
                    os.makedirs("./model")
                saver.save(sess, "./model" + '/model-' + str(global_step) + '.ckpt')
                print("Saved Model")

def main():
    # Get the dataset first.

    if not os.path.exists(os.path.join(os.getcwd(), "bags")):
        print("Generating Dataset")
        generate_dataset()
    # Create the model
    print ("Defining the model")
    model = DiscoGAN()
    print ("Training")
    train(model)


if __name__ == "__main__":
    main()


Defining the model

For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.cast instead.
INFO:tensorflow:Summary name generator_sb/disc_conv1/weights:0 is illegal; using generator_sb/disc_conv1/weights_0 instead.
INFO:tensorflow:Summary name generator_sb/disc_conv1/biases:0 is illegal; using generator_sb/disc_conv1/biases_0 instead.
INFO:tensorflow:Summary name generator_sb/disc_conv2/weights:0 is illegal; using generator_sb/disc_conv2/weights_0 instead.
INFO:tensorflow:Summary name generator_sb/disc_conv2/BatchNorm/beta:0 is illegal; using generator_sb/disc_conv2/BatchNorm/beta_0 instead.
INFO:tensorflow:Summary name generator_sb/disc_conv3/weights:0 is illegal; using generator_sb/disc_conv3/weights_0 in