In [None]:
#!L
backend = 'None' # 'YC', 'Colab'

if backend == 'Colab':
    !pip install lpips
    !git clone https://github.com/yandexdataschool/Practical_DL.git
    !sudo apt install -y ninja-build
    %cd /content/Practical_DL/seminar07-gen_models_2
    !wget https://www.dropbox.com/s/2kpsomtla61gjrn/pretrained.tar
    !tar -xvf pretrained.tar
elif backend == 'YC':
    # Yandex Cloud (temporary unavailable)
    %wget https://www.dropbox.com/s/2kpsomtla61gjrn/pretrained.tar
    %tar -xvf pretrained.tar

In [None]:
!wget https://www.dropbox.com/s/2kpsomtla61gjrn/pretrained.tar
!tar -xvf pretrained.tar

import sys
sys.path.append('/content/Practical_DL/seminar08-gen_models_3')

In [None]:
#!L
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from matplotlib import pyplot as plt

print (torch.cuda.device_count())
print (torch.__version__)

import torchvision
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage
from tqdm.auto import tqdm, trange
from PIL import Image

from gans.gan_load import make_stylegan2


def to_image(tensor, adaptive=False):
    if len(tensor.shape) == 4:
        tensor = tensor[0]
    if adaptive:
        tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
    else:
        tensor = ((tensor + 1) / 2).clamp(0, 1)

    return ToPILImage()((255 * tensor.cpu().detach()).to(torch.uint8))


def to_image_grid(tensor, adaptive=False, **kwargs):
    return to_image(make_grid(tensor, **kwargs), adaptive)

In [None]:
!cd /content/Practical_DL

In [None]:
#!L
G = make_stylegan2(resolution=1024,
                   weights='pretrained/stylegan2-ffhq-config-f.pt', target_key='g').eval()


with torch.no_grad():
    z = torch.randn([4, 512]).cuda()
    imgs = G(z)

plt.figure(dpi=200)
plt.axis('off')
plt.imshow(to_image_grid(imgs, nrow=6))

# Naive inversions

In [None]:
# download image
import requests
from io import BytesIO
from torchvision import transforms

zoom = 1.


def portrait_crop(img, h_percent, w_percent):
    w, h = img.size
    w_offset = int(0.5 * (1 - w_percent) * w)
    return img.crop([w_offset, 0, w - w_offset, int(h_percent * h)])


def load_image(img_url, zoom=1.0, w=1.0, h=1.0):
    crop = lambda x: portrait_crop(x, w, h)

    normalization = transforms.Compose([
        crop,
        transforms.Resize(int(zoom * 1024)),
        transforms.Resize(int(zoom * 1024)),
        transforms.CenterCrop(1024),
        transforms.ToTensor(),
        lambda x: 2 * x - 1,
    ])

    img_data = requests.get(img_url).content
    img = Image.open(BytesIO(img_data))
    return normalization(img).unsqueeze(0).cuda()


imgs = []
imgs.append(load_image('https://fotorelax.ru/wp-content/uploads/2015/08/Daniel-Jacob-Radcliffe_6.jpg'))
imgs.append(load_image('https://www.kinogallery.com/pimages/742/kinogallery.com-742-520627.jpg', h=0.8))

img = imgs[0]


plt.figure(dpi=200)
plt.axis('off')
plt.imshow(to_image(img))

In [None]:
import lpips
lpips_model = lpips.LPIPS()
lpips_model.cuda().eval()
lpips_dist = lambda x, y: lpips_model(
    F.interpolate(x, 256, mode='bilinear'),
    F.interpolate(y, 256, mode='bilinear'))

# CelebA regressor features extractor
# same as at Seminar 7
face_fe = torchvision.models.resnet18()
face_fe.fc = nn.Sequential(nn.ReLU(), nn.Linear(512, 512), nn.ReLU())

state_dict = torch.load('pretrained/regressor.pth')['model_state_dict']
state_dict = {name[len('backbone.'):]: val for name, val in state_dict.items() if name.startswith('backbone.')}

face_fe.load_state_dict(state_dict)
face_fe.cuda().eval();

In [None]:
def invert(img, G, latent_init, n_steps=500, lr=0.025,
           l2_loss_scale=0.1, lpips_loss_scale=1.0, id_loss_scale=1.0,
           latent_map=lambda x: x, **g_kwargs):
    latent = nn.Parameter(latent_init.cuda())
    opt = torch.optim.Adam([latent,], lr=lr)

    l2_losses = []
    perceptual_losses = []
    id_losses = []
    losses = []
    for i in trange(n_steps):
        opt.zero_grad()

        reconstruction = G(latent_map(latent), **g_kwargs)
        l2_loss, perceptual_loss, id_loss = [torch.zeros([])] * 3
        if l2_loss_scale > 0.0:
            l2_loss = F.mse_loss(img, reconstruction).mean()
        if lpips_loss_scale > 0.0:
            perceptual_loss = lpips_dist(img, reconstruction).mean()
        if id_loss_scale > 0.0:
            id_loss = F.mse_loss(face_fe(img), face_fe(reconstruction)).mean()

        loss = l2_loss_scale * l2_loss + lpips_loss_scale * perceptual_loss + id_loss_scale * id_loss
        loss.backward()

        l2_losses.append(l2_loss.item())
        perceptual_losses.append(perceptual_loss.item())
        id_losses.append(id_loss.item())
        losses.append(loss.item())

        opt.step()
        if i % 100 == 0:
            print(f'{i}: loss: {np.mean(losses[-100:]): 0.2f}; '
                  f'l2-loss: {np.mean(l2_losses[-100:]): 0.2f}; '
                  f'lpips loss: {np.mean(perceptual_losses[-100:]): 0.2f}; '
                  f'id-loss: {np.mean(id_losses[-100:]): 0.2f}')

    return reconstruction, latent, losses


def show_inversion_result(img, reconstruction, losses=None):
    _, axs = plt.subplots(1, 3, dpi=250)
    for ax in axs[:2]: ax.axis('off')

    axs[0].imshow(to_image_grid(img))
    axs[1].imshow(to_image_grid(reconstruction))
    if losses is not None:
        axs[2].set_aspect(1.0 / np.max(losses) * len(losses))
        axs[2].set_title('Loss')
        axs[2].plot(losses)

In [None]:
rec, z, losses = invert(img, G, torch.randn([1, G.dim_z]), n_steps=100)
show_inversion_result(img, rec, losses)

In [None]:
w_mean = G.style_gan2.mean_latent(64)
rec, w, losses = invert(img, G, w_mean, w_space=True, n_steps=100)
show_inversion_result(img, rec, losses)

In [None]:
w_mean = G.style_gan2.mean_latent(64)
rec, w_plus, losses = invert(img, G, w_mean.unsqueeze(1).repeat(1, 18, 1), n_steps=100,
                             latent_map=lambda w_plus: [w_plus], w_space=True)
show_inversion_result(img, rec, losses)

# Pix2Style2Pix

cc: https://github.com/eladrich/pixel2style2pixel

In [None]:
!git clone https://github.com/eladrich/pixel2style2pixel
!touch pixel2style2pixel/__init__.py

In [None]:
import sys
sys.path.append('pixel2style2pixel')
from models.encoders.psp_encoders import GradualStyleEncoder
from argparse import Namespace

encoder_chkpt = torch.load('pretrained/psp_ffhq_encode.pt')
encoder = GradualStyleEncoder(50, 'ir_se', opts=Namespace(**encoder_chkpt['opts']))
encoder_state = {name[len('encoder.'):]: val for name, val in encoder_chkpt['state_dict'].items() \
                 if name.startswith('encoder')}

encoder.load_state_dict(encoder_state)
encoder.cuda().eval();
latent_mean = encoder_chkpt['latent_avg'].cuda()

In [None]:
with torch.no_grad():
    w_inversion = encoder(F.interpolate(img, 256, mode='bilinear')) + latent_mean[None]
    rec = G([w_inversion], w_space=True)
show_inversion_result(img, rec)

### pix2style2pix with optimization

In [None]:
rec, w_plus, losses = invert(img, G, w_inversion, n_steps=100, lr=0.005,
                             latent_map=lambda w_plus: [w_plus], w_space=True)
show_inversion_result(img, rec, losses)

# Style Mix

In [None]:
with torch.no_grad():
    w_target = G.style_gan2.style(torch.randn([1, G.dim_z], device='cuda')).unsqueeze(1).repeat(1, 18, 1)
    w_source = G.style_gan2.style(torch.randn([1, G.dim_z], device='cuda'))
    target = G(w_target, w_space=True)
    source = G(w_source, w_space=True)

plt.axis('off')
plt.imshow(to_image_grid(torch.cat([target, source])))

In [None]:
styled_rec = []
for i in range(18):
    w_styled = w_target.clone()
    w_styled[:, :i] = w_source
    styled_rec.append(G([w_styled], w_space=True).cpu().detach())    

In [None]:
plt.figure(dpi=250)
plt.axis('off')
plt.imshow(to_image_grid(torch.cat(styled_rec)))