# Anime faces generator (GAN)

### Imports

In [None]:
import math
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow as tf

import mapping, generator
import utils

### Settings

In [None]:
IMAGE_SIZE = 128
NB_CHANNELS = 3
MARGIN = IMAGE_SIZE // 8

LATENT_DIM = 512
MAPPING_LAYERS = 8
MIN_IMAGE_SIZE = 4
MAX_FILTERS = 512
MIN_FILTERS = 64
KERNEL_SIZE = 3
ALPHA = 0.2
GAIN = 1.3
MAPPING_LR_RATIO = 0.01

NB_BLOCKS = int(math.log(IMAGE_SIZE, 2)) - int(math.log(MIN_IMAGE_SIZE, 2)) + 1

### Import model

In [None]:
ma_mapping = mapping.build_mapping(LATENT_DIM, MAPPING_LAYERS, MAPPING_LR_RATIO, ALPHA, GAIN)
ma_generator = generator.build_generator(LATENT_DIM, IMAGE_SIZE, NB_CHANNELS, MIN_IMAGE_SIZE, MAX_FILTERS, MIN_FILTERS, KERNEL_SIZE, ALPHA, GAIN)

ma_mapping.load_weights("./output/models/model_64/ma_mapping.h5")
ma_generator.load_weights("./output/models/model_64/ma_generator.h5")

### Utils

In [None]:
def gen_z(nb):
	return np.random.normal(size = (nb, LATENT_DIM))

def gen_noise(nb):
	return [np.random.normal(0., 1., (nb, IMAGE_SIZE, IMAGE_SIZE, 1)) for _ in range((NB_BLOCKS * 2) - 1)]

def gen_w(z, batch_size = None):

	if batch_size is None:
		batch_size = z.shape[0]

	w = np.zeros((z.shape[0], LATENT_DIM), dtype = np.float32)

	for i in range(0, z.shape[0], batch_size):

		size = min(batch_size, z.shape[0] - i)
		w[i:i + size, :] = ma_mapping(tf.convert_to_tensor(z[i:i + size])).numpy()

	return w

def gen_images(w, noise, batch_size = None):

	if batch_size is None:
		batch_size = w.shape[0]

	generations = np.zeros((w.shape[0], IMAGE_SIZE, IMAGE_SIZE, NB_CHANNELS), dtype = np.uint8)

	for i in range(0, w.shape[0], batch_size):

		size = min(batch_size, w.shape[0] - i)
		const_input = [tf.ones((size, 1))]
		n = [tf.convert_to_tensor(j[i:i + size]) for j in noise]
		gen = ma_generator(const_input + ([w[i:i + size]] * NB_BLOCKS) + n)
		generations[i:i + size, :, :, :] = utils.denorm_img(gen.numpy())

	return generations

def plot(images, shape, path = None, show = True):

	output_image = np.full((
		MARGIN + (shape[1] * (images.shape[2] + MARGIN)),
		MARGIN + (shape[0] * (images.shape[1] + MARGIN)),
		images.shape[3]), 255, dtype = np.uint8)

	i = 0
	for row in range(shape[1]):
		for col in range(shape[0]):
			r = row * (images.shape[2] + MARGIN) + MARGIN
			c = col * (images.shape[1] + MARGIN) + MARGIN
			output_image[r:r + images.shape[2], c:c + images.shape[1]] = images[i]
			i += 1

	if show:
		dpi = mpl.rcParams['figure.dpi']
		fig = plt.figure(figsize = (output_image.shape[0] / float(dpi), output_image.shape[1] / float(dpi)), dpi = dpi)
		ax = fig.add_axes([0, 0, 1, 1])
		ax.axis('off')
		ax.imshow(output_image)
		plt.show()

	img = Image.fromarray(output_image)

	if path is not None:
		img.save(path)

### Mean W

In [None]:
z = gen_z(100000)
w = gen_w(z, batch_size = 1000)
mean_w = np.mean(w, axis = 0)

In [None]:
def gen_images_truncated(w, noise, psi = 1., batch_size = None):

	return gen_images(mean_w + psi * (w - mean_w), noise, batch_size = batch_size)

In [None]:
z = gen_z(1)
w = gen_w(z)
noise = gen_noise(1)
image = gen_images_truncated(w, noise, psi = 0.)

plot(image, (1, 1))

### Tests

In [None]:
shape = (8, 6)
z = gen_z(shape[0] * shape[1])
w = gen_w(z)
noise = gen_noise(shape[0] * shape[1])
images = gen_images_truncated(w, noise, psi = 0.7)

plot(images, shape, "test_2.png")