# Intro

The goal of this project is to drawings/sketches from photos. This type of task is often called style transfer and popular approach for this task is to use GANs ([review](https://arxiv.org/abs/1705.04058)).

I've decided to start with [CycleGAN](https://arxiv.org/abs/1703.10593). CycleGAN allows for image-to-image translation without paired examples. This makes it fairly easy to use, as we do not need to get/create both the reference output for each image and instead our input are only images to be translated and images showing the desired style.

Other interesting options are [ArtPDGAN](https://link.springer.com/chapter/10.1007/978-3-030-50436-6_21) and [Im2Pencil](https://arxiv.org/pdf/1903.08682.pdf), both of which are directly focused on drawings specifically but require paired images (original and reference).

## Datasets & Implementation

Datasets are from ["I’m Something of a Painter Myself"](https://www.kaggle.com/c/gan-getting-started) Kaggle competition and [ImageNet-Sketch](https://github.com/HaohanWang/ImageNet-Sketch).

The implementation of CycleGAN is based on the "[Monet CycleGAN Tutorial](https://www.kaggle.com/amyjang/monet-cyclegan-tutorial)".


### Preprocessing

Images are **resized** to 256x256 and **scaled** to <-1;1>. All images are converted to 3 channels for consistency.

## Setup & Running

The notebook can be run directly on Kaggle without any setup. You can reference Kaggle's [Dockerfile](https://github.com/Kaggle/docker-python/blob/main/Dockerfile.tmpl) for local running. The main dependency is `Tensorflow 2.6.0`.

# Setup

In [None]:
from pathlib import Path

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

# from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
!nvidia-smi

Enable TPU if needed.

In [None]:
# Not using TPU
# try:
#     raise ValueError('disable TPU')
#     tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
#     print('Device:', tpu.master())
#     tf.config.experimental_connect_to_cluster(tpu)
#     tf.tpu.experimental.initialize_tpu_system(tpu)
#     strategy = tf.distribute.experimental.TPUStrategy(tpu)
# except:
#     print('...fail...')
#     strategy = tf.distribute.get_strategy()
# print('Number of replicas:', strategy.num_replicas_in_sync)

strategy = tf.distribute.get_strategy()
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
tf.__version__

# Data

## Sketch Dataset

Dataset is from https://github.com/HaohanWang/ImageNet-Sketch.

Most images have 3 channels, some have 1. Sizes vary with median size of 570x600.

In [None]:
img_height, img_width = 256, 256
channels = 3

In [None]:
def decode_img(image):
    # convert to 3-channel
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [img_height, img_width])
    # scale to <-1,1>
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    return image

def process_path(file_path):
    # does not work when using TPU
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    return img

In [None]:
list_ds = tf.data.Dataset.list_files('../input/imagenetsketch/sketch/*/*.JPEG', shuffle=True)
# batch images so they have the right shape for neural net
styled_ds = list_ds.map(process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(1)

In [None]:
def plot(image, ax=None):
    if ax is None:
        ax = plt.gca()
    ax.imshow(image * 0.5 + 0.5)

In [None]:
fix, axs = plt.subplots(1, 5, figsize=(17,17))
for batch, ax in zip(styled_ds.take(5), axs):
    plot(batch[0], ax)
plt.show()

## Photo Dataset

From https://www.kaggle.com/c/gan-getting-started.

In [None]:
list_ds = tf.data.Dataset.list_files('../input/gan-getting-started/photo_jpg/*.jpg', shuffle=True)
photo_ds = list_ds.map(process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(1)

In [None]:
fix, axs = plt.subplots(1, 5, figsize=(17,17))
for batch, ax in zip(photo_ds.take(5), axs):
    plot(batch[0], ax)
plt.show()

# Architecture

I'll be using CycleGAN with the UNET architecture. The networks are composed of `downsample` and `upsamble` blocks. Instance normalization is used which is included in Tensorflow Addons.

## Building Blocks

In [None]:
OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    result.add(layers.LeakyReLU())

    return result

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result

## Generator

The generator downsamples and then upsamples the image while using skip connections (UNET architecture).

In [None]:
def Generator():
    inputs = layers.Input(shape=[img_height,img_width,3])

    # bs = batch size, assuming (h, w) = (256, 256)
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

## Discriminator

Discriminator classifies the image as real or generated. The discriminator outputs a smaller 2D image with higher values indicating a real image.

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

In [None]:
with strategy.scope():
    styled_generator = Generator() # transforms photos to drawings
    photo_generator = Generator() # transforms drawings to photos

    styled_discriminator = Discriminator() # differentiates real and generated drawings
    photo_discriminator = Discriminator() # differentiates real and generated photos

### Sample Run

In [None]:
example = next(iter(photo_ds))

In [None]:
example.shape

In [None]:
styled = styled_generator(example)

plt.subplot(1, 2, 1)
plt.title("Original")
plot(example[0])
plt.show()

plt.subplot(1, 2, 2)
plt.title("Generated")
plot(styled[0])
plt.show()

## CycleGAN

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        styled_generator,
        photo_generator,
        styled_discriminator,
        photo_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.s_gen = styled_generator
        self.p_gen = photo_generator
        self.s_disc = styled_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        s_gen_optimizer,
        p_gen_optimizer,
        s_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.s_gen_optimizer = s_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.s_disc_optimizer = s_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        # @TODO: rename
        real_styled, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # photo to styled back to photo
            fake_styled = self.s_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_styled, training=True)

            # styled to photo back to styled
            fake_photo = self.p_gen(real_styled, training=True)
            cycled_styled = self.s_gen(fake_photo, training=True)

            # generating itself
            same_styled = self.s_gen(real_styled, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_styled = self.s_disc(real_styled, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_styled = self.s_disc(fake_styled, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # evaluates generator loss
            styled_gen_loss = self.gen_loss_fn(disc_fake_styled)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_styled, cycled_styled, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # evaluates total generator loss
            total_styled_gen_loss = styled_gen_loss + total_cycle_loss + self.identity_loss_fn(real_styled, same_styled, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # evaluates discriminator loss
            styled_disc_loss = self.disc_loss_fn(disc_real_styled, disc_fake_styled)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        styled_generator_gradients = tape.gradient(total_styled_gen_loss,
                                                  self.s_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        styled_discriminator_gradients = tape.gradient(styled_disc_loss,
                                                      self.s_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.s_gen_optimizer.apply_gradients(zip(styled_generator_gradients,
                                                 self.s_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.s_disc_optimizer.apply_gradients(zip(styled_discriminator_gradients,
                                                  self.s_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "styled_gen_loss": total_styled_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "styled_disc_loss": styled_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }

## Loss functions

### Discriminator

Discriminator should predict 1s for real images and 0s for fake images. The discriminator loss is the average of the real and generated loss.

In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5

### Generator

The first loss tells us how well the generator tricked the discriminator. Ideally the discriminator should output only 1s for the generated images.

In [None]:
with strategy.scope():
    def generator_loss(generated):
        return tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

The original and twice transformed photos should be similar. We calculate the the average of their (absolute) difference as cycle consistency loss. 


In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

        return LAMBDA * loss1

 The identity loss compares the input with the output of the generator.

In [None]:
with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

# Training

In [None]:
with strategy.scope():
    styled_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    styled_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        styled_generator, photo_generator, styled_discriminator, photo_discriminator
    )

    cycle_gan_model.compile(
        s_gen_optimizer = styled_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        s_disc_optimizer = styled_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )

### Preparing zipped dataset

We need to combine the two datasets so we can provide images in pairs. Datasets also have different sizes and we need to deal with that.

In [None]:
len(photo_ds), len(styled_ds)

#### Dataset combinaton example

In [None]:
# in general yo uwant to batch after shuffling but it doesn't matter for batch_size=1
ds_A = tf.data.Dataset.from_tensor_slices(['a', 'b', 'c']).batch(1).shuffle(3, reshuffle_each_iteration=True).repeat()
ds_B = tf.data.Dataset.range(5).batch(1).shuffle(5, reshuffle_each_iteration=True).repeat()
ds_zipped_dummy = tf.data.Dataset.zip((ds_A, ds_B))

list(ds_zipped_dummy.take(10).as_numpy_iterator())

## Custom Callbacks

In [None]:
class ShowSamplesCallback(keras.callbacks.Callback):
    def __init__(self, save_to_dir: str):
        self.save_to_dir = save_to_dir
    
    def on_train_begin(self, logs=None):
        self.original = []
        self.styled = []
    
    def on_train_end(self, logs=None):
        if self.save_to_dir:
            dir_ = Path(self.save_to_dir)
            dir_.mkdir(parents=True, exist_ok=True)
            for i, (original, styled) in enumerate(zip(self.original, self.styled)):
                tf.keras.preprocessing.image.save_img(dir_ / f'{i + 1}_original.jpg', original)
                tf.keras.preprocessing.image.save_img(dir_ / f'{i + 1}_styled.jpg', styled)
    
    def on_epoch_end(self, epoch, logs=None):
        styled_generator = self.model.s_gen
        # sample predict
        img_batch = next(iter(photo_ds.take(1)))
        prediction = styled_generator(img_batch, training=False)[0].numpy()
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
        img = img_batch[0]
        img = (img * 127.5 + 127.5).numpy().astype(np.uint8)
        self.original.append(img)
        self.styled.append(prediction)
        
        # plot
        fig, axs = plt.subplots(1, 2, figsize=(6,6), squeeze=True)
        axs[0].imshow(img)
        axs[1].imshow(prediction)
        axs[0].set_title("Input Photo [{:02d}]".format(epoch + 1))
        axs[1].set_title("Generated [{:02d}]".format(epoch + 1))
        axs[0].axis("off")
        axs[1].axis("off")
        plt.show()

In [None]:
class SaveGeneratorCallback(keras.callbacks.Callback):
    def __init__(self, epochs: int, path: str):
        self.epochs = epochs
        self.dir_ = Path(path)
    
    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.epochs == 0:
            styled_generator = self.model.s_gen
            path = self.dir_ / f'weights.{epoch + 1}.hdf5'
            print(f'[SaveGeneratorCallback] Saving weights to {path}.')
            styled_generator.save_weights(path)

### Training

In [None]:
HOUR_TO_SECONDS = 3600

In [None]:
_buffer_size = 3000

history = cycle_gan_model.fit(
    # zipped dataset - infinitely repeating
    tf.data.Dataset.zip((
        # buffer_size should be at least the size of the dataset for uniform shuffling
        # - https://stackoverflow.com/a/47025850/3936732
        # - however the dataset is too large for that so we limit the size
        styled_ds.shuffle(_buffer_size, reshuffle_each_iteration=True).repeat(),
        photo_ds.shuffle(_buffer_size, reshuffle_each_iteration=True).repeat(),
    )),
    
    steps_per_epoch=1000, # batches per epoch
    epochs=100,
    # --- callbacks ---
    callbacks=[
        tfa.callbacks.TimeStopping(seconds=HOUR_TO_SECONDS * 6, verbose=1),
        ShowSamplesCallback('/kaggle/working/imgs'),
        SaveGeneratorCallback(50, '/kaggle/working'),
              ],
)

# Visualize

In [None]:
_, ax = plt.subplots(2, 5, figsize=(16, 12))
for i, img in enumerate(photo_ds.take(5)):
    prediction = styled_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[0, i].imshow(img)
    ax[1, i].imshow(prediction)
    ax[0, i].set_title("Input Photo")
    ax[1, i].set_title("Generated")
    ax[0, i].axis("off")
    ax[1, i].axis("off")
plt.tight_layout()
plt.show()

# Model Saving/Loading

## Loading weights only (saved with callback)

In [None]:
_TEST = False

In [None]:
if _TEST:
    generator = Generator()
    generator.load_weights('/kaggle/working/weights.1.hdf5')
    sample_photo = next(iter(photo_ds))
    prediction = generator(sample_photo)
    plot(sample_photo[0])
    plt.show()
    plot(prediction[0])
    plt.show()

## Final Model

In [None]:
!mkdir -p /kaggle/working/saved_model
styled_generator.save('/kaggle/working/saved_model/styled_generator')

In [None]:
loaded_styled_generator = tf.keras.models.load_model('/kaggle/working/saved_model/styled_generator')

In [None]:
n_rows = 3 * 2
n_cols = 5
_, ax = plt.subplots(n_rows, n_cols, figsize=(n_cols*4, n_rows*4))
row = 0
i = 0
for img in photo_ds.take(n_rows//2*n_cols):
    prediction = loaded_styled_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)
    
    ax[row + 0, i].imshow(img)
    ax[row + 1, i].imshow(prediction)
    ax[row + 0, i].set_title("Input Photo")
    ax[row + 1, i].set_title("Generated")
    ax[row + 0, i].axis("off")
    ax[row + 1, i].axis("off")
    
    i += 1
    if i >= n_cols:
        i = 0
        row += 2
    
plt.tight_layout()
plt.show()