# Find the seed of an existing image

### Imports

In [None]:
import sys
sys.path.append('../')

import os
import torch
from torch import nn
import torchvision as tv
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from scipy.interpolate import interp1d

from training.generator import Generator
from training import utils
from training.settings import *

utils.reset_rand()

### Check GPU

In [None]:
utils.check_gpu()

### Import the model

In [None]:
MODEL_PATH = 'models/faces_256.pt'

model = Generator().to(DEVICE)
model.load_state_dict(torch.load(os.path.join('..', MODEL_PATH)))
model.eval()
model.requires_grad_(False)

### Mean W image

In [None]:
image = model.generate_one(0.0)

plt.imshow(image)
plt.axis('off')
plt.show()

### Settings

In [None]:
IMAGE_PATH = 'test.png'			# Path to the image to project
SAVE_DIR = 'projector'			# Directory to save the images to
NB_STEPS = 1_000				# Number of steps to use for projection
MAX_LEARNING_RATE = 0.1			# Maximum learning rate to use
WARMUP_STEPS = 50				# Number of steps to use for warmup
COOLDOWN_STEPS = 250			# Number of steps to use for cooldown
W_NOISE_SCALE = 0.05			# The scale of the noise on W
W_NOISE_STEPS = 750				# Number of steps using noise on W
NOISE_REG_STRENGTH = 100_000	# Strength of the noise regularization

### Import the image and VGG16

In [None]:
convert = transforms.Compose([
	transforms.Resize(IMAGE_SIZE),
	transforms.CenterCrop(IMAGE_SIZE),
	transforms.ToTensor()
])

file = os.path.join('..', IMAGE_PATH)

if NB_CHANNELS == 1:
	target = Image.open(file).convert('L')
elif NB_CHANNELS <= 3:
	target = Image.open(file).convert('RGB')
else:
	target = Image.open(file).convert('RGBA')

target = convert(target) * 2 - 1

if NB_CHANNELS == 2:
	target = target[:2]

target = target.to(DEVICE).unsqueeze(0).detach().requires_grad_(False)

### Import VGG16

In [None]:
vgg16 = tv.models.vgg16(weights = tv.models.VGG16_Weights.DEFAULT).to(DEVICE)
vgg16 = nn.Sequential(*list(vgg16.features.children()))
vgg16.eval()
vgg16.requires_grad_(False)

vgg16_transform = tv.models.VGG16_Weights.DEFAULT.transforms()


def vgg16_forward(x: torch.Tensor) -> torch.Tensor:

	x = (x + 1.0) / 2.0
	x = vgg16_transform(x)
	return vgg16(x)


target_features = vgg16_forward(target)

### Initialize the seed

In [None]:
ws = model.gen_w(MEAN_W_SAMPLES)
w_mean = ws.mean(0, keepdims = True).detach().requires_grad_(False)
w_std = ((ws - w_mean).square().sum() / MEAN_W_SAMPLES).sqrt().item()

w = w_mean.clone().requires_grad_(True)
noise = model.gen_noise(1)
noise = [n.requires_grad_(True) for n in noise]

optimizer = torch.optim.Adam([w] + noise, lr = 0.0)

### Training

In [None]:
for step in range(NB_STEPS):

	optimizer.zero_grad(set_to_none = True)

	lr_start = min(1.0, step / WARMUP_STEPS)
	lr_end = min(1.0, (NB_STEPS - step) / COOLDOWN_STEPS)
	lr_end = 0.5 - 0.5 * np.cos(lr_end * np.pi)
	lr = LEARNING_RATE * lr_start * lr_end

	for param_group in optimizer.param_groups:
		param_group['lr'] = lr

	w_noise_scale = w_std * W_NOISE_SCALE * max(0.0, 1.0 - step / W_NOISE_STEPS) ** 2
	w = w + w_noise_scale * torch.randn_like(w)

	gen_image = model.synthesis(w, noise)
	gen_features = vgg16_forward(gen_image)

	main_loss = (gen_features - target_features).square().sum()
	noise_reg = 0.0

	for n in noise:

		temp_n = n.clone()

		while True:

			noise_reg = noise_reg + (temp_n * torch.roll(temp_n, shifts = 1, dims = 2)).mean().square()
			noise_reg = noise_reg + (temp_n * torch.roll(temp_n, shifts = 1, dims = 3)).mean().square()

			if temp_n.shape[2] <= 8:
				break

			temp_n = nn.functional.avg_pool2d(temp_n, kernel_size = 2)

	loss = main_loss + NOISE_REG_STRENGTH * noise_reg

	loss.backward()
	optimizer.step()

	with torch.no_grad():
		for i in range(len(noise)):
			noise[i] = noise[i] - noise[i].mean()
			noise[i] = noise[i] * noise[i].square().mean().rsqrt()

	save_image = utils.denormalize(gen_image.detach().squeeze(0))

	if not os.path.exists(os.path.join('..', SAVE_DIR)):
		os.makedirs(os.path.join('..', SAVE_DIR))

	Image.fromarray(save_image).save(os.path.join('..', SAVE_DIR, f'{step}.png'))

	print(f'Steps: {step:,} / {NB_STEPS}  |  Loss: {main_loss.item():.3f}  |  Noise Regularisation: {noise_reg.item():.3f}')