In [None]:
!git clone https://github.com/nagolinc/stylegan2-pytorch.git /content/stylegan2-pytorch/
%cd /content/stylegan2-pytorch/
!git pull
!wget https://github.com/EvgenyKashin/random-colabs/releases/download/v0.1/stylegan2-car-config-f.pt -O /content/stylegan2-pytorch/stylegan2-car-config-f.pt
!wget https://github.com/EvgenyKashin/random-colabs/releases/download/v0.2/cars_dlatents.pt.zip -O /content/stylegan2-pytorch/cars_dlatents.pt

!pip install ninja ftfy regex
!wget https://openaipublic.azureedge.net/clip/bpe_simple_vocab_16e6.txt.gz -O bpe_simple_vocab_16e6.txt.gz
!wget https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt -O model_clip.pt

In [None]:
import subprocess

CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)

if CUDA_version == "10.0":
    torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
    torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
    torch_version_suffix = ""
else:
    torch_version_suffix = "+cu110"
  
!pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html

import numpy as np
import torch

print("Torch version:", torch.__version__)

In [None]:
%cd /content/stylegan2-pytorch/

In [None]:
import argparse
from argparse import Namespace
import math
import os
from collections import defaultdict
import matplotlib.pyplot as plt

import torch
from torch import optim
from torch.nn import functional as F
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
from tqdm import tqdm
import lpips

from model import Generator, Discriminator


def noise_regularize(noises):
    loss = 0

    for noise in noises:
        size = noise.shape[2]

        while True:
            loss = (
                loss
                + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
                + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
            )

            if size <= 8:
                break

            noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
            noise = noise.mean([3, 5])
            size //= 2

    return loss


def noise_normalize_(noises):
    for noise in noises:
        mean = noise.mean()
        std = noise.std()

        noise.data.add_(-mean).div_(std)


def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
    lr_ramp = min(1, (1 - t) / rampdown)
    lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
    lr_ramp = lr_ramp * min(1, t / rampup)

    return initial_lr * lr_ramp


def latent_noise(latent, strength):
    noise = torch.randn_like(latent) * strength

    return latent + noise


def make_image(tensor):
    return (
        tensor.detach()
        .clamp_(min=-1, max=1)
        .add(1)
        .div_(2)
        .mul(255)
        .type(torch.uint8)
        .permute(0, 2, 3, 1)
        .to("cpu")
        .numpy()
    )
  
def prepare_texts(texts):
    text_tokens = [tokenizer.encode("This is " + desc) for desc in texts]

    text_input = torch.zeros(len(text_tokens), model_clip.context_length, dtype=torch.long)
    sot_token = tokenizer.encoder['<|startoftext|>']
    eot_token = tokenizer.encoder['<|endoftext|>']

    for i, tokens in enumerate(text_tokens):
        tokens = [sot_token] + tokens + [eot_token]
        text_input[i, :len(tokens)] = torch.tensor(tokens)

    text_input = text_input.to(device)
    return text_input

def image_from_latent(latent):
    gen, _ = g_ema([latent], input_is_latent=True, noise=noises)
    img = make_image(gen)[0]
    return Image.fromarray(img)

In [None]:
#@title

import gzip
import html
import os
from functools import lru_cache

import ftfy
import regex as re


@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()


def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text


class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = "bpe_simple_vocab_16e6.txt.gz"):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
        merges = merges[1:49152-256-2+1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v+'</w>' for v in vocab]
        for merge in merges:
            vocab.append(''.join(merge))
        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
        pairs = get_pairs(word)

        if not pairs:
            return token+'</w>'

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
        return text

## Init CLIP and SG2

In [None]:
args = Namespace(
    ckpt='/content/stylegan2-pytorch/stylegan2-car-config-f.pt',
    size=512,
    lr_rampup=0.05,
    lr_rampdown=0.25,
    noise=0.0,
    noise_ramp=0.75,
    noise_regularize=0,
    w_plus=False,
    step=200,
    lr=0.02,
    clip_weight=1,
    lpips_weight=0,
    l2_weight=0,
    latent_reg_weight=0,
    disc_weight=0,
    truncation=0.8
)

device = "cuda"
model_clip = torch.jit.load("model_clip.pt").to(device).eval()

n_mean_latent = 10000
resize = min(args.size, 256)
resize_clip = model_clip.input_resolution.item()


tokenizer = SimpleTokenizer()

transform_gen = Compose(
    [
        Resize(resize),
        CenterCrop(resize),
        ToTensor(),
        Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)


transform_clip = Compose([
    Resize(resize_clip, interpolation=Image.BICUBIC),
    CenterCrop(resize_clip),
    ToTensor(),
    Normalize([0.48145466, 0.4578275, 0.40821073],
              [0.26862954, 0.26130258, 0.27577711])
])

class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        # The normalize code -> t.sub_(m).div_(s)
        tensor = tensor * self.std[:, None, None]
        tensor = tensor + self.mean[:, None, None]

        return tensor.clamp(0.0, 1.0)


transform_clip_after_gen = Compose([
    Resize(resize_clip, interpolation=Image.BICUBIC),
    CenterCrop(resize_clip),
    UnNormalize(torch.tensor([0.5, 0.5, 0.5]).cuda(),
                torch.tensor([0.5, 0.5, 0.5]).cuda()),
    Normalize([0.48145466, 0.4578275, 0.40821073],
              [0.26862954, 0.26130258, 0.27577711])
])


def clip_similarity_score(img, text_features):
    image_features = model_clip.encode_image(img).float()
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    similarity = text_features @ image_features.T
    return similarity.mean()


checkpoint_sg2 = torch.load(args.ckpt)
g_ema = Generator(args.size, 512, 8)
g_ema.load_state_dict(checkpoint_sg2["g_ema"], strict=True)
g_ema = g_ema.to(device).eval()

disc = Discriminator(args.size)
disc.load_state_dict(checkpoint_sg2['d'], strict=True)
disc.to(device).eval()

with torch.no_grad():
    noise_sample = torch.randn(n_mean_latent, 512, device=device)
    latent_out = g_ema.style(noise_sample)

    latent_mean = latent_out.mean(0)
    latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5

percept = lpips.PerceptualLoss(
        model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
    )

dlatents = torch.load('/content/stylegan2-pytorch/cars_dlatents.pt')
keys = ['cars_crop/5_3.jpg', 'cars_crop/3_4.jpg', 'cars_crop/8_1.jpg', 'cars_crop/9_2.jpg']

## Optimization loop

In [None]:
def optimize_latent_to_text(latent, text):
    noises_single = g_ema.make_noise()
    noises = []
    n_samples = len(text)
    for noise in noises_single:
        noises.append(noise.repeat(n_samples, 1, 1, 1).normal_())
      
    texts = prepare_texts(text)
    text_features = model_clip.encode_text(texts).detach().float()
    text_features /= text_features.norm(dim=-1, keepdim=True)


    latent_in = latent.clone()
    latent_init = latent_in.clone()
    img_gen_init, _ = g_ema([latent_init], input_is_latent=True, noise=noises,
                            truncation=args.truncation, truncation_latent=latent_mean)
    img_gen_init = img_gen_init.detach().clone()

    if args.w_plus:
        latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)

    latent_in.requires_grad = True

    if args.noise_regularize != 0:
        for noise in noises:
            noise.requires_grad = True

    params = [latent_in]
    if args.noise_regularize != 0:
        params += noises

    # optimizer = optim.SGD(params, lr=args.lr, momentum=0.9)
    optimizer = optim.Adam(params, lr=args.lr)

    pbar = tqdm(range(args.step))
    latent_path = []

    losses = defaultdict(list)

    for i in pbar:
        t = i / args.step
        lr = get_lr(t, args.lr)
        optimizer.param_groups[0]["lr"] = lr
        noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2
        latent_n = latent_noise(latent_in, noise_strength.item())

        img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises,
                          truncation=args.truncation, truncation_latent=latent_mean)

        batch, channel, height, width = img_gen.shape

        # if height > 256:
        #     factor = height // 256

        #     img_gen = img_gen.reshape(
        #         batch, channel, height // factor, factor, width // factor, factor
        #     )
        #     img_gen = img_gen.mean([3, 5])

        lpips_loss = percept(img_gen, img_gen_init).sum()
        mse_loss = F.mse_loss(img_gen, img_gen_init)
        n_loss = noise_regularize(noises)
        latent_loss = F.mse_loss(latent_init, latent_n)
        disc_score = disc(img_gen)

        img_gen = transform_clip_after_gen(img_gen)
        clip_score = clip_similarity_score(img_gen, text_features)

        loss = -clip_score * args.clip_weight + lpips_loss * args.lpips_weight \
              + mse_loss * args.l2_weight + n_loss * args.noise_regularize \
              + latent_loss * args.latent_reg_weight - disc_score * args.disc_weight

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        noise_normalize_(noises)

        
        losses['loss'].append(loss.item())
        losses['loss_grad'].append(latent_in.grad.norm().item())
        losses['clip_score'].append(clip_score.item())
        losses['lpips_loss'].append(lpips_loss.item())
        losses['latent_loss'].append(latent_loss.item())
        losses['disc_score'].append(disc_score.item())

        if (i + 1) % 10 == 0:
            latent_path.append(latent_in.detach().clone())

        pbar.set_description(
            (   f"Loss: {loss.item():.4f}; Loss gr: {latent_in.grad.norm().item():.4f}; "
                f"latent diff: {latent_loss.item():.4f}; "
              # f"D score: {disc_score.item():.4f}; "
                f"lpips: {lpips_loss.item():.4f}; "
            ), refresh=True
        )

    return latent_path, losses, noises

### The most important params for tweaking

In [None]:
args.lpips_weight = 0 # 1e-1
args.latent_reg_weight = 0 #1e-1
args.noise = 0.0
args.noise_regularize = 0
args.disc_weight = 0
args.clip_weight = 1
args.lr = 0.04
args.step = 150
args.truncation = 0.8
args.disc_weight = 0.0

In [None]:
descriptions = [
    'light weight rear wheel drive coupe highly tuned Japanese automobile reflecting the original styles of drifting with a big spoiler',
    'white car standing in front of the camera ',
    'modified big jeep car for racing',
    'car standing in front of the mountains',
    'car turned forward',
    'a car standing in front of the ocean on the sand',
    'nissan car',  # toyota Mercedes ferrari
    'black and white photo of the car'
]

In [None]:
latent_in_mean = latent_mean.detach().clone().unsqueeze(0)  #  .repeat(1, 1)
# 1 3 4 5
torch.manual_seed(4)
latent_z = torch.randn(1, 512, device=device)
latent_in_seed = g_ema.style(latent_z).detach()  # .repeat(1, 1)

latent_in_car = torch.unsqueeze(dlatents[keys[0]]['latent'], 0).detach().clone()

In [None]:
latent_path, losses, noises = optimize_latent_to_text(latent_in_car, [descriptions[6]])  # latent_in_car or latent_in_seed or latent_in_mean

## Results visualization

In [None]:
image_from_latent(latent_path[-1])

In [None]:
image_from_latent(latent_path[0])

In [None]:
for k, v in losses.items():
    plt.figure()
    plt.title(k)
    plt.plot(v)

In [None]:
path_ar = []
for i, latent in enumerate(latent_path):
    if i % 1 == 0:
        display(image_from_latent(latent))