# Generate a timelapse of the training

### Imports

In [None]:
import sys
sys.path.append('../')

import os
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from scipy.interpolate import interp1d

from training.generator import Generator
from training import utils
from training.settings import *

utils.reset_rand()

### Check GPU

In [None]:
utils.check_gpu()

### Settings

In [None]:
MODELS_DIR = 'output/models'	# Directory containing the models
SAVE_DIR = 'timelapse'			# Directory to save the images to
SHAPE = (4, 3)					# Shape of the grid of images
IMAGES_PER_MODEL = 8			# Number of images to generate per model
PSI = 0.5						# Psi value to use for the generation

### Generate images

In [None]:
z = Generator.gen_z(SHAPE[0] * SHAPE[1])
noise = Generator.gen_noise(SHAPE[0] * SHAPE[1])

models = os.listdir(os.path.join('..', MODELS_DIR))
models.sort(key = lambda x: int(x[8:].split('_')[0]))

model = Generator().to(DEVICE)
model_1 = Generator().to(DEVICE)
model_2 = Generator().to(DEVICE)
i = 0

images = []

for i in range(len(models) - 1):

	print(f'Model {i + 1}')

	model_1.load_state_dict(torch.load(os.path.join('..', MODELS_DIR, models[i], 'ma_generator.pt'), map_location = DEVICE))
	model_1.eval()
	model_1.requires_grad_(False)

	model_2.load_state_dict(torch.load(os.path.join('..', MODELS_DIR, models[i + 1], 'ma_generator.pt'), map_location = DEVICE))
	model_2.eval()
	model_2.requires_grad_(False)

	model_1.mean_w = None
	model_2.mean_w = None

	w_1 = model_1.z_to_w(z, psi = PSI)
	w_2 = model_2.z_to_w(z, psi = PSI)

	with torch.no_grad():

		for j in range(IMAGES_PER_MODEL):

			for p, p_1, p_2 in zip(model.parameters(), model_1.parameters(), model_2.parameters()):
				p.copy_(p_1.detach().lerp(p_2.detach(), j / IMAGES_PER_MODEL))

			for b, b_1, b_2 in zip(model.parameters(), model_1.buffers(), model_2.buffers()):
				b.copy_(b_1.detach().lerp(b_2.detach(), j / IMAGES_PER_MODEL))

			model.eval()
			model.requires_grad_(False)

			w = w_1.detach().lerp(w_2.detach(), j / IMAGES_PER_MODEL)

			images.append(model.w_to_images(w, noise).detach().to('cpu').numpy())

print("Saving images")

if not os.path.exists(os.path.join('..', SAVE_DIR)):
	os.mkdir(os.path.join('..', SAVE_DIR))

for i in range(len(images)):
	img = utils.create_grid(images[i], SHAPE)
	Image.fromarray(img).save(os.path.join('..', SAVE_DIR, f'{i}.png'))