In [1]:
import argparse
import os

import torch
from torch import optim
from torch.nn import functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

import lpips
from model import Generator
from utils import *

In [2]:
imgfiles = ['..\\woman.png']
ckpt = '..\\stylegan2-ffhq-config-f.pt'
n_mean_latent = 10**6

size = 1024
resize = min(256, size)

device = "cuda"
transform = transforms.Compose(
    [
        transforms.Resize(resize),
        transforms.CenterCrop(resize),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)


In [3]:
imgs = []
for imgfile in imgfiles:
    img = transform(Image.open(imgfile).convert("RGB"))
    imgs.append(img)
    
imgs = torch.stack(imgs, 0).to(device)    

In [4]:
g_ema = Generator(size, 512, 8)
g_ema.load_state_dict(torch.load(ckpt)["g_ema"], strict=False)
g_ema.eval()
g_ema = g_ema.to(device)

# Prepare data

In [5]:
CHUNK = 10**5
count = n_mean_latent
torch.manual_seed(1)
latent_out_list = []
with torch.no_grad():
    while count > 0:
        print(count)
        noise_sample = torch.randn(CHUNK, 512, device=device)
        latent_out = g_ema.style(noise_sample)
        latent_out_list.append(latent_out)
        count -= CHUNK
    print("stacking...")
    latent_out = torch.vstack(latent_out_list)
    latent_mean = latent_out.mean(0)
    # latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5

1000000
900000
800000
700000
600000
500000
400000
300000
200000
100000
stacking...


In [7]:
latent_out_dict = {
    'W': latent_out,
    'mean':latent_mean
}
torch.save(latent_out_dict, 'latent_back')

In [11]:
P = F.leaky_relu(latent_out, negative_slope=5)
P_dict = {'P': P}
torch.save(P_dict, 'P')

# Load P and do SVD (for PCA)

In [1]:
import torch
P_dict = torch.load('P')
P = P_dict['P']

In [2]:
P_mean = P.mean(0)
_, S, V = torch.svd(P - P_mean)

In [4]:
svd_dict = {'S':S, 'V':V, 'mean': P_mean}
torch.save(svd_dict, 'svd_S_V')

In [3]:
# _, S_2, V_2 = torch.pca_lowrank(P, 2, center=True)
# print(S_2, S[:2])

tensor([5374.1353, 4875.6348], device='cuda:0') tensor([5470.8291, 5015.5801], device='cuda:0')


# save mean and std

In [2]:
latent_dict = torch.load('latent_back')
latent_mean = latent_dict['mean']

In [5]:
latent_out = latent_dict['W']
latent_std = ((latent_out - latent_mean).pow(2).sum() / (n_mean_latent - 1) ) ** 0.5

In [6]:
latent_stat = {'mean': latent_mean, 'std': latent_std}
torch.save(latent_stat, 'latent_stat')

# II2S

In [10]:
latent_stat = torch.load('latent_stat')
latent_mean, latent_std = latent_stat['mean'], latent_stat['std']
svd_dict = torch.load('svd_S_V')
S, V, P_mean = svd_dict['S'], svd_dict['V'], svd_dict['mean']
S_inv = 1 / S


In [11]:
def PN_loss(latent_in, P_mean, V, S_inv):
    x = F.leaky_relu(latent_in, negative_slope=5)
    v = torch.matmul(x - P_mean, V) * S_inv
    return (v ** 2).mean()

In [7]:
percept = lpips.PerceptualLoss(
    model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
)
noises_single = g_ema.make_noise()
noises = []
for noise in noises_single:
    noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())



Setting up Perceptual loss...
Loading model from: d:\S\PyCharmProject\StyleGAN2-ADA-main\stylegan2-pytorch\lpips\weights\v0.1\vgg.pth
...[net-lin [vgg]] initialized
...Done


In [9]:
latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)

In [12]:
latent_in.requires_grad = True

for noise in noises:
    noise.requires_grad = False

In [16]:
lr_init = 0.01
step = 1300
mse = 0.1
pn = 0.001

noise_rate = 0.0
noise_ramp = 0.75
noise_regulr = 1e5

# optimizer = optim.Adam([latent_in] + noises, lr=lr_init)
optimizer = optim.Adam([latent_in], lr=lr_init)




In [17]:
pbar = tqdm(range(step))
latent_path = []
for i in pbar:
    t = i / step
    lr = get_lr(t, lr_init)
    optimizer.param_groups[0]["lr"] = lr
    if noise_rate > 0:
        noise_strength = latent_std * noise_rate * max(0, 1 - t / noise_ramp) ** 2
        latent_n = latent_noise(latent_in, noise_strength.item())
    else:
        latent_n = latent_in
    img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises)

    batch, channel, height, width = img_gen.shape

    if height > 256:
        factor = height // 256

        img_gen = img_gen.reshape(
            batch, channel, height // factor, factor, width // factor, factor
        )
        img_gen = img_gen.mean([3, 5])

    p_loss = percept(img_gen, imgs).sum()
    
    # n_loss = noise_regularize(noises)
    n_loss = torch.Tensor([0]).cuda()

    mse_loss = F.mse_loss(img_gen, imgs)

    pn_loss = PN_loss(latent_in)

    loss = p_loss + noise_regulr * n_loss + mse * mse_loss + pn * pn_loss


    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    noise_normalize_(noises)

    if (i + 1) % 100 == 0:
        latent_path.append(latent_in.detach().clone())

    pbar.set_description(
        (
            f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
            f" mse: {mse_loss.item():.4f}; lr: {lr:.6f}"
        )
    )

img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises)

filename = os.path.splitext(os.path.basename(imgfiles[0]))[0] + ".pt"

img_ar = make_image(img_gen)

result_file = {}
for i, input_name in enumerate(imgfiles):
    noise_single = []
    for noise in noises:
        noise_single.append(noise[i : i + 1])

    result_file[input_name] = {
        "img": img_gen[i],
        "latent": latent_in[i],
        "noise": noise_single,
    }

    img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png"
    pil_img = Image.fromarray(img_ar[i])
    pil_img.save(img_name)

torch.save(result_file, filename)


perceptual: 0.0992; noise regularize: 0.0000; mse: 0.0070; lr: 0.000000: 100%|██████████| 1300/1300 [04:43<00:00,  4.58it/s]


In [52]:
get_lr(1/1000, 0.1)

0.002