# Setup the environment

In [None]:
!pip install kornia==0.4.0 tqdm==4.45.0 sk-video==1.1.10

In [None]:
!git clone https://github.com/tals/derivative-works
!git clone https://github.com/jacobgil/dlib_facedetector_pytorch
    
 # Used for BigGAN source image generation and there's still some bad dependencies on it
!pip install pytorch_pretrained_biggan

In [None]:
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd(), "derivative-works/research"))
sys.path.insert(0, os.path.join(os.getcwd(), "dlib_facedetector_pytorch"))

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys, math, json, random
import numpy as np
import kornia
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import torch.nn.functional as F
from tqdm.autonotebook import tqdm
from pytorch_pretrained_biggan import BigGAN, convert_to_images
from src.notebook_utils import imshow, imgrid, pltshow, draw_tensors
from src.pytorch_utils import augment
from src.palette import random_biggan, load_directory, load_images
from src.collage import Collager
from src.collage_save import CollageSaver

In [None]:
img_size = 512

# Face Loss

In [None]:
import os, sys
import dlib_torch_converter
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as transforms
from pathlib import Path

MODEL_PATH = Path(dlib_torch_converter.__file__).parent / 'face.xml'
assert MODEL_PATH.exists()

class DlibFaceLoss:
    def __init__(self, filter_index=1, target_image_path=None):
        self.filter_index = filter_index
        self.dlib_model = dlib_torch_converter.get_model(str(MODEL_PATH)).eval().cuda()
        self.model = nn.Sequential(*[self.dlib_model._modules[i] \
                                   for i in list(self.dlib_model._modules.keys())[:-2]])
        self.model.eval()
        self.model.zero_grad()
        self.target_activations = None
        if target_image_path:
            target_ten = transforms.ToTensor()(Image.open(target_image_path)).unsqueeze(0)
            target_ten = F.interpolate(target_ten, size=(128, 128), mode='bilinear')
            self.target_activations = self.model(target_ten.cuda()).detach()

    def __call__(self, img_tensors):
        # [0, 1] input range
        self.model.zero_grad()
        img_tensors = F.interpolate(img_tensors, size=(128, 128), mode='bilinear')

        out = self.model(img_tensors)
        size = out.size(2)
        if self.target_activations is not None:
            # loss = torch.dist(out[0, :5], self.target_activations[0, :5])
            loss = torch.dist(out, self.target_activations)
        else:
            # Take the middle pixel in the image.
            if self.filter_index == 'all':
                loss = -out[:, :, size//2, size//2]
            else:
                loss = -out[:, self.filter_index, size//2, size//2]
        return loss

face_loss = DlibFaceLoss()

# Mask generator

In [None]:
from src.gan import Generator
mask_generator = Generator(img_size=128, latent_size=100, channels=1).cuda()
!gdown "https://drive.google.com/u/0/uc?id=1IhoB6lxbKxL66F0X99ntL-t3-XKnxDPZ&export=download"

model_path = 'deriv_works_dcgan_gen_128'
mask_generator.load_state_dict(torch.load(model_path))
mask_generator.eval()
None

In [None]:
masks = mask_generator(torch.randn(10, 100).cuda())
pltshow(np.hstack(masks[:, 0].detach().cpu().numpy()))

# Make or load the palette

In [None]:
import subprocess
def download_urls(urls):
    p = subprocess.run(['wget', '-i', "-", "-P", "dataset"], input="\n".join(urls), universal_newlines=True)
    return [f"dataset/{x.split('/')[-1]}" for x in urls]

urls = [
        "https://artbreeder.b-cdn.net/imgs/afc622a41966e3482a17.jpeg",
        "https://artbreeder.b-cdn.net/imgs/e8f11a059e51ce49f1fb.jpeg",
        "https://artbreeder.b-cdn.net/imgs/f9c1c5f14783165a5536.jpeg",
        "https://artbreeder.b-cdn.net/imgs/fb6d9b30088fb6a2aedfdbea.jpeg",
]
dataset = download_urls(urls)
dataset

In [None]:
USE_BIGGAN = False
if USE_BIGGAN:
    n_refs = 24*2
    biggan = BigGAN.from_pretrained(f'biggan-deep-{img_size}').cuda()
    palette = random_biggan(n_refs, img_size, biggan, seed=1, truncation=.4)
else:
    img_names = dataset
    palette_imgs_large = load_images(img_names, 1024)
    palette_imgs = F.interpolate(palette_imgs_large, size=(img_size, img_size), mode='bilinear')

In [None]:
patch_per_img = 20 // palette_imgs.shape[0]

In [None]:
draw_tensors(palette_imgs)

In [None]:
collager = Collager(palette_imgs, mask_generator, img_size, patch_per_img)

# View Random

In [None]:
with torch.no_grad():
    imgs = [
        collager(*collager.makeRandom(seed=i, trans_scale=.2))[0]
        for i in range(2)
    ] 
    draw_tensors(torch.stack(imgs).squeeze())

# Optimization

In [None]:
n_steps=600
seed=8
lr=2e-2
frames = []
#save_every_step = False
if seed is not None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
collage_data = collager.makeRandom(seed=seed, trans_scale=.2)
params = collage_data
Z = collage_data[0]
for x in collage_data:
    x.requires_grad_(True)
opt = torch.optim.Adam(params, lr=lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=lr, total_steps=n_steps)
pbar = tqdm(total=n_steps)
loss_history = []

for i in range(n_steps):
    percent = i / n_steps
    pbar.update()        
    opt.zero_grad()
    fl = torch.zeros(1)
    norm_loss = .25 * Z.norm()
    img, data = collager(*collage_data, return_data=False)
    aug = augment(img, n=3)
    fl = face_loss(((aug+1)*.5)).mean()
    loss = fl + norm_loss - .01*img.mean()
    loss_history.append(loss.detach().cpu().item())
    loss.backward(retain_graph=True)
    opt.step()
    scheduler.step()
    #if save_every_step:
    #    data = export_collager(*collage_data, return_data=True)
    #    saver.save(*data, final=True)
    pbar.set_description(f"fl: {fl.item():.3f}")
    frames.append(np.array(convert_to_images(img.detach().cpu())[0]))
    if i % 50 == 0 and i > 0 or i == n_steps-1:
        draw_tensors(img)

In [None]:
opt_img, opt_collage_data, opt_history = \
    img.detach(), tuple(x.detach() for x in collage_data), loss_history
_= plt.plot(opt_history)

# Export video, highres image and masks

In [None]:
saver = CollageSaver()
saver.save_palette(palette_imgs)
print(saver.path)
saver.save_video(frames)

# Regenerate at 2x scale.
export_collager = Collager(palette_imgs_large, mask_generator, 1024, patch_per_img)
with torch.no_grad():
    hires, data = export_collager(*collage_data, return_data=True) 
    saver.save(hires, data, final=True)

if img_names:
    with open(saver.path / 'image_names.txt', 'w') as outfile:
        json.dump(img_names, outfile)

In [None]:
# View video
from IPython.display import HTML
from base64 import b64encode
mp4 = open(saver.path / (saver.key+'.mp4'),'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""<video width=400 controls><source src="%s" type="video/mp4"></video>""" % data_url)