In [None]:
from __future__ import print_function, division
import scipy
import tensorflow as tf
from PIL import Image
from keras_contrib.layers.normalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, Add
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.layers.core import Lambda
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os


class Logo2Font():
    def __init__(self):        
        # Input shape
        self.img_rows = 256
        self.img_cols = 256
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        # Configure data loader
        self.dataset_name = 'font'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.img_rows, self.img_cols))



        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
        self.gf = 64
        self.df = 64

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
        
        # Build the edge conv
        self.edge = self.get_edge()
        self.edge.compile(loss='categorical_crossentropy',
            optimizer=optimizer)

        #-------------------------
        # Construct Computational
        #   Graph of Generator
        #-------------------------

        # Build the generators
        self.generator = self.build_generator()
        self.generator_r = self.build_generator()
        
        # For the combined model we will only train the generator
        self.discriminator.trainable = False  
        self.edge.trainable = False 
        
        # Input images and their conditioning images
        img_A = Input(shape=self.img_shape)
        img_Base = Input(shape=self.img_shape)
        img_cycle = Input(shape=self.img_shape)

        # By conditioning on A generate a fake version of B
        fake_B = self.generator([img_A,img_Base])
        re_A = self.generator_r([fake_B,img_cycle])
        
        # Get fake_B edge
        edge_B = self.edge(fake_B)
        
        # Discriminators determines validity of Style condition / Content condition / Generate images 
        valid = self.discriminator([img_A, img_Base, fake_B])
        
        self.combined = Model(inputs=[img_A, img_Base, img_cycle], outputs=[valid, fake_B, edge_B, re_A])
        self.combined.compile(loss=['mse', 'mae', 'mae','mae'],
                              loss_weights=[10, 100, 100, 100],
                              optimizer=optimizer)
        
        #Load LOGO tests
        imgs, adidas_, huawei_, nike_  = self.data_loader.load_logo()
        self.imgs = imgs
        self.adidas_ = adidas_
        self.huawei_ = huawei_
        self.nike_ = nike_
        
    def build_generator(self):
        """DualInput U-Net Generator"""
        
        def conv2d(layer_input, filters, f_size=4, bn=True):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        def deconv2d(layer_input, skip_input,filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = BatchNormalization(momentum=0.8)(u)
            u = Concatenate()([u, skip_input])
            return u
        
        def deconv2d_triple(layer_input, skip_input, baseconv, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = BatchNormalization(momentum=0.8)(u)
            u = Concatenate()([u, skip_input, baseconv])
            return u
        
        # Image input
        img_A = Input(shape=self.img_shape)
        img_Base = Input(shape=self.img_shape)
        
        #combined_imgs = Concatenate(axis=-1)([img_A, img_Base])
        
        # Content Downsampling
        c1 = conv2d(img_Base, self.gf, bn=False)
        c2 = conv2d(c1, self.gf*2)
        c3 = conv2d(c2, self.gf*4)
        c4 = conv2d(c3, self.gf*8)
        c5 = conv2d(c4, self.gf*8)
        c6 = conv2d(c5, self.gf*8)
        
        
        # Style Downsampling
        s1 = conv2d(img_A, self.gf, bn=False)
        s2 = conv2d(s1, self.gf*2)
        s3 = conv2d(s2, self.gf*4)
        s4 = conv2d(s3, self.gf*8)
        s5 = conv2d(s4, self.gf*8)
        s6 = conv2d(s5, self.gf*8)
        s7 = conv2d(s6, self.gf*8)
        
        # Upsampling
        u1 = deconv2d(s7, c6, self.gf*8)
        u2 = deconv2d(u1, c5, self.gf*8)
        u3 = deconv2d(u2, c4, self.gf*8)
        u4 = deconv2d(u3, s3, self.gf*4)
        u5 = deconv2d(u4, s2, self.gf*2)
        u6 = deconv2d(u5, s1, self.gf)

        u7 = UpSampling2D(size=2)(u6)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)
          
        return Model([img_A,img_Base], output_img)

    def build_discriminator(self):

        def d_layer(layer_input, filters, f_size=4, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)
        img_Base = Input(shape=self.img_shape)

        combined_imgs = Concatenate(axis=-1)([img_A, img_Base, img_B])

        d1 = d_layer(combined_imgs, self.df, bn=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)

        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

        return Model([img_A, img_Base, img_B], validity)
    
   
    def get_edge(self):
        img = Input(shape=self.img_shape)
        kernel = np.array([(1,1,1,1),(1,-4,-2,1),(1,-2,-4,1),(1,1,1,1)])#边缘提取滤波卷积核
        def gxl_kernel(shape):
            return np.expand_dims(np.expand_dims(kernel, axis=2), axis=2)
        img_edge = Conv2D(filters=1, kernel_size=4, kernel_initializer= gxl_kernel, strides=4, padding='same')(img)
        
        return Model(img, img_edge)

    def train(self, epochs, batch_size=1, sample_interval=50):
        # Load weights
        if os.path.exists("logo2font_model_bf/G_dualmodelw175.hdf5"):
            self.generator.load_weights("logo2font_model_bf/G_dualmodelw175.hdf5",by_name=True)
            self.generator_r.load_weights("logo2font_model_bf/Gr_dualmodelw175.hdf5",by_name=True)
            self.discriminator.load_weights("logo2font_model_bf/D_dualmodelw175.hdf5",by_name=True)
            print("Load weights!!")

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)
        
        #记录训练数据用以画图
        logs = []
        dt_rate = 5
        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B, imgs_Base, imgs_cycle) in enumerate(self.data_loader.load_batch(batch_size)):
                
                # -------------------------------
                #  Train Discriminator 训练判别器
                # -------------------------------

                if batch_i % dt_rate == 0: #训练1次判别器，训练dt_rate次生成器
                    # Conditions and generate a translated version
                    fake_B = self.generator.predict([imgs_A, imgs_Base])
                    
                    # Train the discriminators (original images = real / generated = Fake)
                    d_loss_real = self.discriminator.train_on_batch([imgs_A, imgs_Base, imgs_B], valid)
                    d_loss_fake = self.discriminator.train_on_batch([imgs_A, imgs_Base, fake_B], fake)
                    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
                    
                    #Set train rate 
                    if 100*d_loss[1] > 80:
                        dt_rate += 1
                    elif 100*d_loss[1] < 50 and dt_rate > 1:
                        dt_rate -= 1

                # ---------------------------------
                #  Train Generators 训练循环生成器
                # ---------------------------------
                #imgs_A->B 
                B_edge = self.edge.predict(imgs_B)
                g_loss = self.combined.train_on_batch([imgs_A, imgs_Base, imgs_cycle], [valid, imgs_B, B_edge, imgs_A])
                
                
                #imgs_B->A 
                A_edge = self.edge.predict(imgs_A)
                gc_loss = self.combined.train_on_batch([imgs_B, imgs_cycle, imgs_Base], [valid, imgs_A, A_edge, imgs_B])

                elapsed_time = datetime.datetime.now() - start_time
                
                print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f] [Gc loss: %f] time: %s" % (epoch, epochs,
                                                                        batch_i, self.data_loader.n_batches,
                                                                        d_loss[0], 100*d_loss[1],
                                                                        g_loss[0], 
                                                                        gc_loss[0],                                             
                                                                        elapsed_time))
                
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)
                
                # save weight & model
                if (epoch+9) % 10 == 0:
                    epcnt = epoch 
                    if not os.path.exists("logo2font_model_bf"):
                        os.makedirs("logo2font_model_bf")
                    if not os.path.exists("logo2font_model_bf/G_dualmodelw%d.hdf5" % epcnt):
                        self.generator.save_weights("logo2font_model_bf/G_dualmodelw%d.hdf5" % epcnt,True)
                        self.generator_r.save_weights("logo2font_model_bf/Gr_dualmodelw%d.hdf5" % epcnt,True)
                        self.discriminator.save_weights("logo2font_model_bf/D_dualmodelw%d.hdf5" % epcnt,True)
                        
                
                if (epoch+9) % 15 == 0:
                    epcnt = epoch
                    if not os.path.exists("logo2font_model_bf/trained_model"):
                        os.makedirs("logo2font_model_bf/trained_model")
                    if not os.path.exists("logo2font_model_bf/trained_model/G_dmodel%d.hdf5" % epcnt):
                        self.generator.save("logo2font_model_bf/trained_model/G_dmodel%d.hdf5" % epcnt,True)
                        self.generator_r.save("logo2font_model_bf/trained_model/Gr_dmodel%d.hdf5" % epcnt,True)
                        self.discriminator.save("logo2font_model_bf/trained_model/D_dmodel%d.hdf5" % epcnt,True)
                        
            logs.append([epoch, d_loss[0], d_loss[1], g_loss[0]])
            self.showlogs(logs,epoch)                                      
                

    def sample_images(self, epoch, batch_i):
        os.makedirs('images/%s_bf' % self.dataset_name, exist_ok=True)
        os.makedirs('images/%s_bf_r' % self.dataset_name, exist_ok=True)
        os.makedirs('images/%s_bf_adi' % self.dataset_name, exist_ok=True)
        os.makedirs('images/%s_bf_nike' % self.dataset_name, exist_ok=True)
        os.makedirs('images/%s_bf_huawei' % self.dataset_name, exist_ok=True)
        os.makedirs('images/%s_bf_out' % self.dataset_name, exist_ok=True)
        
        r, c = 4, 3

        imgs_A, imgs_B, imgs_Base, imgs_cycle = self.data_loader.load_data(batch_size=3, is_testing=True)
        
        #随机测试集
        fake_B = self.generator.predict([imgs_A,imgs_Base]) 
        re_A = np.squeeze(self.generator_r.predict([fake_B,imgs_cycle]))
        
        #提取边缘
        edge_B = self.edge.predict(fake_B)
        B_edge = self.edge.predict(imgs_B)
        edge_imgs = np.squeeze(np.concatenate([edge_B, B_edge]))
        edge_imgs = 0.5 * edge_imgs + 0.5
        
        imgs_A = np.squeeze(imgs_A)
        imgs_B = np.squeeze(imgs_B)
        imgs_Base = np.squeeze(imgs_Base)
        imgs_cycle = np.squeeze(imgs_cycle)
        fake_B = np.squeeze(fake_B)
        
        gen_imgs = np.concatenate([imgs_A, imgs_Base, fake_B, imgs_B])

        # 重整图片到0-1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Style', 'Content', 'Generated',  'Original']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt], cmap='gray')
                axs[i,j].set_title(titles[i])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s_bf/%d_%d.png" % (self.dataset_name, epoch , batch_i))
        plt.close()
        
        #循环重建效果
        re_imgs = np.concatenate([fake_B, imgs_cycle, re_A, imgs_A])
        #re_imgs = 0.5 * re_imgs + 0.5

        titles_r = ['Generated', 'Content', 'Rebuild',  'Original']
        fig_r, axs_r = plt.subplots(r, c)
        cnt_r = 0
        for i in range(r):
            for j in range(c):
                axs_r[i,j].imshow(re_imgs[cnt_r], cmap='gray')
                axs_r[i,j].set_title(titles_r[i])
                axs_r[i,j].axis('off')
                cnt_r += 1
        fig_r.savefig("images/%s_bf_r/%d_%d.png" % (self.dataset_name, epoch , batch_i))
        plt.close()
        
        # 边缘控制效果图
        fig_e, axs_e = plt.subplots(2, c)
        cnt_e = 0
        for i in range(2):
            for j in range(c):
                axs_e[i,j].imshow(edge_imgs[cnt_e], cmap='gray')
                axs_e[i,j].axis('off')
                cnt_e += 1
        fig_e.savefig("images/%s_bf_out/%d_%d.png" % (self.dataset_name, epoch , batch_i))
        plt.close()
        
        
        #固定logo测试集
        adi_font = np.squeeze(self.generator.predict([self.adidas_,self.imgs]))
        nike_font = np.squeeze(self.generator.predict([self.nike_,self.imgs]))
        huawei_font = np.squeeze(self.generator.predict([self.huawei_,self.imgs]))
        
        # adidas样本集 
        fig_adi, axs_adi = plt.subplots(5, 6)
        cnt_adi = 0
        for a in range(5):
            for d in range(6):
                if cnt_adi < 26:
                    axs_adi[a, d].imshow(adi_font[cnt_adi], cmap='gray')
                axs_adi[a, d].axis('off')
                cnt_adi += 1
        fig_adi.savefig("images/%s_bf_adi/%d_%d.png" % (self.dataset_name, epoch , batch_i))
        plt.close() 
        
        # nike样本集 
        fig_nike, axs_nike = plt.subplots(5, 6)
        cnt_nike = 0
        for n in range(5):
            for k in range(6):
                if cnt_nike < 26:
                    axs_nike[n,k].imshow(nike_font[cnt_nike], cmap='gray')
                axs_nike[n,k].axis('off')
                cnt_nike += 1
        fig_nike.savefig("images/%s_bf_nike/%d_%d.png" % (self.dataset_name, epoch , batch_i))
        plt.close() 
        
        # huawei样本集 
        fig_hua, axs_hua = plt.subplots(5, 6)
        cnt_hua = 0
        for h in range(5):
            for w in range(6):
                if cnt_hua < 26:
                    axs_hua[h,w].imshow(huawei_font[cnt_hua], cmap='gray')
                axs_hua[h,w].axis('off')
                cnt_hua += 1
        fig_hua.savefig("images/%s_bf_huawei/%d_%d.png" % (self.dataset_name, epoch , batch_i))
        plt.close() 
        
    def showlogs(self, logs, epoch):
        os.makedirs('images/%s_bf_log' % self.dataset_name, exist_ok=True)
        logs = np.array(logs)
        names = ["d_loss", "d_acc", "g_loss"]
        for i in range(3):
            plt.subplot(2, 2, i + 1)
            plt.plot(logs[:, 0], logs[:, i + 1])
            plt.xlabel("epoch")
            plt.ylabel(names[i])
        plt.tight_layout()
        plt.savefig("images/%s_bf_log/%d.png" % (self.dataset_name, epoch))
        plt.close()
        
                                          
if __name__ == '__main__':
    dualcgan = Logo2Font()
    dualcgan.train(epochs=441, batch_size=10, sample_interval=300)
    