# Generate the images for a style mixing interpolation video

### Imports

In [None]:
import sys
sys.path.append('../')

import os
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from scipy.interpolate import interp1d

from training.generator import Generator
from training import utils
from training.settings import *

utils.reset_rand()

### Check GPU

In [None]:
utils.check_gpu()

### Import the model

In [None]:
MODEL_PATH = 'models/faces_256.pt'

model = Generator().to(DEVICE)
model.load_state_dict(torch.load(os.path.join('..', MODEL_PATH), map_location = DEVICE))
model.eval()
model.requires_grad_(False)

### Mean W image

In [None]:
image = model.generate_one(0.0)

plt.imshow(image)
plt.axis('off')
plt.show()

### Settings

In [None]:
SAVE_DIR = 'style_mixing'	# Directory to save the images to
NB_IMAGES = 4				# Number of images to style mix
NB_SEEDS = 15				# Number of seeds to use for the interpolation
SEED_TIME = 1.5				# Time to use for the interpolation of one seed
FRAME_RATE = 30				# Frame rate of the video
MIX_POINT = 2				# Point of the interpolation to mix the two seeds
INTERPOLATED_FIRST = True	# Put the interpolated seed on the first part
PSI = 0.5					# Psi value to use for the generation

### Generate images

In [None]:
seed_frames = int(SEED_TIME * FRAME_RATE)

print('Interpolated image')

w_seeds = model.gen_w(NB_SEEDS, psi = PSI).detach().to('cpu').numpy()
noise_seeds = model.gen_noise(NB_SEEDS)

x = np.arange(-NB_SEEDS * 2, NB_SEEDS * 3)
y = np.tile(w_seeds, [5, 1])
w_interp = interp1d(x, y, kind = 'cubic', axis = 0)

noise_interps = []

for i in range(len(noise_seeds)):

	noise_seeds[i] = noise_seeds[i].detach().to('cpu').numpy()

	x = np.arange(-NB_SEEDS * 2, NB_SEEDS * 3)
	y = np.tile(noise_seeds[i], [5, 1, 1, 1])

	noise_interps.append(interp1d(x, y, kind = 'cubic', axis = 0))

w = []
noise = []

for frame in range(NB_SEEDS * seed_frames):
	w.append(w_interp(frame / float(seed_frames)))

for interp in noise_interps:
	noise.append([interp(frame / float(seed_frames)) for frame in range(NB_SEEDS * seed_frames)])

w = torch.as_tensor(np.array(w), dtype = torch.float32, device = DEVICE)
noise = [torch.as_tensor(np.array(n), dtype = torch.float32, device = DEVICE) for n in noise]

interpolated_generation = model.w_to_images(w, noise).detach().to('cpu').numpy()
generations = []
style_mix_generations = []
noise_mix_point = int((MIX_POINT / NB_W) * NB_NOISE)

for i in range(NB_IMAGES):

	print(f'Style mixing image {i + 1}')

	image_w = model.gen_w(1, psi = PSI)
	image_noise = model.gen_noise(1)

	generations.append(model.w_to_images(image_w, image_noise).detach().to('cpu').numpy()[0])

	image_w = image_w.repeat((len(w), 1))
	image_noise = [n.repeat((len(w), 1, 1, 1)) for n in image_noise]

	if INTERPOLATED_FIRST:
		image_w = model.style_mix(w, image_w, MIX_POINT)
		image_noise = noise[:noise_mix_point] + image_noise[noise_mix_point:]
	else:
		image_w = model.style_mix(image_w, w, MIX_POINT)
		image_noise = image_noise[:noise_mix_point] + noise[noise_mix_point:]

	style_mix_generations.append(model.w_to_images(image_w, image_noise).detach().to('cpu').numpy())

print("Saving images")

if not os.path.exists(os.path.join('..', SAVE_DIR)):
	os.mkdir(os.path.join('..', SAVE_DIR))

for i in range(len(w)):

	image = [np.ones((NB_CHANNELS, IMAGE_SIZE, IMAGE_SIZE))]

	for j in range(NB_IMAGES):
		image.append(generations[j])

	image.append(interpolated_generation[i])

	for j in range(NB_IMAGES):
		image.append(style_mix_generations[j][i])

	image = utils.create_grid(np.array(image), (NB_IMAGES + 1, 2))
	Image.fromarray(image).save(os.path.join('..', SAVE_DIR, f'{i}.png'))