<a href="https://colab.research.google.com/github/DavideEva/Moviefy/blob/main/Moviefy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Problem analysis

# Download datasets

In [None]:
from glob import glob
flickr30k_folder = 'Flickr30k-images-preprocessed'
! rm -rf sample_data

In [None]:
! gdown --id "10c0Xruu2wAE-FpQEIlXm17HlwQAJKKVM"
! mkdir "$flickr30k_folder"
! unzip -qo flickr30k-images-preprocessed.zip -d "$flickr30k_folder"
! rm flickr30k-images-preprocessed.zip

In [None]:
films_data = {
    "Ghibli": "1bR_BE-ZZSXW1URBJqPIMFo9VLODYb4_k"
}
ghibli_index = 0
real_folder = 'real'
smooth_folder = 'smooth'

def load_data(id, name):
  ! mkdir -p "$name"
  ! cd "$name" && gdown --id "$id"
  zip_files = glob(f'{name}/*.zip')
  for zip_file in zip_files:
    ! unzip -qo "$zip_file" -d "$name"
    ! rm "$zip_file"
  return name

folders = [load_data(id_drive, studio_name) for studio_name, id_drive in films_data.items()]
print(folders)

Folder structure:
```bash
.
├── Studio_Name
|   ├── real
|   |   └── Movie_Name_1
|   |       ├── Scene-1
|   |       |   ├── left
|   |       |   |   ├── 0.jpg
|   |       |   |   ...
|   |       |   |   └──
|   |       |   └── right
|   |       |       ├── 0.jpg
|   |       |       ...
|   |       |       └──
|   |       ├── Scene-2
|   |       ...
|   |       └──
|   |   
|   └── smooth
|       └── Movie_Name_2
|           ├── Scene-1
|           |   ├── left
|           |   └── right
|           ├── Scene-2
|           ...
|           └──
├── Studio_Name_2
...
└──
```

# Import

In [None]:
import numpy as np
import math
import random
import cv2
import psutil
import pickle
import shutil
import os
from os import path
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
%matplotlib inline
import tensorflow as tf
from tensorflow.keras.layers import Layer, InputSpec, LeakyReLU, Input, Conv2D, Activation, Concatenate, Conv2DTranspose, BatchNormalization, AveragePooling2D, Add
from tensorflow import pad
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.utils import plot_model
from tensorflow.keras import Sequential
from tensorflow.keras.initializers import Constant
from tensorflow.keras.preprocessing.image import ImageDataGenerator
! pip install tensorflow-addons
import tensorflow_addons as tfa

# Plot functions

In [None]:
def plot_grid(images, columns, show_axis=False, labels=None):
  if len(images) == 0 or columns <= 0:
    return
  scale = 2
  height = (1 + math.ceil(len(images) / columns) * 2) * scale
  width = (columns * 4) * scale
  dpi = max(images[0].shape[0], images[0].shape[1]) // 2
  fig = plt.figure(figsize=(width, height), dpi=dpi)
  fig.subplots_adjust(hspace=0.4)
  for index, img in enumerate(images, start=1):
    if 'float' in img.dtype.str:
      img = (img * 255).astype('uint8')
    sp = fig.add_subplot(math.ceil(len(images) / columns), columns, index)
    if not show_axis:
      plt.axis('off')
    if len(np.shape(img)) == 2 or (len(np.shape(img)) > 2 and np.shape(img)[2] == 1):
      plt.imshow(img, cmap='gray')
    else:
      plt.imshow(img)
    if labels is not None:
      l = len(labels)
      sp.set_title(labels[(index-1) % l], fontsize=10)
    else:
      sp.set_title(index, fontsize=10)
  plt.show()

def float_to_int_images(outputs):
  return [np.clip(output * 255 + 0.5, 0, 255).astype(np.uint8) for output in outputs] 

# Global parameters

In [None]:
# Dimension after the preprocess stage
# Should be the dimension expected by the network and the loss functions
input_shape = (224, 224, 3)

# Batch size used for training and fetching images
batch_size = 16

# Images are split between train+validation and test set at this proportion
validation_split = 0.2

epochs_count = 10


# Dataset loading and preprocessing

In [None]:
def lambda_generator(batches, λ=lambda x: x):
  for batch in batches:
    if type(batch) is tuple:
      batch, labels = batch
      yield [λ(i) for i in batch], labels
    else:
      yield [λ(i) for i in batch]

def random_merge_generator(it_1, it_2, p=0.5):
  while True:
    rand = np.random.random()
    it, other = (it_1, it_2) if rand < p else (it_2, it_1)
    try:
      yield next(it)
    except StopIteration:
      while True:
        yield next(other)

In [None]:
norm_mean = np.asfarray([0.485, 0.456, 0.406])
norm_std = np.asfarray([0.229, 0.224, 0.225])

def normalize(img):
  return (img - norm_mean) / norm_std
def unnormalize(img):
  return tf.clip_by_value(img * norm_std + norm_mean, 0.0, 1.0)
def rescale_and_normalize(img):
  return normalize(img / 255.0)

def generated_to_images(outputs):
  return [unnormalize(output).numpy() for output in outputs]

def test():
  a = np.asfarray([[[1.0, 0.5, 0.5], [0.0, 0.1, 0.9]], [[0.5, 0.6, 0.7], [1.0, 0.1, 0.2]]])
  b = normalize(a)
  c = unnormalize(b)
  assert np.linalg.norm(c - a) < 0.00001
test()

In [None]:
data_generator_settings = {
    'data_format' : 'channels_last',
    'validation_split' : validation_split,
    'preprocessing_function' : rescale_and_normalize,
    #'rescale' : 1.0 / 255,
    'horizontal_flip' : True
}

data_flow_settings = {
    'color_mode' : 'rgb',
    'batch_size' : batch_size,
    'shuffle' : True,
    'seed' : 42,
    'class_mode' : None,
    'interpolation' : 'bilinear',
    'target_size' : (input_shape[0], input_shape[1])
}

def cartoon_real_generator(subset='training'):
  cartoon_real_gen = ImageDataGenerator(
    **data_generator_settings
  )
  return cartoon_real_gen.flow_from_directory(
        **data_flow_settings,
        # Ghibli cartoon
        directory = path.join(folders[ghibli_index], real_folder),
        subset = subset
      )

def cartoon_real_validation_generator():
  return cartoon_real_generator('validation')

def cartoon_smooth_generator(subset='training'):
  cartoon_smooth_gen = ImageDataGenerator(
    **data_generator_settings
  )
  return cartoon_smooth_gen.flow_from_directory(
        **data_flow_settings,
        # Ghibli cartoon
        directory = path.join(folders[ghibli_index], smooth_folder),
        subset = subset
      )
  
def cartoon_smooth_validation_generator():
  return cartoon_smooth_generator('validation')

def real_generator(subset='training'):
  real_gen = ImageDataGenerator(
      **data_generator_settings
  )
  return real_gen.flow_from_directory(
      **data_flow_settings,
      # Flickr30k images
      directory=flickr30k_folder,
      subset=subset
  )

def real_validation_generator():
  return real_generator('validation')

In [None]:
gen = cartoon_real_generator()
batches_per_epoch = len(gen)
print("Batches per epoch:", batches_per_epoch)
del gen

In [None]:
test_cartoon_real_flow = cartoon_real_generator()
plot_grid(generated_to_images(next(test_cartoon_real_flow)), 4)
del test_cartoon_real_flow

In [None]:
test_cartoon_edge_fake_flow = cartoon_smooth_generator()
plot_grid(generated_to_images(next(test_cartoon_edge_fake_flow)), 4)
del test_cartoon_edge_fake_flow

In [None]:
test_real_flow = real_generator()
plot_grid(generated_to_images(next(test_real_flow)), 4)
del test_real_flow

# Cartoon-GAN

## Custom Convolutional Layers

In [None]:
class ReflectionPadding2D(Layer):
  def __init__(self, padding=(1, 1), **kwargs):
    self.padding = tuple(padding)
    # self.input_spec = [InputSpec(ndim=4)]
    super(ReflectionPadding2D, self).__init__(**kwargs)

  def compute_output_shape(self, s):
    if s[1] == None:
      return (None, None, None, s[3])
    return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

  def call(self, x, mask=None):
    w_pad, h_pad = self.padding
    return tf.pad(x, [[0, 0], [h_pad, h_pad], [w_pad, w_pad], [0, 0]], 'REFLECT')

  def get_config(self):
    config = super(ReflectionPadding2D, self).get_config()
    return config

In [None]:
class Conv2DReflection3x3(Layer):
  def __init__(self, features, stride=1):
    super().__init__()
    self.reflectionPadding2D = ReflectionPadding2D()
    self.conv2d = Conv2D(features, (3,3), strides=(stride, stride), padding='valid', use_bias=False)

  def call(self, inputs, training=False):
    x = self.reflectionPadding2D(inputs, training=training)
    return self.conv2d(x, training=training)

## Discriminator
Based on the Cartoon-GAN discriminator, available at [this link](https://github.com/FilipAndersson245/cartoon-gan/blob/master/models/discriminator.py).

In [None]:
# define the discriminator model
def define_discriminator(image_shape):
  alpha = 0.2
  epsilon = 1e-5
  momentum = 0.1

  # source image input
  in_image = Input(shape=image_shape)

  # k3n32s1
  d = Conv2DReflection3x3(32, stride=1)(in_image)
  d = LeakyReLU(alpha=alpha)(d)

  # k3n64s2
  d = Conv2DReflection3x3(64, stride=2)(d)
  d = LeakyReLU(alpha=alpha)(d)
  # k3n128s1
  d = Conv2DReflection3x3(128, stride=1)(d)
  d = BatchNormalization(epsilon=epsilon, momentum=momentum)(d)
  d = LeakyReLU(alpha=alpha)(d)

  # k3n128s2
  d = Conv2DReflection3x3(128, stride=2)(d)
  d = LeakyReLU(alpha=alpha)(d)
  # k3n256s1
  d = Conv2DReflection3x3(256, stride=1)(d)
  d = BatchNormalization(epsilon=epsilon, momentum=momentum)(d)
  d = LeakyReLU(alpha=alpha)(d)

  # feature construction block
  # k3n256s1
  d = Conv2DReflection3x3(256, stride=1)(d)
  d = BatchNormalization(epsilon=epsilon, momentum=momentum)(d)
  d = LeakyReLU(alpha=alpha)(d)

  # patch output
  d = Conv2DReflection3x3(1, stride=1)(d)
  patch_out = tf.keras.activations.sigmoid(d)

  # define model
  model = Model(in_image, patch_out, name='Discriminator')
  return model

D = define_discriminator(input_shape)

In [None]:
#plot_model(D, show_shapes=True, expand_nested=True)
D.summary()

In [None]:
noise = tf.random.normal([1, *input_shape])
label_image = D(noise, training=False)

plt.imshow(label_image[0, :, :, 0], cmap='gray')
plt.show()

## Generator
Based on the Cartoon-GAN generator, available at [this link](https://github.com/FilipAndersson245/cartoon-gan/blob/master/models/generator.py).

In [None]:
# define the generator model
def define_generator(image_shape):
  alpha = 0.2
  epsilon = 1e-5
  momentum = 0.1

  # source image input
  in_image = Input(shape=image_shape)

  # flat block
  # k7n64s1
  g = Conv2D(64, (7,7), strides=1, padding='same', use_bias=False)(in_image)
  g = BatchNormalization(epsilon=epsilon, momentum=momentum)(g)
  g = LeakyReLU(alpha=alpha)(g)

  def down_block(x, n_features):
    # k3n?s2
    x = Conv2DReflection3x3(n_features, stride=2)(x)
    # k3n?s1
    x = Conv2DReflection3x3(n_features, stride=1)(x)
    x = BatchNormalization(epsilon=epsilon, momentum=momentum)(x)
    x = LeakyReLU(alpha=alpha)(x)
    return x

  # 1st down block
  g = down_block(g, 128)

  # 2nd down block
  g = down_block(g, 256)

  def resiual_block(x):
    skip = x
    # k3n256s1
    x = Conv2DReflection3x3(256, stride=1)(x)
    x = BatchNormalization(epsilon=epsilon, momentum=momentum)(x)
    x = LeakyReLU(alpha=alpha)(x)
    # k3n256s1
    x = Conv2DReflection3x3(256, stride=1)(x)
    x = BatchNormalization(epsilon=epsilon, momentum=momentum)(x)
    x = Add()([x, skip])
    x = LeakyReLU(alpha=alpha)(x)
    return x

  for _ in range(8):
    g = resiual_block(g)

  def up_block(x, n_features):
    # k3n?s1/2
    x = Conv2DTranspose(n_features, (3,3), strides=2)(x)
    x = AveragePooling2D(pool_size=(2,2), strides=1)(x)
    # k3n?s1
    x = Conv2DReflection3x3(n_features, stride=1)(x)
    x = BatchNormalization(epsilon=epsilon, momentum=momentum)(x)
    x = LeakyReLU(alpha=alpha)(x)
    return x

  # 1st up block
  g = up_block(g, 128)

  # 2nd up-block
  g = up_block(g, 64)

  # k7n3s1
  output = Conv2D(3, (7,7), strides=1, padding='same')(g)

  # define model
  model = Model(in_image, output, name='Generator')
  return model


G = define_generator(input_shape)

In [None]:
#plot_model(G, show_shapes=True, expand_nested=True)
G.summary()

In [None]:
noise = tf.random.normal([1, *input_shape])
generated_image = G(noise, training=False)
plot_grid(generated_to_images(generated_image), 1)
plt.show()

## Loss functions

### Binary Cross Entropy

In [None]:
def BCEWithLogitsLoss():
  return tf.keras.losses.BinaryCrossentropy(
    from_logits=True,
    reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)

def BCELoss():
  return tf.keras.losses.BinaryCrossentropy(
    from_logits=False,
    reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)

def test():
  a = tf.ones((10, 64))
  b = tf.fill((10, 64), 1.5)
  loss = BCEWithLogitsLoss()
  assert abs(loss(a, b).numpy() - 0.2014133) <= 1e-8

  a = tf.ones((10, 64))
  loss = BCELoss()
  assert abs(loss(a, b).numpy() - 0.0) <= 1e-8
test()

### Adversarial Loss
Also called Discriminator loss.

In [None]:
class AdversarialLoss:
  def __init__(self, cartoon_labels, fake_cartoon_labels):
    self.base_loss = BCELoss()
    self.cartoon_labels = cartoon_labels
    self.fake_cartoon_labels = fake_cartoon_labels

  def __call__(self, cartoons_outputs, generated_fakes_outputs, cartoon_edge_fakes_outputs):
    batch_size = len(cartoons_outputs)
    cartoon_labels = tf.stack([self.cartoon_labels for _ in range(batch_size)])
    fake_cartoon_labels = tf.stack([self.fake_cartoon_labels for _ in range(batch_size)])
    D_cartoon_loss = self.base_loss(cartoons_outputs, cartoon_labels)
    D_generated_fake_loss = self.base_loss(generated_fakes_outputs, fake_cartoon_labels) # TODO
    D_edge_fake_loss = self.base_loss(cartoon_edge_fakes_outputs, fake_cartoon_labels)

    return D_cartoon_loss + D_generated_fake_loss + D_edge_fake_loss

# alias for clarity
DiscriminatorLoss = AdversarialLoss

def test():
  loss = AdversarialLoss(np.ones((56, 56)), np.zeros((56, 56)))
  cartoon = tf.fill((10, 56, 56), 0.6)
  gf = tf.fill((10, 56, 56), 0.4)
  cartoon_edge = tf.fill((10, 56, 56), 0.3)
  l1 = loss(cartoon, gf, cartoon_edge)

  cartoon = tf.fill((10, 56, 56), 0.9)
  gf = tf.fill((10, 56, 56), 0.3)
  cartoon_edge = tf.fill((10, 56, 56), 0.2)
  l2 = loss(cartoon, gf, cartoon_edge)

  cartoon = tf.fill((10, 56, 56), 0.99)
  gf = tf.fill((10, 56, 56), 0.01)
  cartoon_edge = tf.fill((10, 56, 56), 0.01)
  l3 = loss(cartoon, gf, cartoon_edge)
  assert l1 > l2 > l3 > 0
test()

### Content Loss
Used to force content fidelty in the generator.

In [None]:
from tensorflow.keras.applications import vgg19

vgg19_model = vgg19.VGG19(include_top=False, 
                    weights='imagenet', 
                    input_shape=input_shape)

vgg19_model.trainable = False
for l in vgg19_model.layers:
  l.trainable = False

class ContentLoss:
  def __init__(self):
    self.perception = lambda img: vgg19_model(vgg19.preprocess_input(unnormalize(img) * 255.0), training=False)
  
  def __call__(self, outputs, targets):
    diff = self.perception(outputs) - self.perception(targets)
    k = tf.norm(diff, ord=1)
    return k

def test():
  loss = ContentLoss()
  outputs = tf.fill((10, 224, 224, 3), 0.0)
  assert loss(outputs, outputs) == 0.0
  inputs = tf.fill((10, 224, 224, 3), 1.0)
  assert loss(outputs, inputs) > 0.0
test()

### Generator Loss
Enforces both discriminator fooling and content fidelty from the original.

In [None]:
class GeneratorLoss:
  def __init__(self, cartoon_labels, omega=10):
    self.omega = tf.constant(omega, dtype=tf.float32)
    self.content_loss = ContentLoss()
    self.base_loss = BCELoss()
    self.cartoon_labels = cartoon_labels
  
  def __call__(self, outputs, inputs, outputs_labels):
    batch_size = len(outputs)
    cartoon_labels = tf.stack([self.cartoon_labels for _ in range(batch_size)])
    return self.base_loss(outputs_labels, cartoon_labels) + self.omega * self.content_loss(outputs, inputs)

def test():
  loss = GeneratorLoss(tf.ones((56, 56)), omega=100)
  outputs_labels = tf.fill((10, 56, 56), 1.0)
  outputs = tf.fill((10, 224, 224, 3), 0.0)
  assert loss(outputs, outputs, outputs_labels) == 0.0
  loss = GeneratorLoss(tf.ones((56, 56)), omega=10)
  outputs_labels = tf.fill((10, 56, 56), 0.5)
  outputs = tf.fill((10, 224, 224, 3), 0.0)
  assert loss(outputs, outputs, outputs_labels) > 0.0
  loss = GeneratorLoss(tf.ones((56, 56)), omega=100)
  outputs_labels = tf.fill((10, 56, 56), 1.0)
  outputs = tf.fill((10, 224, 224, 3), 0.0)
  inputs = tf.fill((10, 224, 224, 3), 1.0)
  assert loss(outputs, inputs, outputs_labels) > 0.0
test()

## Training

In [None]:
class InputIterator(object):
	def __init__(self, inputs, batch_size=64, shuffle=True, seed=None):
		self._inputs = inputs
		self._inputs_list = isinstance(inputs, list)
		self._N = self._inputs[0].shape[0] if self._inputs_list else self._inputs.shape[0]
		self.batch_size = batch_size
		self._shuffle = shuffle
		self._prng = np.random.RandomState(seed=seed)
		self._next_indices = np.array([], dtype=np.uint)

	def __iter__(self):
		return self

	def __next__(self):
		while len(self._next_indices) < self.batch_size:
			next_ind = np.arange(self._N, dtype=np.uint)
			if self._shuffle:
				self._prng.shuffle(next_ind)
			self._next_indices = np.concatenate((
				self._next_indices, next_ind))

		ind = self._next_indices[:self.batch_size]
		self._next_indices = self._next_indices[self.batch_size:]

		if self._inputs_list:
			batch = [inp[ind,...] for inp in self._inputs]
		else:
			batch = self._inputs[ind,...]

		return batch

In [None]:
import warnings

from keras.applications.inception_v3 import InceptionV3
from keras import backend as K
import numpy as np


def update_mean_cov(mean, cov, N, batch):
	batch_N = batch.shape[0]

	x = batch
	N += batch_N
	x_norm_old = batch-mean
	mean = mean + x_norm_old.sum(axis=0)/N
	x_norm_new = batch-mean
	cov = ((N-batch_N)/N)*cov + x_norm_old.T.dot(x_norm_new)/N

	return (mean, cov, N)


def frechet_distance(mean1, cov1, mean2, cov2):
	"""Frechet distance between two multivariate Gaussians.
	Arguments:
		mean1, cov1, mean2, cov2: The means and covariances of the two
			multivariate Gaussians.
	Returns:
		The Frechet distance between the two distributions.
	"""
	
	def check_nonpositive_eigvals(l):
		nonpos = (l < 0)
		if nonpos.any():
			warnings.warn('Rank deficient covariance matrix, '
				'Frechet distance will not be accurate.', Warning)
		l[nonpos] = 0

	(l1,v1) = np.linalg.eigh(cov1)
	check_nonpositive_eigvals(l1)
	cov1_sqrt = (v1*np.sqrt(l1)).dot(v1.T)
	cov_prod = cov1_sqrt.dot(cov2).dot(cov1_sqrt)
	lp = np.linalg.eigvalsh(cov_prod)
	check_nonpositive_eigvals(lp)

	trace = l1.sum() + np.trace(cov2) - 2*np.sqrt(lp).sum()
	diff_mean = mean1-mean2
	fd = diff_mean.dot(diff_mean) + trace

	return fd


class FrechetInceptionDistance(object):
	"""Frechet Inception Distance.
	
	Class for evaluating Keras-based GAN generators using the Frechet
	Inception Distance (Heusel et al. 2017, 
	https://arxiv.org/abs/1706.08500).
	Arguments to constructor:
		generator: a Keras model trained as a GAN generator
		image_range: A tuple giving the range of values in the images output
			by the generator. This is used to rescale to the (-1,1) range
			expected by the Inception V3 network. 
		generator_postprocessing: A function, preserving the shape of the
			output, to be applied to all generator outputs for further 
			postprocessing. If None (default), no postprocessing will be
			done.
	Attributes: The arguments above all have a corresponding attribute
		with the same name that can be safely changed after initialization.
	Arguments to call:
		real_images: An 4D NumPy array of images from the training dataset,
			or a Python generator outputting training batches. The number of
			channels must be either 3 or 1 (in the latter case, the single
			channel is distributed to each of the 3 channels expected by the
			Inception network).
		generator_inputs: One of the following:
			1. A NumPy array with generator inputs, or
			2. A list of NumPy arrays (if the generator has multiple inputs)
			3. A Python generator outputting batches of generator inputs
				(either a single array or a list of arrays)
		batch_size: The size of the batches in which the data is processed.
			No effect if Python generators are passed as real_images or
			generator_inputs.
		num_batches_real: Number of batches to use to evaluate the mean and
			the covariance of the real samples.
		num_batches_gen: Number of batches to use to evaluate the mean and
			the covariance of the generated samples. If None (default), set
			equal to num_batches_real.
		shuffle: If True (default), samples are randomly selected from the
			input arrays. No effect if real_images or generator_inputs is
			a Python generator.
		seed: A random seed for shuffle (to provide reproducible results)
	Returns (call):
		The Frechet Inception Distance between the real and generated data.
	"""
	def __init__(self, generator, image_range, 
		generator_postprocessing=None):

		self._inception_v3 = None
		self.generator = generator
		self.generator_postprocessing = generator_postprocessing
		self.image_range = image_range
		self._channels_axis = \
			-1 if K.image_data_format()=="channels_last" else -3

	def _setup_inception_network(self):
		self._inception_v3 = InceptionV3(
			include_top=False, pooling='avg')
		self._pool_size = self._inception_v3.output_shape[-1]

	def _preprocess(self, images):
		if self.image_range != (-1,1):
			images = images - self.image_range[0]
			images /= (self.image_range[1]-self.image_range[0])/2.0
			images -= 1.0
		if images.shape[self._channels_axis] == 1:
			images = np.concatenate([images]*3, axis=self._channels_axis)
		return images

	def _stats(self, inputs, input_type="real", postprocessing=None,
		batch_size=64, num_batches=128, shuffle=True, seed=None):

		mean = np.zeros(self._pool_size)
		cov = np.zeros((self._pool_size,self._pool_size))
		N = 0

		for i in range(num_batches):
			try:
				# draw a batch from generator input iterator
				batch = next(inputs)
			except TypeError:
				# assume that an array or a list of arrays was passed
				# instead
				inputs = InputIterator(inputs,
					batch_size=batch_size, shuffle=shuffle, seed=seed)
				batch = next(inputs)

			if input_type=="generated":
				batch = self.generator.predict(batch)
			if postprocessing is not None:
				batch = postprocessing(batch)
			batch = self._preprocess(batch)
			pool = self._inception_v3.predict(batch, batch_size=batch_size)

			(mean, cov, N) = update_mean_cov(mean, cov, N, pool)

		return (mean, cov)

	def __call__(self,
			real_images,
			generator_inputs,
			batch_size=64,
			num_batches_real=128,
			num_batches_gen=None,
			shuffle=True,
			seed=None
		):

		if self._inception_v3 is None:
			self._setup_inception_network()

		(real_mean, real_cov) = self._stats(real_images,
			"real", batch_size=batch_size, num_batches=num_batches_real,
			shuffle=shuffle, seed=seed)
		if num_batches_gen is None:
			num_batches_gen = num_batches_real
		(gen_mean, gen_cov) = self._stats(generator_inputs,
			"generated", batch_size=batch_size, num_batches=num_batches_gen,
			postprocessing=self.generator_postprocessing,
			shuffle=shuffle, seed=seed)

		return frechet_distance(real_mean, real_cov, gen_mean, gen_cov)

In [None]:
# fd = FrechetInceptionDistance(G, (0, 1))

In [None]:
# cartoon = next(crg)
# smooth = next(csg)
# fd(cartoon_real_validation_generator(), cartoon_real_validation_generator(), shuffle=False, num_batches_real=2)

In [None]:
learning_rate = 1.5e-4
beta1, beta2 = (.5, .99)
weight_decay = 1e-4

In [None]:
discriminator_optimizer = tfa.optimizers.AdamW(
    learning_rate=learning_rate, 
    beta_1=beta1, beta_2=beta2,
    weight_decay=weight_decay
)
generator_optimizer = tfa.optimizers.AdamW(
    learning_rate=learning_rate, 
    beta_1=beta1, beta_2=beta2,
    weight_decay=weight_decay
)

## Model

### Checkpoints

In [None]:
local_checkpoint_location = './training_checkpoints/'

#### Google Drive checkpoint backup

In [None]:
use_google_drive = False #@param {type:'boolean'}
google_drive_checkpoint_path = 'Anime-Frames/Checkpoints' #@param {type: 'string'}
reset_checkpoints = True #@param {type: 'boolean'}
google_drive_root = '/content/drive/'
google_drive_checkpoint_location = path.join(google_drive_root, 'MyDrive', google_drive_checkpoint_path)

if use_google_drive:
  from google.colab import drive
  drive.mount(google_drive_root)
  os.makedirs(google_drive_checkpoint_location, exist_ok=True)
  shutil.rmtree(local_checkpoint_location, ignore_errors=True)
  shutil.copytree(google_drive_checkpoint_location, local_checkpoint_location)
  print(f'Local files at {local_checkpoint_location} will be backed inside {google_drive_checkpoint_location}')
else:
  try:
    drive.flush_and_unmount()
    !rm -rf /content/drive
  except:
    pass


if reset_checkpoints:
  shutil.rmtree(local_checkpoint_location, ignore_errors=True)

#### Checkpoint manager

In [None]:
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=G,
                                 discriminator=D,
                                 epoch=tf.Variable(0))
checkpoint_manager = tf.train.CheckpointManager(checkpoint, local_checkpoint_location, max_to_keep=3)

### Implementation

In [None]:
class CartoonGAN:
  def __init__(self,
               checkpoint,
               checkpoint_manager,
               cartoon_real_generator,
               cartoon_smooth_generator,
               real_image_generator,
               cartoon_real_generator_val,
               cartoon_smooth_generator_val,
               real_image_generator_val):
    self.name = 'Cartoon-GAN'
    self.checkpoint = checkpoint
    self.checkpoint_manager = checkpoint_manager

    # Checkpoint restore
    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    if checkpoint_manager.latest_checkpoint:
      print(f'Restored from {checkpoint_manager.latest_checkpoint}.')
    else:
      print('Initializing from scratch.')
    self.discriminator = checkpoint.discriminator
    self.generator = checkpoint.generator
    self.d_optimizer = checkpoint.discriminator_optimizer
    self.g_optimizer = checkpoint.generator_optimizer

    self.cartoon_real_generator = cartoon_real_generator # cartoon real gen train
    self.cartoon_smooth_generator = cartoon_smooth_generator # cartoon smooth gen train
    self.cartoon_real_generator_val = cartoon_real_generator_val # cartoon real gen val
    self.cartoon_smooth_generator_val = cartoon_smooth_generator_val # cartoon smooth gen val
    self.real_image_generator = real_image_generator # real image gen train
    self.real_image_generator_val = real_image_generator_val # real image gen val 
    
    self.batches_per_epoch = len(self.cartoon_real_generator) - 1 # Last batch has unknow size
    self.val_batches_per_epoch = len(self.cartoon_real_generator_val) - 1 # Last batch has unknow size

    self.cartoon_labels = tf.ones((*D.compute_output_shape((None, *input_shape))[1:],))
    self.fake_cartoon_labels = tf.zeros((*D.compute_output_shape((None, *input_shape))[1:],))

    self.d_loss_fn = DiscriminatorLoss(self.cartoon_labels, self.fake_cartoon_labels)
    self.g_loss_fn = GeneratorLoss(self.cartoon_labels)
    self.bar_format = '{bar}{desc}: {percentage:3.0f}% {r_bar}'
  
  @tf.function
  def _step(self,
              x_cartoon_batch_train,
              x_cartoon_smooth_batch_train,
              x_real_train):
    
    generated_real_train = self.generator(x_real_train, training=False)

    with tf.GradientTape() as disc_tape:  
      # Discriminator loss
      predictions_cartoon_train = self.discriminator(x_cartoon_batch_train, training=True)
      predictions_cartoon_smooth_train = self.discriminator(x_cartoon_smooth_batch_train, training=True)
      predictions_generated_train = self.discriminator(generated_real_train, training=True)
      d_loss = self.d_loss_fn(predictions_cartoon_train, predictions_generated_train, predictions_cartoon_smooth_train)
    
    gradients_of_discriminator = disc_tape.gradient(d_loss, self.discriminator.trainable_variables)
    self.d_optimizer.apply_gradients(
        zip(gradients_of_discriminator, self.discriminator.trainable_variables)
    )

    with tf.GradientTape() as gen_tape:
      generated_real_train = self.generator(x_real_train, training=True)
      # Generator loss
      generated_real_labels = self.discriminator(generated_real_train, training=False)
      g_loss = self.g_loss_fn(generated_real_train, x_real_train, generated_real_labels)
    
    gradients_of_generator = gen_tape.gradient(g_loss, self.generator.trainable_variables)
    self.g_optimizer.apply_gradients(
        zip(gradients_of_generator, self.generator.trainable_variables)
    )

    return d_loss, g_loss

  @tf.function
  def _val_step(self,
              x_cartoon_batch_val,
              x_cartoon_smooth_batch_val,
              x_real_val):

    generated_real_val = self.generator(x_real_val, training=False)
    
    # Discriminator loss
    predictions_cartoon_val = self.discriminator(x_cartoon_batch_val, training=False)
    predictions_cartoon_smooth_val = self.discriminator(x_cartoon_smooth_batch_val, training=False)
    predictions_generated_val = self.discriminator(generated_real_val, training=False)
    d_loss = self.d_loss_fn(predictions_cartoon_val, predictions_generated_val, predictions_cartoon_smooth_val)

    # Generator loss (could have been computed from discriminator after one step of training)
    generated_real_labels = self.discriminator(generated_real_val, training=False)
    g_loss = self.g_loss_fn(generated_real_val, x_real_val, generated_real_labels)

    return d_loss, g_loss

  def train(self, epochs):
    d_losses = []
    g_losses = []
    starting_epoch = int(self.checkpoint.epoch)
    for epoch in tqdm(range(starting_epoch, epochs),
                      total=epochs,
                      initial=starting_epoch,
                      position=0,
                      desc=f'Training model {self.name}',
                      bar_format=self.bar_format):

      # train phase
      epoch_progress = tqdm(range(self.batches_per_epoch), 
                            total=self.batches_per_epoch, 
                            position=1,
                            desc=f'Training epoch {epoch}',
                            bar_format=self.bar_format)
      for (step,
          x_cartoon_batch_train, 
          x_cartoon_smooth_batch_train, 
          x_real_train) in zip(
             epoch_progress,
             self.cartoon_real_generator, 
             self.cartoon_smooth_generator,
             self.real_image_generator):
        d_loss, g_loss = self._step(x_cartoon_batch_train,
                                    x_cartoon_smooth_batch_train,
                                    x_real_train)
        epoch_progress.set_postfix({
          'Discriminator Loss' : f'{d_loss.numpy():.4f}',
          'Generator Loss' : f'{g_loss.numpy():.4f}'
        })

      # validation phase
      epoch_val_progress = tqdm(total=self.val_batches_per_epoch,
                                position=1,
                                desc=f'Validation epoch {epoch}',
                                bar_format=self.bar_format)

      d_val_loss_sum, d_val_loss_count = 0, 0
      g_val_loss_sum, g_val_loss_count = 0, 0
      for crgv, csgv, rigv, _ in zip(self.cartoon_real_generator_val, self.cartoon_smooth_generator_val, self.real_image_generator_val, range(self.val_batches_per_epoch)):
        d_val_loss, g_val_loss = self._val_step(crgv, csgv, rigv)
        d_val_loss_sum += d_val_loss
        d_val_loss_count += 1
        g_val_loss_sum += g_val_loss
        g_val_loss_count += 1
        epoch_val_progress.update(1)
      
      epoch_val_progress.set_postfix({
        'Discriminator Loss' : f'{d_val_loss_sum/d_val_loss_count:.4f}',
        'Generator Loss' : f'{g_val_loss_sum/g_val_loss_count:.4f}'
      })
      epoch_val_progress.close()
      
      # checkpoint phase
      epoch_checkpoint_progress = tqdm(total=2 if use_google_drive else 1,
                                 position=1,
                                 desc=f'Saving model checkpoints for epoch {epoch}',
                                 bar_format=self.bar_format)
      self.checkpoint.epoch.assign_add(1)
      save_path = self.checkpoint_manager.save()
      epoch_checkpoint_progress.update(1)
      epoch_checkpoint_progress.set_description(f'Saved checkpoint for epoch {epoch}: {save_path}')
      
      if use_google_drive:
        shutil.rmtree(google_drive_checkpoint_location)
        shutil.copytree(self.checkpoint_manager.directory, google_drive_checkpoint_location)
        epoch_checkpoint_progress.set_description(f'Saved checkpoint backup for epoch {epoch}: {save_path} -> {google_drive_checkpoint_location}')
        epoch_checkpoint_progress.update(1)
      epoch_checkpoint_progress.close()
  # TODO make a return with losses for history

In [None]:
model = CartoonGAN(checkpoint,
                   checkpoint_manager,
                   cartoon_real_generator(), 
                   cartoon_smooth_generator(), 
                   real_generator(),
                   cartoon_real_validation_generator(),
                   cartoon_smooth_validation_generator(),
                   real_validation_generator())

### Training

In [None]:
# d_losses, g_losses = 
model.train(100)

In [None]:
# tf.config.run_functions_eagerly(False)

In [None]:
# a, b, c = next(cartoon_real_generator()), next(cartoon_smooth_generator()), next(real_generator())
# print(type(a))


# graph = model._step.get_concrete_function(a, b, c).graph
# for node in graph.as_graph_def().node:
#   print(f'{node.input} -> {node.name}')

In [None]:
cartoon_images = next(cartoon_real_generator())
cartoon_smoothed_images = next(cartoon_smooth_generator())
black_images = np.full((2, 224, 224, 3), 0.0)

print('cartoon')
plot_grid(D(cartoon_images).numpy().squeeze(), 4)
print('smoothed')
plot_grid(D(cartoon_smoothed_images).numpy().squeeze(), 4)
print('black')
plot_grid(D(black_images).numpy().squeeze(), 2)

In [None]:
real_images = next(real_generator())
black_images = np.full((4, 224, 224, 3), 0.0)

print('real images')
plot_grid(generated_to_images(real_images), 4)
print('cartoon generated')
plot_grid(generated_to_images(G(real_images)), 4)
print('black')
plot_grid(generated_to_images(G(black_images)), 4)