In [None]:
"""
For more detail:
https://seachaos.com/transfer-learning-with-gan-cyclegan-from-scratch-1afc9ab7c7d1
"""

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_datasets as tfds

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import trange, tqdm

import random

In [None]:
dataset, dataset_info = tfds.load('cycle_gan/horse2zebra', with_info=True, as_supervised=True)

train_a, train_b = dataset['trainA'], dataset['trainB']
test_a, test_b = dataset['testA'], dataset['testB']

In [None]:
batch_size = 32  # set to 16 or less, if you don't have enough VRAM.

img_size = 128
big_img_size = 192

LR = 0.00012

In [None]:
def _process_img(image, label):
        image = tf.image.resize(image, (big_img_size, big_img_size))
        image = (image / 127.5) - 1.0
        return image, label

def prepare_data(data, b=batch_size):
    return data \
        .cache() \
        .map(_process_img, num_parallel_calls=tf.data.AUTOTUNE) \
        .shuffle(b) \
        .batch(b)

ds_train_a, ds_train_b = prepare_data(train_a), prepare_data(train_b)
ds_test_a, ds_test_b = prepare_data(test_a), prepare_data(test_b)


x_train_sets = [
    tf.concat([a[0] for a in ds_train_a], axis=0),
    tf.concat([b[0] for b in ds_train_b], axis=0),
]

x_test_sets = [
    tf.concat([a[0] for a in ds_test_a], axis=0),
    tf.concat([b[0] for b in ds_test_b], axis=0),
]

print('x_train_all: ', sum([s.shape[0] for s in x_train_sets]), x_train_sets[0].numpy().min(), x_train_sets[0].numpy().max())
print('x_test_all: ', sum([s.shape[0] for s in x_test_sets]), x_test_sets[0].numpy().min(), x_test_sets[0].numpy().max())


In [None]:
def _rand_pick(data, augment=True):
    idx = np.random.choice(range(len(data)), size=batch_size, replace=False)
    x = tf.gather(data, idx, axis=0)
    if augment:
        cx = random.uniform(1.0, 1.5)
        cy = random.uniform(1.0, 1.5)
        x = tf.image.random_crop(x, size=(batch_size, int(img_size * cx), int(img_size * cy), 3))
        x = tf.image.random_flip_left_right(x)
    x = tf.image.resize(x, (img_size, img_size))
    return x

def get_x_train():
    xa = _rand_pick(x_train_sets[0])
    xb = _rand_pick(x_train_sets[1])
    return xa, xb

def get_x_test():
    xa = _rand_pick(x_test_sets[0], augment=False)
    xb = _rand_pick(x_test_sets[1], augment=False)
    return xa, xb

In [None]:
# Verify "get_x_train" output
def cvtImg(x):
    return (x + 1.0) / 2.0

def show(x, S=12):
    x = cvtImg(x)
    plt.figure(figsize=(15, 3))
    for i in range(min(len(x), S)):
        plt.subplot(1, S, i + 1)
        plt.imshow(x[i])
        plt.axis('off')
    plt.show()

for _ in range(1):
    xa, xb = get_x_train()
    xa = xa.numpy()
    print(xa.min(), xa.max(), xa.shape)
    show(xa)
    show(xb.numpy())

In [None]:
base_model = tf.keras.applications.VGG16(input_shape=(img_size, img_size, 3), include_top=False)

x = x_input = base_model.input

outputs = [
    'block2_conv2',
    'block3_conv3',
    'block4_conv3',
    'block5_conv1',
    'block5_pool',
]

x_output = [base_model.get_layer(n).output for n in outputs]
base_model = tf.keras.models.Model(x_input, x_output)

base_model.trainable = False

# base_model.summary() # if you want see more detail about VGG16

In [None]:
act_name = 'gelu'

def act(x):
    x = layers.LayerNormalization()(x)
    x = layers.Activation(act_name)(x)
    return x

In [None]:
def conv_with_cmd(x_img_input, x_cmd, f=64, sp=4):
    x = layers.Dense(128)(x_cmd)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(act_name)(x)
    
    x = layers.Dense(f)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('sigmoid')(x)

    x_g = layers.Reshape((1, 1, f))(x)

    # ---

    x = layers.Conv2D(f, kernel_size=3, padding='same')(x_img_input)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(act_name)(x)
    x =  x * x_g


    return x

In [None]:
def create_gen_model():
    # img input
    x_input = layers.Input(shape=(img_size, img_size, 3))

    # load base model
    x_base_out = base_model(x_input)
    [x64, x32, x16, x8, x4] = x_base_out


    # x_cmd
    x = x4
    x = layers.Conv2D(256, kernel_size=3, padding='same')(x)
    x = act(x)

    x = layers.GlobalMaxPool2D()(x)


    x = layers.Dense(128)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(act_name)(x)
    x_cmd = x

    
    # GAN up
    x = conv_with_cmd(x4, x_cmd, f=512)

    # if you don't have enought VRAM, try reduce filters
    for i, (x_cat, f) in enumerate([
        (x8, 512),
        (x16, 384),
        (x32, 256),
        (x64, 256),
        (x_input, 256),
    ]):
        x = layers.UpSampling2D(2)(x)
        x = layers.Concatenate()([x, x_cat])
        x = conv_with_cmd(x, x_cmd, f=f)
 
    # final output
    x = layers.Conv2D(3, kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('tanh')(x)

    return tf.keras.models.Model(x_input, x)

gen = create_gen_model()
# gen.summary() # if you want see more detail about model

In [None]:
def create_dis_model():
    x = x_input = layers.Input(shape=(img_size, img_size, 3))

    [x64, x32, x16, x8, x4] = base_model(x_input)

    x = x8
    x = layers.Conv2D(512, kernel_size=3, padding='same')(x)
    x = act(x)
    x = layers.MaxPool2D()(x)
    
    x = layers.Concatenate()([x, x4])
    x = layers.Conv2D(512, kernel_size=3, padding='same')(x)
    x = act(x)
    
    x = layers.GlobalMaxPool2D()(x)

    x = layers.Dense(384)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(act_name)(x)
    
    x = layers.Dense(128)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(act_name)(x)
    
    x = layers.Dense(4)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('softmax')(x)
    
    return tf.keras.models.Model(x_input, x)

dis = create_dis_model()
# dis.summary() # if you want see more detail about model

In [None]:
y_false_a = np.zeros(batch_size)
y_false_b = np.full_like(y_false_a, 1)
y_true_a = np.full_like(y_false_a, 2)
y_true_b = np.full_like(y_false_a, 3)

In [None]:
opt_gen = tf.keras.optimizers.AdamW(learning_rate=LR)
opt_dis = tf.keras.optimizers.AdamW(learning_rate=LR)

In [None]:
@tf.function
def _train_dis(x, y_t):
    with tf.GradientTape(persistent=True) as tape:
        y_p = dis(x)
        loss = tf.losses.sparse_categorical_crossentropy(y_t, y_p)
        loss = tf.reduce_mean(loss)

    g = tape.gradient(loss, dis.trainable_variables)
    g = zip(g, dis.trainable_variables)
    opt_dis.apply_gradients(g)
    
    return float(loss)

def train_dis():
    dis.trainable = True
    gen.trainable = False
    base_model.trainable = False

    xa, xb = get_x_train()

    # train dis A
    xa_fake = gen.predict(xb, verbose=False)
    loss_a = \
        _train_dis(xa, y_true_a) + \
        _train_dis(xa_fake, y_false_a)

    # train dis B
    xb_fake = gen.predict(xa, verbose=False)
    loss_b = \
        _train_dis(xb, y_true_b) + \
        _train_dis(xb_fake, y_false_b)
    
    return float(loss_a), float(loss_b)

train_dis()

In [None]:
@tf.function
def _train_gen_cycle(x_real, y_t, y_f):
    with tf.GradientTape(persistent=True) as tape:
        x_fake = gen(x_real) # forward
        
        # discriminator
        y_p = dis(x_fake)
        loss_dis = tf.losses.sparse_categorical_crossentropy(y_t, y_p)

        # revert
        x_revert = gen(x_fake)
        loss_revert = tf.losses.mse(x_real, x_revert)

        loss = tf.reduce_mean(loss_dis) + tf.reduce_mean(loss_revert)


    g = tape.gradient(loss, gen.trainable_variables)
    g = zip(g, gen.trainable_variables)
    opt_gen.apply_gradients(g)

    return float(loss)

def train_gen():
    gen.trainable = True
    dis.trainable = False
    base_model.trainable = False

    xa, xb = get_x_train()

    loss_a = \
        _train_gen_cycle(xa, y_true_b, y_true_a)
    
    loss_b = \
        _train_gen_cycle(xb, y_true_a, y_true_b)

    return float(loss_a), float(loss_b)

train_gen()

In [None]:
def _preview(x_real, title=None):
    x_fake = gen.predict(x_real, verbose=0)
    x_real = cvtImg(x_real.numpy())
    x_fake = cvtImg(x_fake)


    plt.figure(figsize=(25, 5))
    if title:
        plt.suptitle(title)
    s = min(batch_size, 9)
    for i in range(s):
        plt.subplot(2, s, i + 1)
        plt.axis('off')
        plt.imshow(x_real[i])
        plt.subplot(2, s, i + 1 + s)
        plt.axis('off')
        plt.imshow(x_fake[i])
    plt.show()

def preview(useTest=True):
    if useTest:
        xa, xb = get_x_test()
    else:
        xa, xb = get_x_train()
    _preview(xa[:9], 'A -> B')
    _preview(xb[:9], 'B -> A')

preview()

In [None]:
def train():
    bar = trange(200)
    for _ in bar:
        lda, ldb = train_dis()
        lga, lgb = train_gen()
        msg = f'gen: {lga:.5f}, {lgb:.5f} | dis: {lda:.5f}, {ldb:.5f}'
        bar.set_description(msg)

def go():
    for i in trange(50):
        train()
        if i % 5 == 0:
            preview()
        
        opt_dis.learning_rate = opt_dis.learning_rate * 0.98
        opt_gen.learning_rate = opt_gen.learning_rate * 0.98
        lg = opt_gen.learning_rate.numpy()
        ld = opt_dis.learning_rate.numpy()
        print(f'run: {i}')
        print(f'LR gen: {lg:.7f}')
        print(f'LR dis: {ld:.7f}')


go()
preview()