In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path

from IPython import display

import cv2
import os

In [None]:
from os import environ
environ['TENSORBOARD_BINARY'] = r'C:\Users\johna\Documents\A_Documents\Programming\Python\.my_envs\datascience\Scripts\tensorboard.exe'

In [None]:
# %load_ext tensorboard

In [None]:
class Paths:
	CURRENT_MODEL = 'v8'
	CWD = Path.cwd()
	
	HISTORY_DIR = CWD/'history'
	MODEL_DIR = HISTORY_DIR/CURRENT_MODEL
	
	LOG_DIR = MODEL_DIR/'logs'
	SAVED_MODELS_DIR = MODEL_DIR/'saved_models'
	GENERATED_IMAGES_DIR = MODEL_DIR/'generated_images'
	
	CHECKPOINT_DIR = SAVED_MODELS_DIR/'checkpoints'
	SINGLE_IMAGE_DIR = GENERATED_IMAGES_DIR/'single'
	COLLECTION_DIR = GENERATED_IMAGES_DIR/'collection'
	
	DS_PATH = Path(r'E:\datasets\cat-faces')
	# DS_PATH = Path(r'E:\datasets\monet-paintings\monet_jpg')

In [None]:
FOLDERS_TO_SETUP = [
	Paths.HISTORY_DIR,
	Paths.MODEL_DIR,
	
	Paths.SAVED_MODELS_DIR,
	Paths.GENERATED_IMAGES_DIR,
	
	Paths.CHECKPOINT_DIR,
	Paths.SINGLE_IMAGE_DIR,
	Paths.COLLECTION_DIR
]

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
def mkdir(path):
	path.mkdir(exist_ok=True)

In [None]:
def setup_folders(paths: list):
	for path in paths:
		mkdir(path)

In [None]:
def display_image(image):
	plt.imshow(image.astype('uint8'))

### data preparation

In [None]:
IMAGE_SHAPE = (64, 64, 3)
BATCHSIZE = 32

In [None]:
def normalize_images(images):
	return (images-127.5)/127.5

In [None]:
def denormalize_images(images):
	return np.array((images*127.5+127.5), dtype='uint8')

In [None]:
def resize_images(images, new_size=IMAGE_SHAPE[:-1], interpolation=cv2.INTER_AREA):
	return np.array([cv2.resize(image, new_size, interpolation=interpolation) for image in images], dtype='float32')

In [None]:
def prepare_images():
	images = np.array([plt.imread(image) for image in (Paths.DS_PATH).glob('*')], dtype='float32')
	# resized_images = resize_images(images)
	# assert resized_images.shape[1:] == IMAGE_SHAPE, f'input images have the wrong dimensions! {resized_images.shape}'
	# normalized_images = normalize_images(resized_images)
	assert images.shape[1:] == IMAGE_SHAPE, f'images have the wrong shape! {images.shape}'
	normalized_images = normalize_images(images)
	return normalized_images

In [None]:
images = prepare_images()
train_dataset = tf.data.Dataset.from_tensor_slices(images).shuffle(images.shape[0]).batch(BATCHSIZE, drop_remainder=True).cache().prefetch(tf.data.AUTOTUNE)

### model building

In [None]:
FINAL_STEP = 1000*len(images)
LATENT_DIM = 100

In [None]:
def generate_sample_images(generator, latent_dim=LATENT_DIM, training=False):
	noise = tf.random.normal((12, latent_dim))
	
	generated_images = generator(noise, training=training)
	generated_images = denormalize_images(generated_images)

	display.clear_output(True)

	f, axs = plt.subplots(3, 4, figsize=(20, 15))
	for i, ax in enumerate(axs.flatten()):
		ax.imshow(generated_images[i])
		ax.axis('off')

	plt.show()

#### blocks

In [None]:
from tensorflow.keras import layers

In [None]:
def conv_block(x, n_filters, kernel_size=(3, 3), strides=(2, 2), padding='same'):
	x = layers.Conv2D(n_filters, kernel_size, strides, padding)(x)
	x = layers.LeakyReLU(0.2)(x)
	x = layers.Dropout(0.3)(x)
	return x

In [None]:
def normalized_conv_block(x, n_filters, kernel_size=(3, 3), strides=(2, 2), padding='same'):
	x = layers.Conv2D(n_filters, kernel_size, strides, padding)(x)
	x = layers.LeakyReLU(0.2)(x)
	x = layers.LayerNormalization()(x)
	x = layers.Dropout(0.3)(x)
	return x

In [None]:
def transposed_conv_block(x, n_filters, kernel_size=(3, 3), strides=(2, 2), padding=('same')):
	x = layers.Conv2DTranspose(n_filters, kernel_size, strides, padding, use_bias=False)(x)
	x = layers.LeakyReLU(0.2)(x)
	x = layers.BatchNormalization()(x)
	return x

In [None]:
def upsampling_conv_block(x, n_filters, kernel_size=(3, 3), strides=(1, 1), padding='same'):
	x = layers.UpSampling2D()(x)
	x = layers.Conv2D(n_filters, kernel_size, strides, padding)
	x = layers.LeakyReLU(0.2)(x)
	x = layers.LayerNormalization()(x)
	return x

In [None]:
def upsampling_transposed_conv_block(x, n_filters, kernel_size=(3, 3), strides=(1, 1), padding='same'):
	x = layers.UpSampling2D()(x)
	x = layers.Conv2DTranspose(n_filters, kernel_size, strides, padding)(x)
	x = layers.LeakyReLU(0.2)(x)
	x = layers.BatchNormalization()(x)
	return x

#### models

##### generator

In [None]:
def get_generator_model(input_size, optimizer, loss_function):
	initial_dimensions = (8, 8)
	inputs = layers.Input(input_size)

	x = layers.Dense(np.product(initial_dimensions)*256, use_bias=False)(inputs)
	x = layers.LeakyReLU(0.2)(x)
	x = layers.BatchNormalization()(x)

	x = layers.Reshape((*initial_dimensions, 256))(x)

	x = upsampling_transposed_conv_block(x, 256)
	x = upsampling_transposed_conv_block(x, 256)

	x = layers.UpSampling2D()(x)
	x = layers.Conv2DTranspose(3, kernel_size=(3, 3), padding='same')(x)
	outputs = tf.keras.activations.tanh(x)

	assert outputs.get_shape()[1:] == IMAGE_SHAPE, f'output tensor\'s shapes are wrong! {outputs.get_shape()}'

	model = tf.keras.Model(inputs, outputs, name='generator')
	
	model.compile(
		optimizer=optimizer,
		loss=loss_function
	)
	
	return model

In [None]:
def compute_generator_loss(fake_predictions):
	return -tf.reduce_mean(fake_predictions)

##### critic

In [None]:
def get_critic_model(input_size, optimizer, loss_function):
	inputs = layers.Input(input_size)

	x = normalized_conv_block(inputs, 64, (5, 5))
	x = normalized_conv_block(x, 128)

	x = layers.Flatten()(x)
	x = layers.Dropout(0.2)(x)
	outputs = layers.Dense(1)(x)

	model = tf.keras.Model(inputs, outputs, name='critic')
	
	model.compile(
		optimizer=optimizer,
		loss=loss_function
	)

	return model

In [None]:
def compute_critic_loss(real_predictions, fake_predictions):
	return tf.reduce_mean(fake_predictions - real_predictions)

##### wgan

In [None]:
class WGAN(tf.keras.Model):
	def __init__(self, generator, critic, latent_dim=LATENT_DIM, batchsize=BATCHSIZE, critic_extra_steps=1, lambd=10):
		super().__init__()
		self.generator = generator
		self.critic = critic
		self.latent_dim = tf.constant(latent_dim, dtype='int32')
		self.batchsize = tf.Variable(BATCHSIZE, dtype='int32')
		self.critic_extra_steps = tf.constant(critic_extra_steps, dtype='int32')
		self.lambd = tf.constant(lambd, dtype='float32')
		
	def compute_gradient_penalty(self, generated_images, real_images):
		epsilon = tf.random.uniform((self.batchsize, 1, 1, 1), 0, 1)
		interpolated_generated_images = epsilon * real_images + (1-epsilon) * generated_images

		with tf.GradientTape() as gp_tape:
			gp_tape.watch(interpolated_generated_images)
			prediction = self.critic(interpolated_generated_images, training=True)

		grads = gp_tape.gradient(prediction, interpolated_generated_images)
		
		norms = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=(1, 2, 3)))
		gradient_penalty = self.lambd * tf.reduce_mean(tf.square(norms - 1))

		return gradient_penalty
	
	def update_critic(self, real_images):
		noise = tf.random.normal(shape=(self.batchsize, self.latent_dim))

		with tf.GradientTape() as c_tape:
			generated_images = self.generator(noise, training=True)
			fake_predictions = self.critic(generated_images, training=True)
			real_predictions = self.critic(real_images, training=True)

			c_cost = self.critic.loss(real_predictions, fake_predictions)
			gradient_penalty = self.compute_gradient_penalty(generated_images, real_images)
			c_loss = c_cost + gradient_penalty
		
		c_gradients = c_tape.gradient(c_loss, self.critic.trainable_variables)
		self.critic.optimizer.apply_gradients(zip(c_gradients, self.critic.trainable_variables))
		
		return c_loss
	
	def update_generator(self):
		noise = tf.random.normal((self.batchsize, self.latent_dim))
		
		with tf.GradientTape() as g_tape:
			generated_images = self.generator(noise, training=True)
			fake_predictions = self.critic(generated_images, training=True)
			g_loss = self.generator.loss(fake_predictions)
		
		g_gradients = g_tape.gradient(g_loss, self.generator.trainable_variables)
		self.generator.optimizer.apply_gradients(zip(g_gradients, self.generator.trainable_variables))
		
		return g_loss
	
	def train_step(self, images):
		self.batchsize = tf.shape(images)[0]
		
		for _ in range(self.critic_extra_steps):
			c_loss = self.update_critic(images)
			
		g_loss = self.update_generator()
		
		return {'c_loss': c_loss, 'g_loss': g_loss}

#### callbacks

In [None]:
class VisualizerOnBatch(tf.keras.callbacks.Callback):
	def __init__(self, latent_dim=LATENT_DIM):
		self.latent_dim = latent_dim
		
	def on_batch_end(self, batch, logs=None):
		generate_sample_images(self.model.generator, self.latent_dim, training=True)

In [None]:
class VisualizerOnEpoch(tf.keras.callbacks.Callback):
	def __init__(self, latent_dim=LATENT_DIM):
		self.latent_dim = latent_dim
		
	def on_epoch_end(self, epoch, logs=None):
		generate_sample_images(self.model.generator, self.latent_dim, training=True)
			
		if epoch % 5 == 0:
			f.savefig(Paths.COLLECTION_DIR/f'image_at_epoch_{epoch}.png', format='png', dpi=100)

In [None]:
class Checkpointer(tf.keras.callbacks.Callback):
	def on_epoch_end(self, epoch):
		if epoch % 25 == 0:
			checkpoint.save(Paths.CHECKPOINT_DIR/'checkpoint')

In [None]:
class CustomLearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
	def __init__(self, init_lr, final_lr, final_step):
		self.init_lr = tf.cast(init_lr, 'float32')
		self.final_lr = tf.cast(final_lr, 'float32')
		self.final_step = tf.cast(final_step, 'float32')
		
	def __call__(self, step):
		return tf.cond(
			tf.greater_equal(step, FINAL_STEP),
			true_fn=lambda: self.final_lr,
			false_fn=lambda: -(tf.sqrt(step * self.final_step) / self.final_step) * (self.init_lr - self.final_lr) + self.init_lr
		)
	
	def get_config(self):
		return {
			'init_lr': tf.get_static_value(self.init_lr),
			'final_lr': tf.get_static_value(self.final_lr),
			'final_step': tf.get_static_value(self.final_step)
			}

### training

In [None]:
generator_schedule = CustomLearningRateSchedule(init_lr=1e-4, final_lr=1e-5, final_step=FINAL_STEP)
critic_schedule = CustomLearningRateSchedule(init_lr=1e-4, final_lr=1e-7, final_step=FINAL_STEP*5)

generator_optimizer = tf.keras.optimizers.Adam(generator_schedule, 0, 0.9)
critic_optimizer = tf.keras.optimizers.Adam(critic_schedule, 0, 0.9)

In [None]:
generator = get_generator_model(LATENT_DIM, generator_optimizer, compute_generator_loss)
critic = get_critic_model(IMAGE_SHAPE, critic_optimizer, compute_critic_loss)

In [None]:
visualizer_on_batch = VisualizerOnBatch()
visualizer_on_epoch = VisualizerOnEpoch()
checkpointer = Checkpointer()
tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=Paths.LOG_DIR, histogram_freq=1, profile_batch='300, 310')

In [None]:
# each callback takes resources away from training. For fastest training disable all callbacks.
callbacks = [
	# visualizer_on_batch,
	# visualizer_on_epoch,
	# checkpointer,
	# tensorboard_cb
]

In [None]:
checkpoint = tf.train.Checkpoint(
	generator=generator,
	critic=critic
)

In [None]:
wgan = WGAN(generator, critic)
wgan.compile()

In [None]:
EPOCHS = 2
# setup_folders(FOLDERS_TO_SETUP)
history = wgan.fit(train_dataset, shuffle=True, epochs=EPOCHS, callbacks=callbacks, verbose=1)

In [None]:
# generate and display a single image
noise = tf.random.normal((1, LATENT_DIM))

generated_image = wgan.generator(noise, training=False)
generated_image = denormalize_images(generated_image)

f, axis = plt.subplots(1, 1, figsize=(5, 5))
axis.imshow(generated_image[0])
axis.axis('off')

plt.show()

In [None]:
# generate and display a collage of images
generate_sample_images(wgan.generator)

In [None]:
# generate and display a collage of images
generate_sample_images(wgan.generator)

In [None]:
def save_weights():
	wgan.generator.save_weights(Paths.SAVED_MODELS_DIR/'generator')
	wgan.critic.save_weights(Paths.SAVED_MODELS_DIR/'critic')
# save_models()

In [None]:
def load_weights():
	wgan.generator.load_weights(Paths.SAVED_MODELS_DIR/'generator')
	wgan.critic.load_weights(Paths.SAVED_MODELS_DIR/'critic')
# load_models()

In [None]:
f, axs = plt.subplots(1, 2, figsize=(30, 9))
axs[0].plot(wgan.history.history['c_loss'])
axs[0].title.set_text('critic loss')
axs[1].plot(wgan.history.history['g_loss'])
axs[1].title.set_text('generator loss')
plt.show()

#### refrences
http://modelai.gettysburg.edu/2020/wgan/Resources/Lesson5/WGAN-GP.pdf  
https://keras.io/examples/generative/wgan_gp/  
https://github.com/caogang/wgan-gp/blob/master/gan_mnist.py  
https://developers.google.com/machine-learning/gan/loss  
https://www.youtube.com/watch?v=pG0QZ7OddX4  
https://distill.pub/2016/deconv-checkerboard/  