In [None]:
import tensorflow as tf
from tensorflow.keras import Sequential, Model, applications
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Dropout, ReLU, LeakyReLU, Input, Concatenate, ZeroPadding2D
from tensorflow_addons.layers import InstanceNormalization

from IPython.display import clear_output
import matplotlib.pyplot as plt
import matplotlib as mpl
from PIL import Image

import numpy as np

import time


AUTOTUNE = tf.data.AUTOTUNE

In [None]:
def show_batch(image_batch):
    columns = 8
    rows = len(image_batch) / columns + 1  
    plt.figure(figsize = (16, 2 * rows))
    for n in range(len(image_batch)):
        ax = plt.subplot(int(rows), columns, n+1)
        plt.imshow((image_batch[n]), cmap='gray')
        plt.axis('off')
        
        
def generate_and_save_images(model, epoch, test_input):
    # Notice `training` is set to False.
    # This is so all layers run in inference mode (batchnorm).
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(8,8))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 255.0, cmap='gray')
        plt.axis('off')

    plt.savefig('./epochs/image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()        

## Load and prepare data

In [None]:
def decode_img(img):
  img = tf.image.decode_png(img, channels = 3)
  return img


def get_bytes_and_label(file_path):
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img

In [None]:
def createDataset(path, flag):
    if flag:
        regex = '/*/*.jpg'
    else:
        regex = '/*.jpg'

    dataset = tf.data.Dataset.list_files(path + regex)
    dataset = dataset.map(get_bytes_and_label, num_parallel_calls = AUTOTUNE)

    return dataset

path = ['cezanne', 'monet', 'ukiyoe', 'vangogh']

INDEX = 3
ALL = False
ID = True


if ALL:
    art_train, art_test = createDataset(f'./Datasets/art_train', True), createDataset(f'./Datasets/art_test', True)
    photos_train, photos_test = createDataset('./Datasets/photos/train', False), createDataset('./Datasets/photos/test', False)

else:
    art_train, art_test = createDataset(f'./Datasets/art_train/{path[INDEX]}', False), createDataset(f'./Datasets/art_test/{path[INDEX]}', False)
    photos_train, photos_test = createDataset('./Datasets/photos/train', False), createDataset('./Datasets/photos/test', False)

print(len(art_train), len(art_test), len(photos_train), len(photos_test))

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

def random_crop(image):
  return tf.image.random_crop(image, size = [IMG_HEIGHT, IMG_WIDTH, 3])


# normalizing the images to [-1, 1]
def normalize(image):
  return (tf.cast(image, tf.float32) / 127.5) - 1


def random_jitter(image):
  return tf.image.random_flip_left_right(random_crop(tf.image.resize(image, [286, 286], method = tf.image.ResizeMethod.NEAREST_NEIGHBOR)))


def preprocess_image_train(image):
  return normalize(random_jitter(image))


def preprocess_image_test(image):
  return normalize(image)

In [None]:
art_train = art_train.cache().map(preprocess_image_train, num_parallel_calls = AUTOTUNE).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

photos_train = photos_train.cache().map(preprocess_image_train, num_parallel_calls = AUTOTUNE).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

art_test = art_test.map(preprocess_image_test, num_parallel_calls = AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

photos_test = photos_test.map(preprocess_image_test, num_parallel_calls = AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)


samples_art = next(iter(art_train))
samples_photo = next(iter(photos_train))

plt.subplot(121)
plt.title('Art')
plt.imshow(samples_art[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Art with random jitter')
plt.imshow(random_jitter(samples_art[0]) * 0.5 + 0.5)

plt.show()

plt.subplot(121)
plt.title('Photo')
plt.imshow(samples_photo[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Photo with random jitter')
plt.imshow(random_jitter(samples_photo[0]) * 0.5 + 0.5)

plt.show()

## **Model cycleGAN**

In [None]:
OUTPUT_CHANNELS = 3


def downsampler(filters, size, apply_instancenorm = True):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = Sequential()
    result.add(Conv2D(filters, size, strides = 2, padding = 'same', kernel_initializer = initializer, use_bias = False))

    if apply_instancenorm:
        result.add(InstanceNormalization())

    result.add(LeakyReLU())

    return result


def upsampler(filters, size, apply_dropout = False):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = Sequential()
    result.add(Conv2DTranspose(filters, size, strides = 2, padding = 'same', kernel_initializer = initializer, use_bias = False))

    result.add(InstanceNormalization())

    if apply_dropout:
        result.add(Dropout(0.5))

    result.add(ReLU())

    return result

def generator():
  inputs = Input(shape = [IMG_HEIGHT, IMG_WIDTH, OUTPUT_CHANNELS])

  down_stack = [
      downsampler(32, 4, apply_instancenorm = False),
      downsampler(64, 4),
      downsampler(128, 4),
      downsampler(256, 4),
      downsampler(256, 4),
      downsampler(256, 4),
      downsampler(512, 4),
      downsampler(512, 4)
  ]

  up_stack = [
      upsampler(512, 4, apply_dropout = True),
      upsampler(512, 4, apply_dropout = True),
      upsampler(256, 4, apply_dropout = True),
      upsampler(256, 4),
      upsampler(256, 4),
      upsampler(128, 4),
      upsampler(64, 4),
      upsampler(32, 4)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = Conv2DTranspose(OUTPUT_CHANNELS, 4, strides = 2, padding = 'same', kernel_initializer = initializer, activation = 'tanh')

  x = inputs
  skips = []

  for down in down_stack:
      x = down(x)
      skips.append(x)

  skips = reversed(skips[:-1])

  for up, skip in zip(up_stack, skips):
      x = up(x)
      x = Concatenate()([x, skip])

  x = last(x)

  return Model(inputs = inputs, outputs = x)


def discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)

    i = Input(shape = [IMG_HEIGHT, IMG_WIDTH, OUTPUT_CHANNELS], name = 'input')

    x = i

    down = downsampler(64, 4, False)(x)
    down = downsampler(128, 4)(down)
    down = downsampler(256, 4)(down)

    zero_pad1 = ZeroPadding2D()(down)
    conv = Conv2D(512, 4, strides = 1, kernel_initializer = initializer, use_bias = False)(zero_pad1)

    norm1 = InstanceNormalization()(conv)

    leaky_relu = LeakyReLU()(norm1)

    zero_pad2 = ZeroPadding2D()(leaky_relu)

    last = Conv2D(1, 4, strides = 1, kernel_initializer = initializer)(zero_pad2)

    return Model(inputs = i, outputs = last)


In [None]:
generator_g = generator()
generator_f = generator()

discriminator_x = discriminator()
discriminator_y = discriminator()


LAMBDA = 10


loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits = True)


def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)
  generated_loss = loss_obj(tf.zeros_like(generated), generated)
  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5


def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)


def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  return LAMBDA * loss1


def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss


generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)

In [None]:
if ALL:
  checkpoint_path = f"./checkpoints/all"

else:
  if ID:
    checkpoint_path = f"./checkpoints/{path[INDEX]}"

  else:
    checkpoint_path = f"./checkpoints/{path[INDEX]}_noID"


ckpt = tf.train.Checkpoint(generator_g = generator_g,
                           generator_f = generator_f,
                           discriminator_x = discriminator_x,
                           discriminator_y = discriminator_y,
                           generator_g_optimizer = generator_g_optimizer,
                           generator_f_optimizer = generator_f_optimizer,
                           discriminator_x_optimizer = discriminator_x_optimizer,
                           discriminator_y_optimizer = discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep = 1)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')
  contrast = 1
else:
  contrast = 8

In [None]:
to_photo = generator_g(samples_art)
to_art = generator_f(samples_photo)

plt.figure(figsize = (8, 8))

imgs = [samples_art, to_photo, samples_photo, to_art]
title = ['Art', 'To Photo', 'Photo', 'To Art']

for i in range(len(imgs)):
  plt.subplot(2, 2, i + 1)
  plt.title(title[i])

  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)

plt.show()

## **Model Style Transfer**

In [None]:
base_model = applications.VGG19(include_top = False, weights='imagenet')
base_model.trainable = False

content_layers = ['block5_conv2'] 

style_layers = ['block1_conv1',
                'block2_conv1',
                'block3_conv1', 
                'block4_conv1', 
                'block5_conv1']

layer_names = content_layers + style_layers
outputs = [base_model.get_layer(name).output for name in layer_names]
model = Model(inputs = [base_model.input], outputs = outputs)


def clip_0_1(image):
  return tf.clip_by_value(image, clip_value_min = 0.0, clip_value_max = 1.0)


# compute the gram matrix using tensorflow API
def gramMatrix(input_tensor):
    result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
    input_shape = tf.shape(input_tensor)
    num_locations = tf.cast(input_shape[1] * input_shape[2], tf.float32)
    return result / (num_locations)    


@tf.function()
def processInput(image):
    preProcessedImage = applications.vgg19.preprocess_input(image)
    outputs = model(preProcessedImage)
    
    styleOutputs = [gramMatrix(styleOut)
                     for styleOut in outputs[1:]]
    contentOutputs = [contentOut
                     for contentOut in outputs[:1]]
    
    return(styleOutputs, contentOutputs)


styleW = 1e-2
contentW = 1e4
variationW = 30


def computeLoss(original, style, content):
    img = processInput(original * 255)
    
    styleLoss = tf.add_n([tf.reduce_mean((img[0][n] - style[0][n]) ** 2) for n in range(len(style_layers))] )
    styleLoss /= len(style_layers)
    contentLoss = tf.reduce_mean((img[1][0] - content[1][0]) ** 2)
    variationLoss = tf.image.total_variation(original)

    total_loss = styleW * styleLoss + contentW * contentLoss + variationW * variationLoss
   
    return total_loss

## **Training**

In [None]:
EPOCHS = 50

@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent = True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    fake_y = generator_g(real_x, training = True)
    cycled_x = generator_f(fake_y, training = True)

    fake_x = generator_f(real_y, training = True)
    cycled_y = generator_g(fake_x, training = True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training = True)
    same_y = generator_g(real_y, training = True)

    disc_real_x = discriminator_x(real_x, training = True)
    disc_real_y = discriminator_y(real_y, training = True)

    disc_fake_x = discriminator_x(fake_x, training = True)
    disc_fake_y = discriminator_y(fake_y, training = True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # Total generator loss = adversarial loss + cycle loss

    if ID:
      total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
      total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    else:
      total_gen_g_loss = gen_g_loss + total_cycle_loss
      total_gen_f_loss = gen_f_loss + total_cycle_loss

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))

  return total_gen_g_loss, total_gen_f_loss, disc_x_loss, disc_y_loss

In [None]:
opt = Adam(learning_rate = 0.02, beta_1 = 0.99, epsilon = 1e-1)

def train(img, styleOutputs, contentOutputs):

  with tf.GradientTape() as tape:       
      loss = computeLoss(img, styleOutputs, contentOutputs)

  grad = tape.gradient(loss, img)
  opt.apply_gradients([(grad, img)])
      
  img.assign(tf.clip_by_value(img, clip_value_min = 0, clip_value_max = 1))

 

def generate_images(model, test_input):
  prediction = model(test_input)
  
  plt.imshow(test_input.numpy().squeeze())
  plt.show()
  plt.imshow(prediction.numpy().squeeze())
  plt.show()
  
  styleOutputs = processInput(prediction * 255)
  contentOutputs = processInput(test_input * 255)

  test_input = tf.Variable(test_input)

  for i in range(2001):
      train(test_input, styleOutputs, contentOutputs)

      if i % 1000 == 0:
          plt.imshow(test_input.numpy().squeeze())
          plt.show()


def show_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i])
    plt.axis('off')
  plt.show()


In [None]:
TRAIN = False

if TRAIN:
  for epoch in range(EPOCHS):
    start = time.time()

    total_gen_g_loss_e, total_gen_f_loss_e, disc_x_loss_e, disc_y_loss_e = 0.0, 0.0, 0.0, 0.0

    n = 0
    for image_x, image_y in tf.data.Dataset.zip((art_train, photos_train)):
      total_gen_g_loss, total_gen_f_loss, disc_x_loss, disc_y_loss = train_step(image_x, image_y)

      total_gen_g_loss_e += total_gen_g_loss
      total_gen_f_loss_e += total_gen_f_loss
      disc_x_loss_e += disc_x_loss
      disc_y_loss_e += disc_y_loss

      if n % 10 == 0:
        print(n, end = ' ')
      
      n += 1

    clear_output(wait = True)

    show_images(generator_f, samples_photo)
    show_images(generator_g, samples_art)

    print("gen_art_loss: ", total_gen_g_loss_e, 
          "gen_photo_loss: ", total_gen_f_loss_e, 
          "disc_art_loss: ", disc_x_loss_e, 
          "disc_photo_loss: ", disc_y_loss_e, '\n')

    ckpt_save_path = ckpt_manager.save()
    print('Saving checkpoint for epoch {} at {}'.format(epoch + 1, ckpt_save_path))

    print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time() - start))

## **Generate Images**

### Real Photos Gen

In [None]:
SHOWSET = False

if SHOWSET:
    def loadImagePIL(path):
        im = Image.open(path).convert('RGB')
        im = im.resize((int(256), int(256)))
        im_array = np.array(im).astype(np.float32) / 255.0
        
        return im_array

    path = './Datasets/art_test/cezanne/00100.jpg'

    originalImage = loadImagePIL(path)
    originalImage = np.expand_dims(originalImage, axis=0)
    originalTensor = tf.constant(originalImage)
    plt.imshow(tf.squeeze(originalTensor, axis=0))

    show_images(generator_g, originalTensor)

### Art Gen

In [None]:
DATASET = False

if DATASET:
    for i in photos_test.take(1):
        generate_images(generator_f, i)

In [None]:
if not DATASET: 
    def loadImagePIL(path):
        im = Image.open(path).convert('RGB')
        im = im.resize((int(256), int(256)))
        im_array = np.array(im).astype(np.float32) / 255.0
        
        return im_array

    path = './Results/foto2/foto2.png'

    originalImage = loadImagePIL(path)
    originalImage = np.expand_dims(originalImage, axis=0)
    originalTensor = tf.constant(originalImage)
    plt.imshow(tf.squeeze(originalTensor, axis=0))

    generate_images(generator_f, originalTensor)