<a href="https://colab.research.google.com/github/DavideEva/Moviefy/blob/main/Moviefy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Problem analysis

# Download datasets

In [None]:
films_data = {
    "Ghibli": "1RR18MAxLoZQWsxrmfb1hYif2MhOcsgO2"
}
keys = list(films_data.keys())
ghibli_index = 0
films_data[keys[ghibli_index]]

In [None]:
from glob import glob

def load_data(name, id_txt):

  file_name = f'list-{name}.txt'

  ! gdown --id "$id_txt" -O "$file_name"

  lines = []
  with open(file_name, "r") as f:
    lines = f.readlines()
  
  ! mkdir "$name"

  for line in lines:
    id = line.strip()
    ! cd "$name" && gdown --id "$id"

  zip_files = glob(f'{name}/*.zip')
  for zip_file in zip_files:
    ! unzip -qo "$zip_file" -d "$name"
    ! rm "$zip_file"
  
  return name

In [None]:
folders = [load_data(studio_name, id_list_id) for studio_name, id_list_id in films_data.items()]

# Import

In [None]:
import numpy as np
import math
import random
import cv2
from matplotlib import pyplot as plt
%matplotlib inline
import tensorflow as tf
from tensorflow.keras.layers import Layer, InputSpec, LeakyReLU, Input, Conv2D, Activation, Concatenate, Conv2DTranspose, BatchNormalization, AveragePooling2D, Add
from tensorflow import pad
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.utils import plot_model
from tensorflow.keras import Sequential
from tensorflow.keras.initializers import Constant
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Plot functions

In [None]:
def plot_grid(images, columns, show_axis=False, labels=None):
  if len(images) == 0 or columns <= 0:
    return
  height = 1 + math.ceil(len(images) / columns) * 2
  width = columns * 4
  dpi = max(images[0].shape[0], images[0].shape[1]) // 2
  fig = plt.figure(figsize=(width, height), dpi=dpi)
  fig.subplots_adjust(hspace=0.4)
  for index, img in enumerate(images, start=1):
    if 'float' in img.dtype.str:
      img = (img * 255).astype('uint8')
    sp = fig.add_subplot(math.ceil(len(images) / columns), columns, index)
    if not show_axis:
      plt.axis('off')
    plt.imshow(img)
    if labels is not None:
      l = len(labels)
      sp.set_title(labels[(index-1) % l], fontsize=10)
    else:
      sp.set_title(index, fontsize=10)

# Global parameters

In [None]:
# The raw image as found in dataset files. The important part is the width/height proportion. In this case, 16:9.
raw_shape = (1080, 1920, 3)

# The same as https://github.com/FilipAndersson245/cartoon-gan/blob/5a09f4e2cfad42accfc1792dedfba95f9ab6fb83/utils/datasets.py#L32
# Should be less than input shape
preprocess_shape = (384, 384, 3)

# Dimension after the preprocess stage
# Should be the dimension expected by the network and the loss functions
input_shape = (224, 224, 3)

# Batch size used for training and fetching images
batch_size = 32

# Images are split between train+validation and test set at this proportion
validation_split = 0.2

# Dataset loading and preprocessing

In [None]:
def smooth_edges(img):
  # Parameters taken from https://github.com/FilipAndersson245/cartoon-gan/blob/master/utils/datasets.py
  kernel_size = 5
  pad_size = kernel_size // 2 + 1
  gray_img = cv2.cvtColor(np.uint8(img*255), cv2.COLOR_RGB2GRAY)
  pad_img = np.pad(img, ((pad_size, pad_size), (pad_size, pad_size), (0, 0)), mode='reflect')
  edges = cv2.Canny(gray_img, 150, 500)
  dilation = cv2.dilate(edges, np.ones((kernel_size, kernel_size), np.uint8))
  gauss = cv2.getGaussianKernel(kernel_size, 0)
  gauss = gauss * gauss.transpose(1, 0)
  idx = np.where(dilation != 0)
  loops = len(idx[0])
  gauss_img = np.copy(img)
  for i in range(loops):
    #debug edges detection: 
    #gauss_img[idx[0][i], idx[1][i], 0] = 1.0
    #gauss_img[idx[0][i], idx[1][i], 1] = 1.0
    #gauss_img[idx[0][i], idx[1][i], 2] = 0.0
    gauss_img[idx[0][i], idx[1][i], 0] = np.sum(np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 0], gauss))
    gauss_img[idx[0][i], idx[1][i], 1] = np.sum(np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 1], gauss))
    gauss_img[idx[0][i], idx[1][i], 2] = np.sum(np.multiply(pad_img[idx[0][i]:idx[0][i] + kernel_size, idx[1][i]:idx[1][i] + kernel_size, 2], gauss))
  return gauss_img

In [None]:
def left_cropper(pre = lambda x: x):
  def left_crop(img):
      img = pre(img)
      return img[:, 0:img.shape[0], :]
  return left_crop 
def right_cropper(pre = lambda x: x):
  def right_crop(img):
    img = pre(img)
    return img[:, -img.shape[0]:, :]
  return right_crop
def smoother(pre = lambda x: x):
  return lambda img: smooth_edges(pre(img))
def resizer(size, pre = lambda x: x):
  return lambda img: cv2.resize(pre(img), (size[1], size[0]), interpolation=cv2.INTER_AREA)

In [None]:
def lambda_generator(batches, λ = lambda x: x):
  for batch in batches:
    yield [λ(i) for i in batch]

def random_merge_generator(it_1, it_2, p = 0.5):
  while True:
    rand = np.random.random()
    it, other = (it_1, it_2) if rand < p else (it_2, it_1)
    try:
      yield next(it)
    except StopIteration:
      while True:
        yield next(other)

def cartoon_generator(pre = lambda x: x, preprocess_shape = preprocess_shape, raw_shape = raw_shape):
  return lambda_generator(
      cartoon_real_generator.flow_from_directory(
        **data_flow_settings,
        directory = folders[ghibli_index],
        subset = 'training',
        # Same proportions as raw, same height as desired input
        target_size = (preprocess_shape[0], raw_shape[1] * preprocess_shape[0] // raw_shape[0])
      ),
      pre
  )

In [None]:
data_generator_settings = {
    'data_format' : 'channels_last',
    'validation_split' : validation_split,
    'rescale' : 1.0 / 255
}

data_flow_settings = {
    'color_mode' : 'rgb',
    'class_mode' : None,
    'batch_size' : batch_size,
    'shuffle' : True,
    'interpolation' : 'bilinear'
}

cartoon_real_generator = ImageDataGenerator(
    **data_generator_settings
)

In [None]:
test_cartoon_real_flow = random_merge_generator(
  cartoon_generator(left_cropper(), input_shape),
  cartoon_generator(right_cropper(), input_shape)
)
plot_grid(next(test_cartoon_real_flow), 4)
del test_cartoon_real_flow

In [None]:
test_cartoon_edge_fake_flow = random_merge_generator(
  cartoon_generator(resizer(input_shape, smoother(left_cropper())), preprocess_shape),
  cartoon_generator(resizer(input_shape, smoother(right_cropper())), preprocess_shape)
)
plot_grid(next(test_cartoon_edge_fake_flow), 4)
del test_cartoon_edge_fake_flow

# Cartoon-GAN

## Utility Layers

In [None]:
class ReflectionPadding2D(Layer):
  def __init__(self, padding=(1, 1), **kwargs):
    self.padding = tuple(padding)
    # self.input_spec = [InputSpec(ndim=4)]
    super(ReflectionPadding2D, self).__init__(**kwargs)

  def compute_output_shape(self, s):
    if s[1] == None:
      return (None, None, None, s[3])
    return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

  def call(self, x, mask=None):
    w_pad, h_pad = self.padding
    return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')

  def get_config(self):
    config = super(ReflectionPadding2D, self).get_config()
    return config

In [None]:
class Conv2DReflection3x3(Layer):
  def __init__(self, features, stride=1):
    super().__init__()
    self.reflectionPadding2D = ReflectionPadding2D()
    self.conv2d = Conv2D(features, (3,3), strides=(stride, stride), padding='valid', use_bias=False)

  def call(self, inputs, training=False):
    x = self.reflectionPadding2D(inputs, training=training)
    return self.conv2d(x, training=training)

## Discriminator

In [None]:
# define the discriminator model
def define_discriminator(image_shape):
  alpha = 0.2
  epsilon = 1e-5
  momentum = 0.1

  # source image input
  in_image = Input(shape=image_shape)

  # k3n32s1
  d = Conv2DReflection3x3(32, stride=1)(in_image)
  d = LeakyReLU(alpha=alpha)(d)

  # k3n64s2
  d = Conv2DReflection3x3(64, stride=2)(d)
  d = LeakyReLU(alpha=alpha)(d)
  # k3n128s1
  d = Conv2DReflection3x3(128, stride=1)(d)
  d = BatchNormalization(epsilon=epsilon, momentum=momentum)(d)
  d = LeakyReLU(alpha=alpha)(d)

  # k3n128s2
  d = Conv2DReflection3x3(128, stride=2)(d)
  d = LeakyReLU(alpha=alpha)(d)
  # k3n256s1
  d = Conv2DReflection3x3(256, stride=1)(d)
  d = BatchNormalization(epsilon=epsilon, momentum=momentum)(d)
  d = LeakyReLU(alpha=alpha)(d)

  # feature construction block
  # k3n256s1
  d = Conv2DReflection3x3(256, stride=1)(d)
  d = BatchNormalization(epsilon=epsilon, momentum=momentum)(d)
  d = LeakyReLU(alpha=alpha)(d)

  # patch output
  patch_out = Conv2DReflection3x3(1, stride=1)(d)

  # define model
  model = Model(in_image, patch_out)
  return model

In [None]:
D = define_discriminator(input_shape)

In [None]:
plot_model(D, show_shapes=True, expand_nested=True)

## Generator

In [None]:
# define the generator model
def define_generator(image_shape):
  alpha = 0.2
  epsilon = 1e-5
  momentum = 0.1

  # source image input
  in_image = Input(shape=image_shape)

  # flat block
  # k7n64s1
  g = Conv2D(64, (7,7), strides=1, padding='same', use_bias=False)(in_image)
  g = BatchNormalization(epsilon=epsilon, momentum=momentum)(g)
  g = LeakyReLU(alpha=alpha)(g)

  def down_block(x, n_features):
    # k3n?s2
    x = Conv2DReflection3x3(n_features, stride=2)(x)
    # k3n?s1
    x = Conv2DReflection3x3(n_features, stride=1)(x)
    x = BatchNormalization(epsilon=epsilon, momentum=momentum)(x)
    x = LeakyReLU(alpha=alpha)(x)
    return x

  # 1st down block
  g = down_block(g, 128)

  # 2nd down block
  g = down_block(g, 256)

  def resiual_block(x):
    skip = x
    # k3n256s1
    x = Conv2DReflection3x3(256, stride=1)(x)
    x = BatchNormalization(epsilon=epsilon, momentum=momentum)(x)
    x = LeakyReLU(alpha=alpha)(x)
    # k3n256s1
    x = Conv2DReflection3x3(256, stride=1)(x)
    x = BatchNormalization(epsilon=epsilon, momentum=momentum)(x)
    x = Add()([x, skip])
    x = LeakyReLU(alpha=alpha)(x)
    return x

  for _ in range(8):
    g = resiual_block(g)

  def up_block(x, n_features):
    # k3n?s1/2
    x = Conv2DTranspose(n_features, (3,3), strides=2)(x)
    x = AveragePooling2D(pool_size=(2,2), strides=1)(x)
    # k3n?s1
    x = Conv2DReflection3x3(n_features, stride=1)(x)
    x = BatchNormalization(epsilon=epsilon, momentum=momentum)(x)
    x = LeakyReLU(alpha=alpha)(x)
    return x

  # 1st up block
  g = up_block(g, 128)

  # 2nd up-block
  g = up_block(g, 64)

  # k7n3s1
  output = Conv2D(3, (7,7), strides=1, padding='same')(g)

  # define model
  model = Model(in_image, output)
  return model

In [None]:
G = define_generator(input_shape)

In [None]:
plot_model(G, show_shapes=True, expand_nested=True)

## Loss functions

In [None]:
def BCEWithLogitsLoss():
  return tf.keras.losses.BinaryCrossEntropy(
    from_logits=True,
    reduction=tf.keras.losses.Reduction.NONE)

In [None]:
class AdversarialLoss:
  def __init__(self, cartoon_labels, fake_cartoon_labels):
    self.base_loss = BCEWithLogitsLoss()
    self.cartoon_labels = cartoon_labels
    self.fake_cartoon_labels = fake_cartoon_labels

  def __call__(self, cartoon, generated_fake, cartoon_edge_fake):
    D_cartoon_loss = self.base_loss(cartoon, self.cartoon_labels)
    D_generated_fake_loss = self.base_loss(generated_fake, self.fake_cartoon_labels)
    D_edge_fake_loss = self.base_loss(cartoon_edge_fake, self.fake_cartoon_labels)

    return D_cartoon_loss + D_generated_fake_loss + D_edge_fake_loss

# alias for clarity
DiscriminatorLoss = AdversarialLoss

In [None]:
from tensorflow.keras.applications.vgg19 import VGG19

vgg19 = VGG19(include_top=False, weights='imagenet', input_shape=input_shape)

class ContentLoss:
  def __init__(self):
    self.perception = vgg19.predict
  
  def __call__(self, outputs, inputs):
    diff = self.perception(outputs) - self.perception(inputs)
    k = tf.norm(diff, ord=1)
    return k

In [None]:
class GeneratorLoss:
  def __init__(self, omega=10):
    self.omega = omega
    self.content_loss = ContentLoss()
    self.base_loss = BCEWithLogitsLoss()
  
  def __call__(self, outputs, inputs):
    return self.base_loss(outputs, inputs) + self.omega * self.content_loss(outputs, inputs)