In [None]:
from songe.tf_layer import tf_toolkit as tl
from songe.tf_layer import activation as ac
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import gridspec as gridspec
import random
from skimage import io
import tensorflow as tf 
from PIL import Image
import glob
import os

In [None]:
def show_images(images):
    if images.shape[3] == 3:
        c = images.shape[3]
        images = np.reshape(images, [images.shape[0], -1])
        sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
        sqrtimg = int(np.ceil(np.sqrt(images.shape[1]/c)))
        fig = plt.figure(figsize = (7,7))
        gs = gridspec.GridSpec(sqrtn, sqrtn)
        gs.update(wspace = 0.05, hspace = 0.05)
    
        for i, img in enumerate(images):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(img.reshape([sqrtimg,sqrtimg,c]))
    else:
        images = np.reshape(images, [images.shape[0], -1])
        sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
        sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
        fig = plt.figure(figsize = (7, 7))
        gs = gridspec.GridSpec(sqrtn, sqrtn)
        gs.update(wspace = 0.05, hspace = 0.05)
    
        for i, img in enumerate(images):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(img.reshape([sqrtimg,sqrtimg]), cmap = plt.cm.gray)
    return 

        
    

In [None]:
def preprocessing():
    bag_dir = "./datasets/discogan/edges2handbags/train/"
    shoe_dir = "./datasets/discogan/edges2shoes/train/"

    bag_list = glob.glob(bag_dir+'*.jpg')
    shoe_list = glob.glob(shoe_dir+'*.jpg')
    
    bag_path = './datasets/discogan/edges2handbags/bags_train/'
    shoe_path = "./datasets/discogan/edges2shoes/shoes_train/"

    if not os.path.exists(bag_path):
        os.mkdir(bag_path)
    if not os.path.exists(shoe_path):
        os.mkdir(shoe_path)
    

    for idx ,i in enumerate(bag_list):
        image = Image.open(i)
        image = image.resize([128,64])
        image = image.crop([64,0,128,64])
        image.save(bag_path+str(idx)+'.jpg')

    for idx, i in enumerate(shoe_list):
        image = Image.open(i)
        image = image.resize([128,64])
        image = image.crop([64,0,128,64])
        image.save(shoe_path+str(idx)+'.jpg')

    
    

In [None]:
preprocessing()

In [None]:
class discogan:
    def __init__(self, batch_size, training_epoch):
        self.batch_size = batch_size
        self.training_epoch = training_epoch
        self.c_dim = 3
        self.image_shape = [64,64,self.c_dim]
        self.lambda_ = 10 
        self.learning_rate = 0.0002
        self.bag_path = './datasets/discogan/edges2handbags/bags_train/'
        self.shoes_path = './datasets/discogan/edges2shoes/shoes_train/'

        
    def discriminator(self,x,name, reuse = False, is_train = True):
        with tf.variable_scope("discriminator"+name, reuse = reuse):
            net = ac.lrelu(tl.conv(x,64,4,2,"d_1"))
            net = ac.lrelu(tl.batch_norm(tl.conv(net,64*2,4,2,"d_2")))
            net = ac.lrelu(tl.batch_norm(tl.conv(net,64*4,4,2,"d_3")))
            net = ac.lrelu(tl.batch_norm(tl.conv(net,64*8,4,2,"d_4")))
            net = tl.conv(net,1,4,2,"d_5",padding = "VALID")
        return net
    
    def generator(self,x,name, reuse = False, is_train = True):
        with tf.variable_scope("generator"+name, reuse = reuse):
            "Encoder"
            net = ac.lrelu(tl.conv(x, 64, 4,2, "g_1"))
            net = ac.lrelu(tl.batch_norm(tl.conv(net, 64*2,4,2,"g_2")))
            net = ac.lrelu(tl.batch_norm(tl.conv(net, 64*4,4,2,"g_3")))
            net = ac.lrelu(tl.batch_norm(tl.conv(net, 64*8,4,2,"g_4")))
            net = ac.lrelu(tl.batch_norm(tl.conv(net, 100, 4,1,"g_5", padding = "VALID")))

            "Decoder"
            net = ac.relu(tl.batch_norm(tl.conv_tran(net, 64*8, 4,2,"g_6",padding = "VALID")))
            net = ac.relu(tl.batch_norm(tl.conv_tran(net, 64*4, 4,2,"g_7")))
            net = ac.relu(tl.batch_norm(tl.conv_tran(net, 64*2, 4,2,"g_8")))
            net = ac.relu(tl.batch_norm(tl.conv_tran(net, 64, 4,2,"g_9")))
            net = ac.tanh(tl.conv_tran(net, 3, 4,2,"g_10"))
        return net
    
    def build_model(self):
        self.input_A = tf.placeholder(tf.float32, shape = [None]+self.image_shape)
        self.input_B = tf.placeholder(tf.float32, shape = [None]+self.image_shape)
        
        A2B = self.generator(self.input_A,"_inputAtoB")
        B2A = self.generator(self.input_B,"_inputBtoA")
        self.sample_A2B = self.generator(self.input_A,"_inputAtoB",reuse= True)
        self.sample_B2A = self.generator(self.input_B,"_inputBtoA",reuse = True)
        A2B2A = self.generator(A2B,"_inputBtoA", reuse= True)
        B2A2B = self.generator(B2A,"_inputAtoB", reuse = True)
        
        # discrimonator --> real image discrimination logit
        discriminatorA = self.discriminator(self.input_A,"_A_discriminator")
        discriminatorB = self.discriminator(self.input_B,"_B_discriminator")
        
        # discriminator --> A2B , B2A data discrimination logit
        discriminatorB_A = self.discriminator(A2B,"_A_discriminator", reuse= True)
        discriminatorA_B = self.discriminator(B2A,"_B_discriminator", reuse= True)
        
        # discriminator --> real_image d_loss
        real_A_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits= discriminatorA, labels= tf.ones_like(discriminatorA)))
        real_B_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits= discriminatorB, labels= tf.ones_like(discriminatorB)))
        
        # discriminator -->? fake_image d_loss
        fake_A_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits= discriminatorB_A, labels= tf.zeros_like(discriminatorB_A)))
        fake_B_loss= tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits= discriminatorA_B, labels= tf.zeros_like(discriminatorA_B)))
        
        # generator --> generator가 만들어내는 이미지를 진실로 판단할 수 있게 학습
        g_A_B_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = discriminatorA_B, labels = tf.ones_like(discriminatorA_B)))
        g_B_A_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = discriminatorB_A, labels = tf.ones_like(discriminatorB_A)))
        
        
        # Total Discriminator Loss
        loss_D_A = (real_A_loss + fake_A_loss) * 0.5
        loss_D_B = (real_B_loss + fake_B_loss) * 0.5
        self.loss_D = loss_D_A + loss_D_B
        
        # reconstruction Loss
        A_reconstruction_loss = tf.reduce_sum(tf.losses.mean_squared_error
(self.input_A,A2B2A))
        B_reconstruction_loss = tf.reduce_sum(tf.losses.mean_squared_error
(self.input_B,B2A2B))
        # Total Generator Loss
        self.loss_G = (g_A_B_loss + g_B_A_loss)*0.5 + self.lambda_ *(A_reconstruction_loss + B_reconstruction_loss)
        
        # training variables
#         g_vars_AtoB = tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES, scope="generator_inputAtoB")
#         g_vars_BtoA = tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES, scope="generator_inputBtoA")
#         d_vars_AtoB = tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES, scope="discriminator_A_discriminator")
#         d_vars_BtoA = tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES, scope="discriminator_B_discriminator")
    
#         self.g_vars = g_vars_AtoB + g_vars_BtoA
#         self.d_vars = d_vars_AtoB + d_vars_BtoA
        t_vars = tf.trainable_variables()
        g_vars = [var for var in t_vars if 'g_' in var.name]
        d_vars = [var for var in t_vars if 'd_' in var.name]

        # Define Optimizer
        self.d_trainer = tf.train.RMSPropOptimizer(self.learning_rate).minimize(self.loss_D,var_list = d_vars)
        self.g_trainer = tf.train.RMSPropOptimizer(self.learning_rate).minimize(self.loss_G,var_list = g_vars)
        
    def train(self):
        self.build_model()
        data_num = len(glob.glob( self.shoes_path + '*.jpg'))
        
        with tf.Session() as sess:
            total_batch = int(data_num/ self.batch_size)
            init = tf.global_variables_initializer()
            sess.run(init)
            for epoch in range(self.training_epoch):
                bag_file_list =  glob.glob(self.bag_path+'*.jpg')
                shoe_file_list =  glob.glob(self.shoes_path+'*.jpg')
                
                for iteration in range(total_batch):
                    random.shuffle(bag_file_list)
                    random.shuffle(shoe_file_list)              

                    bag_image,bag_file_list = batch(bag_file_list,self.batch_size)
                    shoe_image,shoe_file_list = batch(shoe_file_list, self.batch_size)
                    
                    bag_image = bag_image / 255.0
                    shoe_image = shoe_image / 255.0
                    
                    _, g_loss_val = sess.run([self.g_trainer, self.loss_G], feed_dict = {self.input_A : bag_image, self.input_B : shoe_image})
                    _, d_loss_val = sess.run([self.d_trainer, self.loss_D], feed_dict = {self.input_A : bag_image, self.input_B : shoe_image})
                    
                    sampleAtoB = sess.run(self.sample_A2B, feed_dict = {self.input_A : bag_image, self.input_B : shoe_image})
                    sampleBtoA = sess.run(self.sample_B2A, feed_dict = {self.input_A : bag_image, self.input_B : shoe_image})
                    
                    
                    print("Epcch : {} , D_loss : {} , G_loss : {}".format(epoch, d_loss_val, g_loss_val))
                sample_image = np.concatenate([sampleAtoB[:16],sampleBtoA[:16]],axis = 0)
                show_images(sample_image)
                plt.savefig('./result/discogan/{}.png'.format(str(epoch).zfill(3)), bbox_inches = 'tight')
                plt.show()
                    

In [None]:
def batch(file_list,batch_size ):
    file_list = list(file_list)        
    random = file_list[:batch_size]
    image_list = io.ImageCollection(random)
    image = image_list.concatenate()

    try :
        for idx in range(batch_size):
            file_list.pop(idx)
    except IndexError:
        print("next epoch")
     

    return image, file_list
    

In [None]:
tf.reset_default_graph()
a = discogan(256,200)