In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
import sys
sys.path.append('..')
sys.path.append('../external_dependencies')

import numpy as np
np.bool = np.bool_
np.int = np.int_
np.float = np.float_
np.complex = np.complex_
np.object = np.object_
np.unicode = np.unicode_
np.str = np.str_

In [None]:
import dnnlib
import torch
import pickle
from torch import nn
import numpy as np
from PIL import Image, ImageDraw
from torch.nn import functional as F

from einops import rearrange
from typing import List, Union
from matplotlib import pyplot as plt
from torchvision.utils import make_grid

In [None]:
device = "cuda"

In [None]:
@torch.no_grad()
def render_tensor(img: torch.Tensor, normalize: bool = True, nrow: int = 8) -> Image.Image:
    if type(img) == list:
        img = torch.cat([i if len(i.shape) == 4 else i[None, ...] for i in img], dim=0).expand(-1, 3, -1, -1)
    elif len(img.shape) == 3:
        img = img.expand(3, -1, -1)
    elif len(img.shape) == 4:
        img = img.expand(-1, 3, -1, -1)
    
    img = img.squeeze()
    
    if normalize:
        img = img / 2 + .5
    
    if len(img.shape) == 3:
        return Image.fromarray((img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
    elif len(img.shape) == 2:
        return Image.fromarray((img.cpu().numpy() * 255).astype(np.uint8))
    elif len(img.shape) == 4:
        return Image.fromarray((make_grid(img, nrow=nrow).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))

In [None]:
def to_tensor(img: Union[Image.Image, np.ndarray], normalize=True) -> torch.Tensor:
    if isinstance(img, Image.Image):
        img = np.array(img)
        if len(img.shape) > 2:
            img = img.transpose(2, 0, 1)
        else:
            img = img[None, ...]
    else:
        if img.shape[0] == img.shape[1]:
            img = img.transpose(2, 0, 1)
    if normalize:
        img = torch.from_numpy(img).to(torch.float32) / 127.5 - 1
    else:
        img = torch.from_numpy(img).to(torch.float32) / 255.
    return img[None, ...].to(device)

In [None]:
import pickle
import random as r
from camera_utils import LookAtPoseSampler, FOV_to_intrinsics
from torch_utils import misc

In [None]:
from tqdm import tqdm

In [None]:
from external_dependencies.decalib import DECAWrapper
deca = DECAWrapper(device)

In [None]:
with open("../data/vfhq-celebv-text-64.pkl", "rb") as f:
    G = pickle.load(f)["G_ema"].to(device).eval().requires_grad_(False)
G.exp_mask = (torch.from_numpy(np.array(Image.open('../data/plane_0.png').convert('L'))).to(torch.float32) / 255.)[None, None, :, :].to(device)

In [None]:
fov_deg = 18.837
intrinsics = FOV_to_intrinsics(fov_deg, device=device)
cam_pivot = torch.tensor(G.rendering_kwargs.get('avg_camera_pivot', [0, 0, 0]), device=device)
cam_radius = G.rendering_kwargs.get('avg_camera_radius', 2.7)
conditioning_cam2world_pose = LookAtPoseSampler.sample(np.pi/2, np.pi/2, cam_pivot, radius=cam_radius, device=device)
conditioning_params = torch.cat([conditioning_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)

In [None]:
l = torch.tensor([ 0.8214,  0.0908,  0.3353, -0.1008,  0.1011,  0.1123, -0.1217, -0.1401, 0.0878])[None, ...].to(device)
wl = G.backbone.lmapping(None, l)

In [None]:
from training.encoder import Encoder
encoder = Encoder(50 + 3, 64, 128, 3, 5).eval().requires_grad_(False).to(device)

In [None]:
encoder.train().requires_grad_(True)
optimizer = torch.optim.Adam(encoder.parameters(), lr=0.001)

In [None]:
def get_coeffs(image):
    codedict = deca.encode(render_tensor(image.clamp(-1, 1)))
    return torch.cat((codedict['exp'], codedict['pose'][:, 3:]), dim=-1)

In [None]:
def batch_get_coeffs(images):
    return torch.cat([get_coeffs(image) for image in images.unbind(dim=0)])

In [None]:
from lpips import LPIPS

In [None]:
lpips_fn = LPIPS(net='vgg').to(device)

In [None]:
batch_size = 4

In [None]:
from IPython.display import clear_output

In [None]:
loss_s = []
step = 0
while True:
    z = torch.randn(batch_size, 512).to(device)
    w = G.backbone.mapping(z, conditioning_params.expand(batch_size, -1), 1.) # .5
    
    d = torch.randn(batch_size, G.d_dim).to(device)
    wd = G.backbone.dmapping(d, None)
    
    out = G.synthesis(
        w, wd, wl.expand(batch_size, -1, -1), conditioning_params.expand(batch_size, -1), 
        use_exp_mask=True
    )
    try:
        coeffs = batch_get_coeffs(out["image"])
    except:
        continue
    optimizer.zero_grad()
    
    pred_wd = encoder(coeffs)[:, None, :] + G.backbone.dmapping.w_avg[None, None, :]
    # pred_wd = G.backbone.dmapping(pred_d, None)
    pred_out = G.synthesis(
        w, pred_wd.expand_as(wd), wl.expand(batch_size, -1, -1), conditioning_params.expand(batch_size, -1), 
        use_exp_mask=True
    )
    
    image_loss = torch.nn.L1Loss()(out["image"], pred_out["image"]) + lpips_fn(out["image"], pred_out["image"]).mean()
    code_loss = 1. - torch.nn.CosineSimilarity(dim=-1)(wd[:, 0, :], pred_wd[:, 0, :]).mean() # + torch.nn.L1Loss()(d, pred_d)
    
    # print(image_loss, code_loss)
    
    loss = image_loss + code_loss * 0.1
    loss.backward()
    loss_s.append(float(loss))
    optimizer.step()
    
    step += 1
    
    if step % 100 == 0 or step == 1:
        clear_output(wait=True)
        print(step)
        plt.plot(loss_s[-100:])
        plt.show()
        display(render_tensor([out["image"].clamp(-1, 1), pred_out["image"].clamp(-1, 1)], nrow=batch_size))

In [None]:
torch.save(encoder.state_dict(), 'encoder.pth')