# selfie to webtoon U-GAT-IT (Tensorflow 2.0.0 -rc0)

In [None]:
import tensorflow as tf
import glob
import os
import time
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras import layers, models, Model
if tf.__version__ != '2.0.0-rc0':
    raise Exception('Tensorflow Version is not correct, Required Version : 2.0.0-rc0')

In [None]:
init_he = tf.initializers.he_normal(2019)
init_rand_norm = tf.random_normal_initializer(0., 0.02)
kernel_regularizer = tf.keras.regularizers.l2(0.0001)

### normalization

In [None]:
def instance_norm(x):
    eps = 1e-5
    ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2])
    return (x - ins_mean) / (tf.sqrt(ins_sigma + eps))

def layer_norm(x):
    eps = 1e-5
    ln_mean, ln_sigma = tf.nn.moments(x, axes=[1, 2, 3])
    return (x - ln_mean) / (tf.sqrt(ln_sigma + eps))

def layer_instance_norm(x):
    ch = np.shape(x)[-1]
    x_ins = instance_norm(x)
    x_ln = layer_norm(x)
    rho = tf.Variable(initial_value = np.zeros(ch),constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0), dtype = tf.float32)
    gamma = tf.Variable(initial_value = np.ones([ch]), dtype = tf.float32)
    beta = tf.Variable(initial_value = np.zeros([ch]), dtype = tf.float32)

    x_hat = rho * x_ins + (1 - rho) * x_ln
    x_hat = x_hat * gamma + beta
    return x_hat

def adaptive_instance_layer_norm(x, gamma, beta, smoothing = True):
    ch = np.shape(x)[-1]
    x_ins = instance_norm(x)
    x_ln = layer_norm(x)
    if smoothing:
        rho = tf.Variable(initial_value = np.ones([ch]) * 0.9,constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=0.9), dtype = tf.float32)
    else:
        rho = tf.Variable(initial_value = np.ones([ch]) * 1.0,constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0), dtype = tf.float32)
    x_hat = rho * x_ins + (1 - rho) * x_ln
    x_hat = x_hat * gamma + beta
    return x_hat

def resblock(x_init, channels):
    x = layers.Conv2D(channels, 3, strides = 1, padding = 'same')(x_init)
    x = instance_norm(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(channels, 3, strides = 1, padding = 'same')(x)
    x = instance_norm(x)
    return x + x_init

def adaptive_ins_layer_resblock(x_init, ch, gamma, beta) :
    x = layers.Conv2D(ch, 3, strides = 1, padding = 'same')(x_init)
    x = adaptive_instance_layer_norm(x, gamma, beta)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(ch, 3, strides = 1, padding = 'same')(x)
    x = adaptive_instance_layer_norm(x, gamma, beta)
    return x + x_init

def MLP(x, ch):
    ch = np.shape(x)[-1]
    x = global_avg_pooling(x)
    for i in range(2) :
        x, _ = fully_connected_with_w(x)
        x = layers.Activation('relu')(x)
    gamma= fully_connected(x, ch)
    gamma = tf.reshape(gamma, shape = [-1, 1, 1, ch])
    beta = fully_connected(x, ch)
    beta = tf.reshape(gamma, shape = [-1, 1, 1, ch])
    return gamma, beta

def downsample(filters, size, input_shape = None):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))
    result.add(tf.keras.layers.ReLU())
    return result
def upsample(filters, size):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2,padding='same',kernel_initializer=initializer,use_bias=False))
    result.add(tf.keras.layers.ReLU())
    return result

def global_avg_pooling(x):
    gap = tf.reduce_mean(x, axis=[1, 2])
    return gap

def global_max_pooling(x):
    gmp = tf.reduce_max(x, axis=[1, 2])
    return gmp

def flatten(x) :
    return layers.Flatten()(x)

def fully_connected_with_w(x):
    x = flatten(x)
    shape = x.get_shape().as_list()
    channels = shape[-1]
    w = tf.Variable(tf.random.normal([channels, 1], mean=0.0, stddev=0.02), tf.float32,)
    x = tf.matmul(x, spectral_norm(w))
    weights = tf.gather(tf.transpose(w), 0)
    return x, weights
def fully_connected(x, units):
    x = flatten(x)
    shape = x.get_shape().as_list()
    channels = shape[-1]
    w = tf.Variable(tf.random.normal([channels, units], mean=0.0, stddev=0.02), tf.float32,)
    
    x = tf.matmul(x, spectral_norm(w))
    return x

def spectral_norm(w, iteration=1):
    w_shape = w.shape.as_list()
    w = tf.reshape(w, [-1, w_shape[-1]])
    u = tf.Variable(tf.random.normal([1, w_shape[-1]]))###########

    u_hat = u
    v_hat = None
    for i in range(iteration):
        v_ = tf.matmul(u_hat, tf.transpose(w))
        v_hat = tf.nn.l2_normalize(v_)

        u_ = tf.matmul(v_hat, w)
        u_hat = tf.nn.l2_normalize(u_)

    u_hat = tf.stop_gradient(u_hat)
    v_hat = tf.stop_gradient(v_hat)

    sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))


    w_norm = w / sigma
    w_norm = tf.reshape(w_norm, w_shape)


    return w_norm

### Generator

In [None]:
def Generator():
    ch = 256
    down_stack = [
        downsample(ch // 2, 3), 
        downsample(ch, 3)
    ]

    up_stack = [
        upsample(ch, 3),
        upsample(ch // 2, 3)
    ]
    initializer = tf.random_normal_initializer(0., 0.02)
    inputs = tf.keras.layers.Input(shape=[256,256,3])
    x = inputs
    for down in down_stack:
        x = down(x)
    for i in range(4):
        x = resblock(x, ch)
    
    gap_att = global_avg_pooling(x)
    cam_gap_logit, cam_x_weight = fully_connected_with_w(gap_att)
    x_gap = tf.multiply(x, cam_x_weight)
    
    gmp_att = global_max_pooling(x)
    cam_gmp_logit, cam_x_weight = fully_connected_with_w(gmp_att)
    x_gmp = tf.multiply(x, cam_x_weight)
    
    cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
    x = tf.concat([x_gap, x_gmp], axis=-1)
    
    gamma, beta = MLP(x, ch)
    
    
    x = tf.keras.layers.Conv2D(ch * 2, 1, strides=1, padding='same', use_bias=False, activation = 'relu')(x)
    
    for i in range(4):
        x = adaptive_ins_layer_resblock(x, ch * 2, gamma, beta)
        
    for up in up_stack:
        x = up(x)
        
    last = layers.Conv2D(3, 1, strides = 1, padding = 'same', activation = 'tanh')(x)
    return Model(inputs, [last, cam_logit])

### Discriminator

In [None]:
def dis_downsample(filters, size, input_shape = None):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))
    result.add(tf.keras.layers.tf.keras.layers.LeakyReLU())
    return result
def discriminator_global(x, ch):
    initializer = tf.random_normal_initializer(0., 0.02)
    down_stack = [
        downsample(ch // 8, 3),
        downsample(ch // 4, 3),
        downsample(ch // 2, 3),
        downsample(ch, 3)
    ]
    for down in down_stack:
        x = down(x)
        
    gap_att = global_avg_pooling(x)
    cam_gap_logit, cam_x_weight = fully_connected_with_w(gap_att)
    x_gap = tf.multiply(x, cam_x_weight)
    
    gmp_att = global_max_pooling(x)
    cam_gmp_logit, cam_x_weight = fully_connected_with_w(gmp_att)
    x_gmp = tf.multiply(x, cam_x_weight)
    
    cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
    x = tf.concat([x_gap, x_gmp], axis=-1)
    x = tf.keras.layers.Conv2D(ch, 1, strides=1, padding='same', kernel_initializer=initializer, use_bias=False)(x)
    x = tf.keras.layers.Conv2D(1, 1, strides=1, padding='same', kernel_initializer=initializer, use_bias=False)(x)
    return x, cam_logit

def discriminator_local(x, ch):
    initializer = tf.random_normal_initializer(0., 0.02)
    down_stack = [
        downsample(ch // 2, 3),
        downsample(ch, 3)
    ]
    for down in down_stack:
        x = down(x)
        
    gap_att = global_avg_pooling(x)
    cam_gap_logit, cam_x_weight = fully_connected_with_w(gap_att)
    x_gap = tf.multiply(x, cam_x_weight)
    
    gmp_att = global_max_pooling(x)
    cam_gmp_logit, cam_x_weight = fully_connected_with_w(gmp_att)
    x_gmp = tf.multiply(x, cam_x_weight)
    
    cam_logit = tf.concat([cam_gap_logit, cam_gmp_logit], axis=-1)
    x = tf.concat([x_gap, x_gmp], axis=-1)
    x = tf.keras.layers.Conv2D(ch, 1, strides=1, padding='same', kernel_initializer=initializer, use_bias=False)(x)
    x = tf.keras.layers.Conv2D(1, 1, strides=1, padding='same', kernel_initializer=initializer, use_bias=False)(x)
    return x, cam_logit

def Discriminator():
        D_logit = []
        D_CAM_logit = []
        ch = 256
        inputs = tf.keras.layers.Input(shape=[256,256,3])
        local_x, local_cam= discriminator_local(inputs, ch)
        global_x, global_cam = discriminator_global(inputs, ch)
        D_logit.extend([local_x, global_x])
        D_CAM_logit.extend([local_cam, global_cam])
        return Model(inputs, [D_logit, D_CAM_logit])

### save model

In [None]:
saved_model_list = glob.glob('./model_save/*')
if len(saved_model_list) == 4:
    print('load model')
    gen_g = models.load_model('./model_save/gen_g.h5')
    gen_f = models.load_model('./model_save/gen_f.h5')
    dis_x = models.load_model('./model_save/dis_x.h5')
    dis_y = models.load_model('./model_save/dis_y.h5')
else:
    print('make new model')
    gen_g = Generator()
    gen_f = Generator()
    dis_x = Discriminator()
    dis_y = Discriminator()

In [None]:
EPOCHS = 2
lr = 2e-4
lr = 0.0001
gen_g_optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.5)
gen_f_optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.5)
dis_x_optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.5)
dis_y_optimizer = tf.keras.optimizers.Adam(lr, beta_1=0.5)

loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

### loss

In [None]:
#LAMBDA = 10
def discriminator_loss(real, generated):
    real_loss = loss_obj(tf.ones_like(real), real)
    generated_loss = loss_obj(tf.zeros_like(generated), generated)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss * 0.5

def generator_loss(generated):
    return loss_obj(tf.ones_like(generated), generated)

def calc_cycle_loss(real_image, cycled_image):
    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
    return LAMBDA * loss1

def identity_loss(real_image, same_image):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return LAMBDA * 0.5 * loss

def cam_loss(source, non_source) :

    identity_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(source), logits=source))
    non_identity_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(non_source), logits=non_source))
    loss = identity_loss + non_identity_loss
    return loss

def L1_loss(x, y):
    loss = tf.reduce_mean(tf.abs(x - y))
    return loss

def discriminate_real(x_A, x_B):
    real_A_logit, real_A_cam_logit= dis_x(x_A)
    real_B_logit, real_B_cam_logit= dis_y(x_B)
    return real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit

def discriminate_fake(x_ba, x_ab):
    fake_A_logit, fake_A_cam_logit = dis_x(x_ba)
    fake_B_logit, fake_B_cam_logit = dis_y(x_ab)
    return fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit

def generate_images(model, test_input, epoch):
    prediction,_ = model(test_input)

    plt.figure(figsize=(12, 12))

    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
    plt.savefig('./save/image_at_epoch_{:04d}.jpg'.format(epoch))
    plt.show()


### dataset

In [None]:
basic_dir = './data/'
img_folder = ['selfie', 'cartoon']
X_array = np.load(basic_dir + 'selfie.npz')['arr_0']
Y_array = np.load(basic_dir + 'cartoon.npz')['arr_0']
Y_array = Y_array
X_array = X_array
print(X_array.shape, Y_array.shape)

### train

In [None]:
@tf.function # 데코레이터
def train_step(real_x, real_y, gen_loss_weights):
    with tf.GradientTape(persistent=True) as tape:
        x_ab, cam_ab = gen_g(real_x)
        x_ba, cam_ba = gen_f(real_y)

        x_aba, _ = gen_f(x_ab)
        x_bab, _ = gen_g(x_ba)

        x_aa, cam_aa = gen_f(real_x)
        x_bb, cam_bb = gen_g(real_y)
        
        real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = discriminate_real(real_x, real_y)
        fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = discriminate_fake(x_ba, x_ab)
    
        cam_A = cam_loss(cam_ba, cam_aa)
        cam_B = cam_loss(cam_ab, cam_bb)
        
        G_ad_loss_A = generator_loss(fake_A_logit[0]) + generator_loss(fake_A_logit[1])
        G_ad_loss_B = generator_loss(fake_B_logit[0]) + generator_loss(fake_B_logit[1])

        D_ad_loss_A = discriminator_loss(real_A_logit[0], fake_A_logit[0]) + discriminator_loss(real_A_logit[1], fake_A_logit[1])
        D_ad_loss_B = discriminator_loss(real_B_logit[0], fake_B_logit[0]) + discriminator_loss(real_B_logit[1], fake_B_logit[1])
        
        D_cam_loss_A = discriminator_loss(real_A_cam_logit[0], fake_A_cam_logit[0]) + discriminator_loss(real_A_cam_logit[1], fake_A_cam_logit[1])
        D_cam_loss_B = discriminator_loss(real_B_cam_logit[0], fake_B_cam_logit[0]) + discriminator_loss(real_B_cam_logit[1], fake_B_cam_logit[1])

        identity_A = L1_loss(x_aa, real_x)
        identity_B = L1_loss(x_bb, real_y)
        
        reconstruction_A = L1_loss(x_aba, real_x)
        reconstruction_B = L1_loss(x_bab, real_y)
        
        Generator_A_gan = gen_loss_weights[0] * G_ad_loss_A
        Generator_A_cycle = gen_loss_weights[1] * reconstruction_B
        Generator_A_identity = gen_loss_weights[2] * identity_A
        Generator_A_cam = gen_loss_weights[3] * cam_A


        Generator_B_gan = gen_loss_weights[0] * G_ad_loss_B
        Generator_B_cycle = gen_loss_weights[1] * reconstruction_A
        Generator_B_identity = gen_loss_weights[2] * identity_B
        Generator_B_cam = gen_loss_weights[3] * cam_B
        
        Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam
        Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam
        
        Discriminator_A_loss = D_ad_loss_A + D_cam_loss_A
        Discriminator_B_loss = D_ad_loss_B + D_cam_loss_B


    generator_g_gradients = tape.gradient(Generator_A_loss, 
                                        gen_g.trainable_variables)
    generator_f_gradients = tape.gradient(Generator_B_loss, 
                                        gen_f.trainable_variables)

    discriminator_x_gradients = tape.gradient(Discriminator_A_loss, 
                                            dis_x.trainable_variables)
    discriminator_y_gradients = tape.gradient(Discriminator_B_loss, 
                                            dis_y.trainable_variables)


    gen_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            gen_g.trainable_variables))

    gen_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            gen_f.trainable_variables))

    dis_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                dis_x.trainable_variables))

    dis_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                dis_y.trainable_variables))

In [None]:
batch_size = 1
gen_loss_weights = [1, 10, 100, 1000] ### lambda값조절하면서 돌리자!
for epoch in range(EPOCHS):
    start = time.time()
    n = 0
    nn = 0
    while True:
        try:
            image_x = X_array[n * batch_size:(n + 1) * batch_size] 
            image_y = Y_array[n * batch_size:(n + 1) * batch_size] 
        except:
            image_x = X_array[n * batch_size:]
            image_y = Y_array[n * batch_size:]
        if len(image_x) == 0:
            break
        train_step(image_x, image_y, gen_loss_weights)
        if n % 10 == 0:
            print (nn, '/', end='')
            nn += 1
        n+=1
        if len(image_x) < batch_size:
            break

    generate_images(gen_g, sample_x, epoch)
    
    gen_g.save('./model_save/gen_g.h5')
    gen_f.save('./model_save/gen_f.h5')
    dis_x.save('./model_save/dis_x.h5')
    dis_y.save('./model_save/dis_y.h5')

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

### generate image

In [None]:
def generate_img(model, test_input):
    prediction,_ = model(test_input)

    plt.figure(figsize=(12, 12))

    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')

    plt.show()
sample_x = load_img('./data/selfie/female_10.jpg', target_size = (256, 256))
sample_x = img_to_array(sample_x) / 255.
sample_x = np.reshape(sample_x, (-1, 256, 256, 3))
generate_img(gen_g, sample_x)

### save image

In [None]:
def generate_img(model, test_input, epoch):
    prediction,_ = model(test_input)

    plt.figure(figsize=(12, 12))

    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
    plt.savefig('./result/{:04d}.jpg'.format(epoch))

for i, sample_x in enumerate(X_array):
    sample_x = np.reshape(sample_x, (-1, 256, 256, 3))
    generate_img(gen_g, sample_x, i)