# Generate the images for an interpolation video

### Imports

In [None]:
import sys
sys.path.append('../')

import os
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from scipy.interpolate import interp1d

from training.generator import Generator
from training import utils
from training.settings import *

utils.reset_rand()

### Check GPU

In [None]:
utils.check_gpu()

### Import the model

In [None]:
MODEL_PATH = 'models/faces_256.pt'

model = Generator().to(DEVICE)
model.load_state_dict(torch.load(os.path.join('..', MODEL_PATH)))
model.eval()
model.requires_grad_(False)

### Mean W image

In [None]:
image = model.generate_one(0.0)

plt.imshow(image)
plt.axis('off')
plt.show()

### Settings

In [None]:
SAVE_DIR = 'interpolation'	# Directory to save the images to
SHAPE = (3, 2)				# Shape of the grid of images
NB_SEEDS = 15				# Number of seeds to use for the interpolation
SEED_TIME = 1.5				# Time to use for the interpolation of one seed
FRAME_RATE = 30				# Frame rate of the video
PSI = 0.5					# Psi value to use for the generation

### Generate images

In [None]:
seed_frames = int(SEED_TIME * FRAME_RATE)
global_generations = []

for i in range(SHAPE[0] * SHAPE[1]):

	print(f'Image {i + 1}')

	w_seeds = model.gen_w(NB_SEEDS, psi = PSI).detach().to('cpu').numpy()
	noise_seeds = model.gen_noise(NB_SEEDS)

	x = np.arange(-NB_SEEDS * 2, NB_SEEDS * 3)
	y = np.tile(w_seeds, [5, 1])
	w_interp = interp1d(x, y, kind = 'cubic', axis = 0)

	noise_interps = []

	for i in range(len(noise_seeds)):

		noise_seeds[i] = noise_seeds[i].detach().to('cpu').numpy()

		x = np.arange(-NB_SEEDS * 2, NB_SEEDS * 3)
		y = np.tile(noise_seeds[i], [5, 1, 1, 1])

		noise_interps.append(interp1d(x, y, kind = 'cubic', axis = 0))

	w = []
	noise = []

	for frame in range(NB_SEEDS * seed_frames):
		w.append(w_interp(frame / float(seed_frames)))

	for interp in noise_interps:
		noise.append([interp(frame / float(seed_frames)) for frame in range(NB_SEEDS * seed_frames)])

	w = torch.as_tensor(np.array(w), dtype = torch.float32, device = DEVICE)
	noise = [torch.as_tensor(np.array(n), dtype = torch.float32, device = DEVICE) for n in noise]

	generations = model.w_to_images(w, noise).detach().to('cpu').numpy()
	global_generations.append(generations)

print('Saving images')

global_generations = np.array(global_generations)
global_generations = np.transpose(global_generations, (1, 0, 2, 3, 4))

if not os.path.exists(os.path.join('..', SAVE_DIR)):
	os.mkdir(os.path.join('..', SAVE_DIR))

for i in range(global_generations.shape[0]):
	images = utils.create_grid(global_generations[i], SHAPE)
	Image.fromarray(images).save(os.path.join('..', SAVE_DIR, f'{i}.png'))