In [1]:
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from glob import glob
import os
import time

In [2]:
path_x = glob('./monet2photo/trainB/*.jpg')[:1193]
path_y = glob('./monet2photo/trainA/*.jpg')
test_input = tf.io.read_file('./monet2photo/testB/2014-08-03 17_39_45.jpg')
test_input = tf.image.decode_jpeg(test_input, channels=3)
BATCH_SIZE = 1
epochs = 200
initializer = tf.random_normal_initializer(0., 0.02)  #Weights are initialized from aGaussian distribution N (0, 0.02).
#We train our networks from scratch, with a learning rate of 0.0002.
learning_rate = tf.keras.optimizers.schedules.CosineDecay(initial_learning_rate=2e-4,
                                                          decay_steps=250000,
                                                          alpha=0)
beta_1 = 0.5
alpha = 0.2  #We use leaky ReLUs with a slope of 0.2
LAMBDA = 10  #We set λ = 10

In [3]:
def load_data(path_x, path_y):
    #real_x
    img_x = tf.io.read_file(path_x)
    img_x = tf.image.decode_jpeg(img_x, channels=3)
    img_x = tf.image.resize(img_x, [256, 256])
    
    #normalize
    img_x = tf.cast(img_x, dtype=tf.float32)
    img_x = (img_x / 127.5) - 1
    
    #real_y
    img_y = tf.io.read_file(path_y)
    img_y = tf.image.decode_jpeg(img_y, channels=3)
    img_y = tf.image.resize(img_y, [256, 256])
    
    #normalize
    img_y = tf.cast(img_y, dtype=tf.float32)
    img_y = (img_y / 127.5) - 1
    
    #randomflip
    if tf.random.uniform(()) > 0.5:
        img_x = tf.image.random_flip_left_right(img_x)
        img_y = tf.image.random_flip_left_right(img_y)
    return img_x, img_y

def generate_save_img(model, epoch, test_input):
    test_input = tf.expand_dims(test_input, axis=0)
    preds = model(test_input, training=False)
    preds = (preds + 1) * 127.5
    preds = tf.cast(preds, dtype=tf.uint8)
    
    plt.figure(figsize=(7,7))
    display_list = [test_input[0], preds[0]]
    title = ['Input Image', 'Predicted image']
    for i in range(2):
        plt.subplot(1,2,i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
    plt.savefig('./CYCLEGAN/image/image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()
    plt.close()

In [4]:
train_dataset = tf.data.Dataset.from_tensor_slices((path_x, path_y))
train_dataset = train_dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(10000)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

In [5]:
#Reflection padding was used to reduce artifacts
#c7s1-k denote a 7×7 Convolution-InstanceNormReLU layer with k filters and stride 1
def c7s1_k(inputs, filters, kernel_size, strides, padding, kernel_initializer):
    
    rpad = tf.pad(inputs, [[0,0],[3,3],[3,3],[0,0]], 'REFLECT')
    conv = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, 
                                  padding=padding, 
                                  kernel_initializer=kernel_initializer)(rpad)
    norm = tfa.layers.InstanceNormalization()(conv)
    relu = tf.keras.layers.ReLU()(norm)
    
    return relu

#dk denotes a 3 × 3 Convolution-InstanceNorm-ReLU layer with k filters and stride 2
def dk(inputs, filters, kernel_size, strides, padding, kernel_initializer):
    
    rpad = tf.pad(inputs, [[0,0],[1,1],[1,1],[0,0]], 'REFLECT')  
    conv = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, 
                                  padding=padding, 
                                  kernel_initializer=kernel_initializer)(rpad)
    norm = tfa.layers.InstanceNormalization()(conv)
    relu = tf.keras.layers.ReLU()(norm)
    
    return relu

#Rk denotes a residual block that contains two 3 × 3 convolutional layers with the same number of filters on both layer
def Rk(inputs, filters, kernel_size, strides, padding, kernel_initializer):
    
    rpad1 = tf.pad(inputs, [[0,0],[1,1],[1,1],[0,0]], 'REFLECT')  
    Rk1 = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
                               padding='same', 
                               kernel_initializer=kernel_initializer,
                               activation="relu")(rpad1)
    rpad2 = tf.pad(Rk1, [[0,0],[1,1],[1,1],[0,0]], 'REFLECT')  
    Rk2 = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
                               padding=padding,
                               kernel_initializer=kernel_initializer)(Rk1)
    x = tf.keras.layers.Add()([inputs, Rk2])
    
    return x

#uk denotes a 3 × 3 fractional-strided-ConvolutionInstanceNorm-ReLU layer with k filters and stride 1/2.
def uk(inputs, filters, kernel_size, strides, padding, kernel_initializer):
      
    conv = tf.keras.layers.Conv2DTranspose(filters=filters, kernel_size=kernel_size, strides=strides, 
                                           padding=padding,
                                           kernel_initializer=kernel_initializer)(inputs)
    norm = tfa.layers.InstanceNormalization()(conv)
    relu = tf.keras.layers.ReLU()(norm)
    
    return relu

#Ck denote a 4 × 4 Convolution-InstanceNorm-LeakyReLU layer with k filters and stride 2
def ck(inputs, filters, kernel_size, strides, padding, kernel_initializer, alpha):
    
    conv = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, 
                                  padding=padding,
                                  kernel_initializer=kernel_initializer)(inputs)
    norm = tfa.layers.InstanceNormalization()(conv)
    Lrelu = tf.keras.layers.LeakyReLU(alpha=alpha)(norm)
    
    return Lrelu

In [6]:
def generator():
    
    inputs = tf.keras.layers.Input(shape=[256,256,3])

    c7s1_64 = c7s1_k(inputs, filters=64, kernel_size=7, strides=1, padding='valid', kernel_initializer=initializer)
    
    d128 = dk(c7s1_64, filters=128, kernel_size=3, strides=2, padding='valid', kernel_initializer=initializer)
    d256 = dk(d128, filters=256, kernel_size=3, strides=2, padding='valid', kernel_initializer=initializer)
    
    R256_1 = Rk(d256, filters=256, kernel_size=3, strides=1, padding='valid', kernel_initializer=initializer)
    R256_2 = Rk(R256_1, filters=256, kernel_size=3, strides=1, padding='valid', kernel_initializer=initializer)
    R256_3 = Rk(R256_2, filters=256, kernel_size=3, strides=1, padding='valid', kernel_initializer=initializer)
    R256_4 = Rk(R256_3, filters=256, kernel_size=3, strides=1, padding='valid', kernel_initializer=initializer)
    R256_5 = Rk(R256_4, filters=256, kernel_size=3, strides=1, padding='valid', kernel_initializer=initializer)
    R256_6 = Rk(R256_5, filters=256, kernel_size=3, strides=1, padding='valid', kernel_initializer=initializer)
    R256_7 = Rk(R256_6, filters=256, kernel_size=3, strides=1, padding='valid', kernel_initializer=initializer)
    R256_8 = Rk(R256_7, filters=256, kernel_size=3, strides=1, padding='valid', kernel_initializer=initializer)
    R256_9 = Rk(R256_8, filters=256, kernel_size=3, strides=1, padding='valid', kernel_initializer=initializer)
    
    u128 = uk(R256_9, filters=128, kernel_size=3, strides=2, padding='same', kernel_initializer=initializer)
    u64 = uk(u128, filters=64, kernel_size=3, strides=2, padding='same', kernel_initializer=initializer)
    
    c7s1_3 = tf.keras.layers.Conv2D(filters=3, kernel_size=7, strides=1, 
                                    padding='same', 
                                    kernel_initializer=initializer,
                                    activation='tanh')(u64)
    return tf.keras.Model(inputs=inputs, outputs=c7s1_3)

In [7]:
def discriminator():
    
    inputs = tf.keras.layers.Input(shape=[256,256,3])
    
    c64 = tf.keras.layers.Conv2D(filters=64, kernel_size=4, strides=2,
                                 padding='same',
                                 kernel_initializer=initializer)(inputs)  #We do not use InstanceNorm for the first C64 layer
    Lrelu = tf.keras.layers.LeakyReLU(alpha=alpha)(c64)
    c128 = ck(Lrelu, filters=128, kernel_size=4, strides=2, padding='same', kernel_initializer=initializer, alpha=alpha)
    drop1 = tf.keras.layers.Dropout(0.3)(c128)
    c256 = ck(drop1, filters=256, kernel_size=4, strides=2, padding='same', kernel_initializer=initializer, alpha=alpha)
    drop2 = tf.keras.layers.Dropout(0.3)(c256)
    c512 = ck(drop2, filters=512, kernel_size=4, strides=2, padding='same', kernel_initializer=initializer, alpha=alpha)
    drop3 = tf.keras.layers.Dropout(0.3)(c512)
    #After the last layer, we apply a convolution to produce a 1-dimensional output
    outputs = tf.keras.layers.Conv2D(filters=1, kernel_size=4, strides=2, padding='same', activation='sigmoid')(drop3)
    
    return tf.keras.Model(inputs=inputs, outputs=outputs)

In [8]:
G = generator()
F = generator()
Dx = discriminator()
Dy = discriminator()

In [9]:
beloss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
l1loss = tf.keras.losses.MeanAbsoluteError()

In [10]:
def d_loss(real_img, fake_img):
    real = beloss(tf.ones_like(real_img), real_img)
    fake = beloss(tf.zeros_like(fake_img), fake_img)
    total = real + fake
    return total

In [11]:
def g_loss(fake_img):
    return beloss(tf.ones_like(fake_img), fake_img)

In [12]:
def cycle_loss(real_img, cycle_img):
    loss = l1loss(real_img, cycle_img)
    return LAMBDA * loss

In [13]:
def identity_loss(real_img, same_img):
    loss = l1loss(real_img, same_img)
    return LAMBDA * 0.5 * loss  #identity mapping loss was 0.5λ

In [14]:
g_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=beta_1)
f_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=beta_1)
dx_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=beta_1)
dy_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=beta_1)

In [15]:
checkpoint_dir = './CYCLEGAN/checkpoint'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
checkpoint = tf.train.Checkpoint(G=G, 
                                 F=F, 
                                 Dx=Dx,
                                 Dy=Dy,
                                 g_optimizer=g_optimizer,
                                 f_optimizer=f_optimizer,
                                 dx_optimizer=dx_optimizer,
                                 dy_optimizer=dy_optimizer)

In [16]:
@tf.function
def train_step(real_x, real_y):
    with tf.GradientTape(persistent=True) as tape:
        
        fake_y = G(real_x, training=True)
        cycle_x = F(fake_y, training=True)
        
        fake_x = F(real_y, training=True)
        cycle_y = G(fake_x, training=True)
        
        same_x = F(real_x, training=True)
        same_y = G(real_y, training=True)
        
        d_real_x = Dx(real_x, training=True)
        d_real_y = Dy(real_y, training=True)
        
        d_fake_x = Dx(fake_x, training=True)
        d_fake_y = Dy(fake_y, training=True)
        
        gen_g_loss = g_loss(d_fake_y)
        gen_f_loss = g_loss(d_fake_x)
        total_cycle_loss = cycle_loss(real_x, cycle_x) + cycle_loss(real_y, cycle_y)
        
        total_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
        total_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)
        dx_loss = d_loss(d_real_x, d_fake_x)
        dy_loss = d_loss(d_real_y, d_fake_y)
        
    g_gradients = tape.gradient(total_g_loss, G.trainable_variables)
    f_gradients = tape.gradient(total_f_loss, F.trainable_variables)
    dx_gradients = tape.gradient(dx_loss, Dx.trainable_variables)
    dy_gradients = tape.gradient(dy_loss, Dy.trainable_variables)
    
    g_optimizer.apply_gradients(zip(g_gradients, G.trainable_variables))
    f_optimizer.apply_gradients(zip(f_gradients, F.trainable_variables))
    dx_optimizer.apply_gradients(zip(dx_gradients, Dx.trainable_variables))
    dy_optimizer.apply_gradients(zip(dy_gradients, Dy.trainable_variables))    

In [17]:
def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()
        
        for real_x, real_y in dataset:
            train_step(real_x, real_y)
            
        generate_save_img(G, epoch+1, test_input)
        
        if (epoch + 1) % 5 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
            
        print('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

In [None]:
%%time
train(train_dataset, epochs)