## Keras implementation of https://github.com/junyanz/CycleGAN

In [None]:
# tf.Session(config=tf.ConfigProto(log_device_placement=True))

In [None]:
import keras.backend as K
import tensorflow as tf

In [None]:
K.set_learning_phase(1)

In [None]:

from keras.optimizers import RMSprop, SGD, Adam
from keras.models import Sequential, Model
from keras.layers import Conv2D, ZeroPadding2D, BatchNormalization, Input, Dropout
from keras.layers import Conv2DTranspose, UpSampling2D, Activation, Add, Lambda
from keras.layers.advanced_activations import LeakyReLU
from keras.activations import relu
from keras.initializers import RandomNormal
from keras_contrib.layers.normalization import InstanceNormalization

In [None]:
# Weights initializations

# for convolution kernel
conv_init = RandomNormal(0, 0.02)
# for batch normalization
gamma_init = RandomNormal(1., 0.02) 

In [None]:
def conv2d(f, *a, **k):
    return Conv2D(f, kernel_initializer = conv_init, *a, **k)
def batchnorm():
    return BatchNormalization(momentum=0.9, axis=3, epsilon=1e-5,
                                   gamma_initializer = gamma_init)

In [None]:
def conv_block(x, filters, size, stride=(2, 2), has_norm_layer=True, use_norm_instance=False,
               has_activation_layer=True, use_leaky_relu=False, padding='same'):
    x = conv2d(filters, (size, size), strides=stride, padding=padding)(x)
    if has_norm_layer:
        if not use_norm_instance:
            x = batchnorm()(x)
        else:
            x = InstanceNormalization(axis=1)(x)
    if has_activation_layer:
        if not use_leaky_relu:
            x = Activation('relu')(x)
        else:
            x = LeakyReLU(alpha=0.2)(x)
    return x

def res_block(x, filters=256, use_dropout=False):
    y = conv_block(x, filters, 3, (1, 1))
    if use_dropout:
        y = Dropout(0.5)(y)
    y = conv_block(y, filters, 3, (1, 1), has_activation_layer=False)
    return Add()([y, x])

# decoder block
def up_block(x, filters, size, use_conv_transpose=True, use_norm_instance=False):
    if use_conv_transpose:
        x = Conv2DTranspose(filters, kernel_size=size, strides=2, padding='same',
                            use_bias=True if use_norm_instance else False,
                            kernel_initializer=RandomNormal(0, 0.02))(x)
        x = batchnorm()(x)
        x = Activation('relu')(x)
    else:
        x = UpSampling2D()(x)
        x = conv_block(x, filters, size, (1, 1))
    return x

In [None]:
# Defines the PatchGAN discriminator

In [None]:
def n_layer_discriminator(image_size=256, input_nc=3, ndf=64, hidden_layers=3):
    """
        input_nc: input channels
        ndf: filters of the first layer
    """
    inputs = Input(shape=(image_size, image_size, input_nc))
    x = inputs
    
    x = ZeroPadding2D(padding=(1, 1))(x)
    x = conv_block(x, ndf, 4, has_norm_layer=False, use_leaky_relu=True, padding='valid')
    
    x = ZeroPadding2D(padding=(1, 1))(x)
    for i in range(1, hidden_layers + 1):
        nf = 2 ** i * ndf
        x = conv_block(x, nf, 4, use_leaky_relu=True, padding='valid')
        x = ZeroPadding2D(padding=(1, 1))(x)
        
    x = conv2d(1, (4, 4), activation='sigmoid', strides=(1, 1))(x)
    outputs = x
    return Model(inputs=inputs, outputs=outputs)

In [None]:
# Defines the generator

In [None]:
def resnet_generator(image_size=256, input_nc=3, res_blocks=6, use_conv_transpose=True):
    inputs = Input(shape=(image_size, image_size, input_nc))
    x = inputs
    
    x = conv_block(x, 64, 7, (1, 1))
    x = conv_block(x, 128, 3, (2, 2))
    x = conv_block(x, 256, 3, (2, 2))
    
    for i in range(res_blocks):
        x = res_block(x)
        
    x = up_block(x, 128, 3, use_conv_transpose=use_conv_transpose)
    x = up_block(x, 64, 3, use_conv_transpose=use_conv_transpose)
    
    x = conv2d(3, (7, 7), activation='tanh', strides=(1, 1) ,padding='same')(x)    
    outputs = x
    return Model(inputs=inputs, outputs=outputs), inputs, outputs

In [None]:
dpath = '/home/lin/Downloads/weights-data/'
image_size=256
batch_size = 1
input_nc = 3

In [None]:
netD_A = n_layer_discriminator()
netD_B = n_layer_discriminator()
# netD_A.summary()
# netD_B.summary()

In [None]:
def criterion_GAN(output, target, use_lsgan=True):
    if use_lsgan:
        diff = output-target
        dims = list(range(1,K.ndim(diff)))
        return K.expand_dims((K.mean(diff**2, dims)), 0)
    else:
        return K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target))
    
def criterion_cycle(rec, real):
    diff = K.abs(rec-real)
    dims = list(range(1,K.ndim(diff)))
    return K.expand_dims((K.mean(diff, dims)), 0)

In [None]:
def netG_loss(inputs, cycle_loss_weight=1.0):
    netD_B_predict_fake, rec_A, real_A, netD_A_predict_fake, rec_B, real_B = inputs
    
    loss_G_A = criterion_GAN(netD_B_predict_fake, K.ones_like(netD_B_predict_fake))
    loss_cyc_A = criterion_cycle(rec_A, real_A)
    
    loss_G_B = criterion_GAN(netD_A_predict_fake, K.ones_like(netD_A_predict_fake))
    loss_cyc_B = criterion_cycle(rec_B, real_B)
    
    loss_G = loss_G_A + loss_G_B + cycle_loss_weight * (loss_cyc_A+loss_cyc_B)
    return loss_G

In [None]:
def netD_loss(netD_predict):
    netD_predict_real, netD_predict_fake = netD_predict
    
    netD_loss_real = criterion_GAN(netD_predict_real, K.ones_like(netD_predict_real))
    netD_loss_fake = criterion_GAN(netD_predict_fake, K.zeros_like(netD_predict_fake))
    
    loss_netD= 0.5  *  (netD_loss_real + netD_loss_fake)
    return loss_netD

In [None]:
netD_A = n_layer_discriminator()
netD_B = n_layer_discriminator()
# netD_A.summary()
# netD_B.summary()

In [None]:
netG_A, real_A, fake_B = resnet_generator(use_conv_transpose=True)
netG_B, real_B, fake_A = resnet_generator(use_conv_transpose=True)
# netG_A.summary()
# netG_B.summary()

In [None]:
# make generater train function

In [None]:
netD_B_predict_fake = netD_B(fake_B)
rec_A= netG_B(fake_B)
netD_A_predict_fake = netD_A(fake_A)
rec_B = netG_A(fake_A)
lambda_layer_inputs = [netD_B_predict_fake, rec_A, real_A, netD_A_predict_fake, rec_B, real_B]

for l in netG_A.layers: 
    l.trainable=True
for l in netG_B.layers: 
    l.trainable=True
for l in netD_A.layers: 
    l.trainable=False
for l in netD_B.layers: 
    l.trainable=False
        
netG_train_function = Model([real_A, real_B],Lambda(netG_loss)(lambda_layer_inputs))
Adam(lr=2e-4, beta_1=0.5, beta_2=0.999, epsilon=None, decay=0.0)
netG_train_function.compile('adam', 'mae')

In [None]:
# make discriminator A train function

In [None]:
netD_A_predict_real = netD_A(real_A)

_fake_A = Input(shape=(image_size, image_size, input_nc))
_netD_A_predict_fake = netD_A(_fake_A)

for l in netG_A.layers: 
    l.trainable=False
for l in netG_B.layers: 
    l.trainable=False
for l in netD_A.layers: 
    l.trainable=True      
for l in netD_B.layers: 
    l.trainable=False

netD_A_train_function = Model([real_A, _fake_A], Lambda(netD_loss)([netD_A_predict_real, _netD_A_predict_fake]))
netD_A_train_function.compile('adam', 'mae')

In [None]:
# make discriminator B train function

In [None]:
netD_B_predict_real = netD_B(real_B)

_fake_B = Input(shape=(image_size, image_size, input_nc))
_netD_B_predict_fake = netD_B(_fake_B)

for l in netG_A.layers: 
    l.trainable=False
for l in netG_B.layers: 
    l.trainable=False
for l in netD_B.layers: 
     l.trainable=True  
for l in netD_A.layers: 
    l.trainable=False 
        
netD_B_train_function= Model([real_B, _fake_B], Lambda(netD_loss)([netD_B_predict_real, _netD_B_predict_fake]))
netD_B_train_function.compile('adam', 'mae')

In [None]:
import numpy as np
import glob
import time
from PIL import Image
from random import randint, shuffle

def load_data(file_pattern):
    return glob.glob(file_pattern)

def read_image(img, loadsize=286, imagesize=256):
    img = Image.open(img).convert('RGB')
    img = img.resize((loadsize, loadsize), Image.BICUBIC)
    img = np.array(img)
    img = img.astype(np.float32)
    img = (img-127.5) / 127.5
    # random jitter
    w_offset = h_offset = randint(0, max(0, loadsize - imagesize - 1))
    img = img[h_offset:h_offset + imagesize,
          w_offset:w_offset + imagesize, :]
    # horizontal flip
    if randint(0, 1):
        img = img[:, ::-1]
    return img

def try_read_img(data, index):
    try:
        img = read_image(data[index])
        return img
    except:
        try_read_img(data, index + 1)

train_A = load_data('/home/lin/Downloads/m-cycle/trainA/*')
train_B = load_data('/home/lin/Downloads/m-cycle/trainB/*')
print(len(train_A))
print(len(train_B))

val_A = load_data('/home/lin/Downloads/m-cycle/testA/*')
val_B = load_data('/home/lin/Downloads/m-cycle/testB/*')

In [None]:
def minibatch(data, batch_size):
    length = len(data)
    epoch = i = 0
    tmpsize = None   
    
    while True:
        size = tmpsize if tmpsize else batch_size
        if i+size > length:
            shuffle(data)
            i = 0
            epoch+=1        
        rtn = []
        for j in range(i,i+size):
            img = try_read_img(data, j)
            rtn.append(img)
                
        i+=size
        tmpsize = yield epoch, np.float32(rtn)

def minibatchAB(dataA, dataB, batch_size):
    batchA=minibatch(dataA, batch_size)
    batchB=minibatch(dataB, batch_size)
    tmpsize = None    
    while True:
        ep1, A = batchA.send(tmpsize)
        ep2, B = batchB.send(tmpsize)
        tmpsize = yield max(ep1, ep2), A, B

In [None]:
from IPython.display import display
def display_image(X, rows=1):
    assert X.shape[0]%rows == 0
    int_X = ((X*127.5+127.5).clip(0,255).astype('uint8'))
    int_X = int_X.reshape(-1,image_size,image_size, 3)
    int_X = int_X.reshape(rows, -1, image_size, image_size,3).swapaxes(1,2).reshape(rows*image_size,-1, 3)
    pil_X = Image.fromarray(int_X)
    t = str(round(time.time()))
    pil_X.save(dpath+'results/'+ t, 'JPEG')
    display(pil_X)

In [None]:
train_batch = minibatchAB(train_A, train_B, 6)

_, A, B = next(train_batch)
display_image(A)
display_image(B)
_, A, B = next(train_batch)
display_image(A)
display_image(B)
del train_batch, A, B

In [None]:
val_batch = minibatchAB(val_A, val_B, 6)

_, A, B = next(val_batch)
display_image(A)
display_image(B)
del val_batch, A, B

In [None]:
def get_output(netG_alpha, netG_beta, X):
    real_input = X
    fake_output = netG_alpha.predict(real_input)
    rec_input = netG_beta.predict(fake_output)
    outputs = [fake_output, rec_input]
    return outputs

In [None]:
def get_combined_output(netG_alpha, netG_beta, X):
    r = [get_output(netG_alpha, netG_beta, X[i:i+1]) for i in range(X.shape[0])]
    r = np.array(r)
    return r.swapaxes(0,1)[:,:,0]  

In [None]:
def show_generator_image(A,B, netG_alpha,  netG_beta):
    assert A.shape==B.shape
      
    rA = get_combined_output(netG_alpha, netG_beta, A)
    rB = get_combined_output(netG_beta, netG_alpha, B)
    
    arr = np.concatenate([A,B,rA[0],rB[0],rA[1],rB[1]])    
    display_image(arr, 3)

In [None]:
def get_generater_function(netG):
    real_input = netG.inputs[0]
    fake_output = netG.outputs[0]
    function = K.function([real_input], [fake_output])
    return function

netG_A_function = get_generater_function(netG_A)
netG_B_function = get_generater_function(netG_B)

In [None]:
import warnings
warnings.simplefilter('error', Image.DecompressionBombWarning)

In [None]:
import time
from IPython.display import clear_output
time_start = time.time()
how_many_epochs = 2
iteration_count = 410000
epoch_count = 0
batch_size = 1
display_freq = 1000
val_batch = minibatchAB(val_A, val_B, batch_size=4)
train_batch = minibatchAB(train_A, train_B, batch_size)

while epoch_count < how_many_epochs: 
    target_label = np.zeros((batch_size, 1))
    epoch_count, A, B = next(train_batch)

    try:
        _fake_B = netG_A_function([A])[0]
        _fake_A = netG_B_function([B])[0]
    except:
        epoch_count, A, B = next(train_batch)
        _fake_B = netG_A_function([A])[0]
        _fake_A = netG_B_function([B])[0]
        

    netG_train_function.train_on_batch([A, B], target_label)
    
    netD_B_train_function.train_on_batch([B, _fake_B], target_label)
    netD_A_train_function.train_on_batch([A, _fake_A], target_label)
    
    iteration_count+=1
    
    if iteration_count%display_freq==0:
        clear_output()
        traintime =  (time.time()-time_start)/iteration_count
        print('epoch_count: {}  iter_count: {}  timecost/iter: {}s'.format(epoch_count, iteration_count, traintime))
        _, val_A, val_B = next(val_batch)
        show_generator_image(val_A,val_B, netG_A, netG_B)
        save_name = dpath + '{}' + str(iteration_count) + '.h5'
        
        netG_A.save_weights(save_name.format('tf_GA_weights'))
        netG_B.save_weights(save_name.format('tf_GB_weights'))
        netD_A.save_weights(save_name.format('tf_DA_weights'))
        netD_B.save_weights(save_name.format('tf_DB_weights'))
        netG_train_function.save_weights(save_name.format('tf_G_train_weights'))
        netD_A_train_function.save_weights(save_name.format('tf_D_A_train_weights'))
        netD_B_train_function.save_weights(save_name.format('tf_D_B_train_weights'))

In [None]:
# inference

In [None]:
K.learning_phase()

In [None]:
save_name = dpath + '{}' + '200000.h5'
netG_A.load_weights(save_name.format('tf_GA_weights'))
netG_B.load_weights(save_name.format('tf_GB_weights'))

In [None]:
val_batch = minibatchAB(val_A, val_B, batch_size=2)

In [None]:
# run batch normalization layer in inference mode

In [None]:
_,A, B = next(val_batch)
save_name = dpath + '{}' + '400000.h5'
netG_A.load_weights(save_name.format('tf_GA_weights'))
netG_B.load_weights(save_name.format('tf_GB_weights'))
show_generator_image(A,B, netG_A, netG_B)

In [None]:
# run batch normalization layer in training mode

In [None]:
def get_cycle_generater (netG_alpha, netG_beta):
    real_input = netG_alpha.inputs[0]
    fake_output = netG_alpha.outputs[0]
    rec_input = netG_beta([fake_output])
    generater = K.function([real_input, K.learning_phase()], [fake_output, rec_input])
    return generater

cycleA_generater = get_cycle_generater(netG_A, netG_B)
cycleB_generater = get_cycle_generater(netG_B, netG_B)

In [None]:
def show_netG(A,B):
    assert A.shape==B.shape
    def G(generater, X):
        r = np.array([generater([X[i:i+1], 1]) for i in range(X.shape[0])])
        return r.swapaxes(0,1)[:,:,0]        
    rA = G(cycleA_generater, A)
    rB = G(cycleB_generater, B)
    arr = np.concatenate([A,B,rA[0],rB[0],rA[1],rB[1]])
    display_image(arr, 3)

In [None]:
_,A, B = next(val_batch)
show_netG(A,B)