In [None]:
#@title Resources
import multiprocessing
import torch
from psutil import virtual_memory

ram_gb = round(virtual_memory().total / 1024**3, 1)

print('CPU:', multiprocessing.cpu_count())
print('RAM GB:', ram_gb)
print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("cuDNN version:", torch.backends.cudnn.version())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device.type)

!nvidia-smi

In [None]:
#@title 1. Cloning taming-transformers
!git clone https://github.com/CompVis/taming-transformers.git

Cloning into 'taming-transformers'...
remote: Enumerating objects: 1020, done.[K
remote: Counting objects: 100% (224/224), done.[K
remote: Compressing objects: 100% (221/221), done.[K
remote: Total 1020 (delta 3), reused 213 (delta 3), pack-reused 796[K
Receiving objects: 100% (1020/1020), 350.30 MiB | 36.57 MiB/s, done.
Resolving deltas: 100% (205/205), done.


In [None]:
#@title 2. Installing dependencies
!pip install omegaconf > /dev/null
!pip install pytorch_lightning > /dev/null
!pip install einops > /dev/null
!pip install DALL-E > /dev/null

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.[0m


In [None]:
#@title 3. Lbraries importing
import io
import os
import sys
import yaml
import gdown
from math import sqrt
sys.path.append("./taming-transformers")

import requests
import numpy as np
import PIL
from PIL import Image
from PIL import ImageDraw, ImageFont
from matplotlib import pyplot as plt
from omegaconf import OmegaConf
from einops import rearrange

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF

import torch
torch.set_grad_enabled(False);

from taming.models.vqgan import VQModel, GumbelVQ

In [None]:
#@title 4. Configs to download models from google disk
models_folder = './models'
configs_folder = './configs'

os.makedirs(models_folder, exist_ok=True)
os.makedirs(configs_folder, exist_ok=True)

models_storage = [
{
    'id': '1COec-dpskvwHbIl9QA8qy_9nuOzsRLkl',
    'name': 'encoder.pkl',
},
{
    'id': '1pCIvZnVrzA968dqSAi2OEj299Y9YLcDQ',
    'name': 'decoder.pkl',
},
{
    'id': '1yB5nPXiJqYnoBEOannq_M5JJ2lpzhp3T',
    'name': 'vqgan.16384.model.ckpt',
},
{
    'id': '1UHuUUWX5F4y17oaW8sWuDzrsXyExU-rK',
    'name': 'vqgan.gumbelf8.model.ckpt',
},
{
    'id': '1WP6Li2Po8xYcQPGMpmaxIlI1yPB5lF5m',
    'name': 'sber.gumbelf8.ckpt',
},
]

configs_storage = [{
    'id': '1mXu9ThC3ET_uFGPwCYKCbOXqwma7wHo-',
    'name': 'vqgan.16384.config.yml',
},{
    'id': '1M7RvSoiuKBwpF-98sScKng0lsZnwFebR',
    'name': 'vqgan.gumbelf8.config.yml',
}]

In [None]:
#@title 5. Models downloading
url_template = 'https://drive.google.com/uc?id={}'

for item in models_storage:
    out_name = os.path.join(models_folder, item['name'])
    url = url_template.format(item['id'])
    gdown.download(url, out_name, quiet=True)

for item in configs_storage:
    out_name = os.path.join(configs_folder, item['name'])
    url = url_template.format(item['id'])
    gdown.download(url, out_name, quiet=True)

Downloading...
From: https://drive.google.com/uc?id=1COec-dpskvwHbIl9QA8qy_9nuOzsRLkl
To: /content/models/encoder.pkl
215MB [00:03, 64.1MB/s]
Downloading...
From: https://drive.google.com/uc?id=1pCIvZnVrzA968dqSAi2OEj299Y9YLcDQ
To: /content/models/decoder.pkl
175MB [00:02, 59.0MB/s]
Downloading...
From: https://drive.google.com/uc?id=1yB5nPXiJqYnoBEOannq_M5JJ2lpzhp3T
To: /content/models/vqgan.16384.model.ckpt
980MB [00:10, 96.0MB/s]
Downloading...
From: https://drive.google.com/uc?id=1UHuUUWX5F4y17oaW8sWuDzrsXyExU-rK
To: /content/models/vqgan.gumbelf8.model.ckpt
377MB [00:03, 118MB/s]
Downloading...
From: https://drive.google.com/uc?id=1WP6Li2Po8xYcQPGMpmaxIlI1yPB5lF5m
To: /content/models/sber.gumbelf8.ckpt
920MB [00:09, 93.8MB/s]
Downloading...
From: https://drive.google.com/uc?id=1mXu9ThC3ET_uFGPwCYKCbOXqwma7wHo-
To: /content/configs/vqgan.16384.config.yml
100%|██████████| 692/692 [00:00<00:00, 330kB/s]
Downloading...
From: https://drive.google.com/uc?id=1M7RvSoiuKBwpF-98sScKng0lsZnw

In [None]:
#@title 6. VAE model
def map_pixels(x, eps=0.1):
    return (1 - 2 * eps) * x + eps


def unmap_pixels(x, eps=0.1):
    return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1)


class VQVAE(torch.nn.Module):
    def __init__(self, enc_path, dec_path):
        super().__init__()

        self.enc = torch.load(enc_path, map_location=torch.device('cpu'))
        self.dec = torch.load(dec_path, map_location=torch.device('cpu'))

        self.num_layers = 3
        self.image_size = 256
        self.num_tokens = 8192

    @torch.no_grad()
    def get_codebook_indices(self, img):
        img = map_pixels(img)
        z_logits = self.enc.blocks(img)
        z = torch.argmax(z_logits, dim=1)
        return rearrange(z, 'b h w -> b (h w)')

    def decode(self, img_seq):
        b, n = img_seq.shape
        img_seq = rearrange(img_seq, 'b (h w) -> b h w', h=int(sqrt(n)))

        z = torch.nn.functional.one_hot(img_seq, num_classes=self.num_tokens)
        z = rearrange(z, 'b h w c -> b c h w').float()
        x_stats = self.dec(z).float()
        x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
        return x_rec

    def forward(self, img):
        raise NotImplementedError

In [None]:
#@title 7. Additional functions
def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return PIL.Image.open(io.BytesIO(resp.content))

def load_config(config_path, display=False):
    config = OmegaConf.load(config_path)
    if display:
        print(yaml.dump(OmegaConf.to_container(config)))
    return config

def preprocess(img, target_image_size=256, map_dalle=True):
    s = min(img.size)
        
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    if map_dalle: 
        img = map_pixels(img)
    return img

def load_vqgan(config, ckpt_path=None, is_gumbel=False):
    if is_gumbel:
        model = GumbelVQ(**config.model.params)
    else:
        model = VQModel(**config.model.params)
    if ckpt_path is not None:
        sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
        missing, unexpected = model.load_state_dict(sd, strict=False)
    return model.eval()

def preprocess_vqgan(x):
    x = 2.*x - 1.
    return x

def map_pixels(x, eps=0.1):
    return (1 - 2 * eps) * x + eps

def vae_postprocess(x):
    x = x.detach().cpu()
    x = torch.clamp(x, 0., 1.)
    x = x.permute(1,2,0).numpy()
    x = (255*x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x
    
def vqgan_postprocess(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.)/2.
    x = x.permute(1,2,0).numpy()
    x = (255*x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

def reconstruct_with_vqgan(x, model):
    with torch.no_grad():
        z, _, [_, _, indices] = model.encode(x)
        xrec = model.decode(z)
    return xrec

def reconstruct_with_vae(x, model):
    with torch.no_grad():
        img_seq = model.get_codebook_indices(x)
        out_img = model.decode(img_seq)
    return out_img

In [None]:
#@title 8. Config with models info

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

models_info = [{
    'model_name': 'VAE',
    'enc_path': './models/encoder.pkl',
    'dec_path': './models/decoder.pkl',
},{
    'model_name': '16384',
    'config_path': './configs/vqgan.16384.config.yml',
    'ckpt_path': './models/vqgan.16384.model.ckpt',
    'is_gumbel': False,
},{
    'model_name': 'gumbelf8',
    'config_path': './configs/vqgan.gumbelf8.config.yml',
    'ckpt_path': './models/vqgan.gumbelf8.model.ckpt',
    'is_gumbel': True,
},{
    'model_name': 'SBER-gumbelf8',
    'config_path': './configs/vqgan.gumbelf8.config.yml',
    'ckpt_path': './models/sber.gumbelf8.ckpt',
    'is_gumbel': True,
},]

In [None]:
#@title 9. Load models into memory

models = []
for model_info in models_info:
    if model_info['model_name'] == 'VAE':
        model = VQVAE(model_info['enc_path'], 
                      model_info['dec_path']).eval().to(DEVICE)
    else:
        config = load_config(model_info['config_path'], display=False)
        model = load_vqgan(config, 
                           ckpt_path=model_info['ckpt_path'], 
                           is_gumbel=model_info['is_gumbel']).to(DEVICE)
    models.append({
        'model_name': model_info['model_name'],
        'model': model,
    })
    model = None
    config = None
    del model
    del config

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

Downloading vgg_lpips model from https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1 to taming/modules/autoencoder/lpips/vgg.pth


8.19kB [00:00, 369kB/s]                    


loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.
Working with z of shape (1, 256, 32, 32) = 262144 dimensions.
Working with z of shape (1, 256, 32, 32) = 262144 dimensions.


In [None]:
#@title 10. Functions for rendering images

def stack_reconstructions(images):
    gt_img = images[0]['image']
    w, h = gt_img.size[0], gt_img.size[1]
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSans-BoldItalic.ttf", 22)
    imgs_count = len(images)
    img = Image.new("RGB", (imgs_count*w, h))
    
    for i, pr_img in enumerate(images):
        img.paste(pr_img['image'], (i*w, 0))
        ImageDraw.Draw(img).text((i*w, 0), pr_img['title'], (255, 0, 0), font=font)
    return img

def reconstruction_pipeline_by_url(models, url, size=256):
    original_image = download_image(url)
    img = preprocess(original_image, target_image_size=size)
    images = [{
        'image': vae_postprocess(img[0]), 
        'title': 'original',
    }]
    for model in models:
        map_dalle = model['model_name'] == 'VAE'
        
        x = preprocess(original_image, target_image_size=size, map_dalle=map_dalle)
        x = x.to(DEVICE)
        
        if model['model_name'] == 'VAE':
            pr_imgs = reconstruct_with_vae(x, model['model'])
            pr_img = vae_postprocess(pr_imgs[0])
        else:
            pr_imgs = reconstruct_with_vqgan(preprocess_vqgan(x), model['model'])
            pr_img = vqgan_postprocess(pr_imgs[0])
        
        images.append({
            'image': pr_img,
            'title': model['model_name']
        })
    
    img = stack_reconstructions(images)
    
    
    plt.figure(figsize=(20, 10))
    plt.imshow(img)
    plt.axis(False)
    plt.show();

In [None]:
img_urls = [
    'https://static.remove.bg/remove-bg-web/3661dd45c31a4ff23941855a7e4cedbbf6973643/assets/start-0e837dcc57769db2306d8d659f53555feb500b3c5d456879b9c843d1872e7baa.jpg'
    ,'https://heibox.uni-heidelberg.de/f/be6f4ff34e1544109563/?dl=1'
    ,'https://static.wikia.nocookie.net/walkingdead/images/d/d9/Арнольд_Шварценеггер.jpg/revision/latest/scale-to-width-down/700?cb=20170729224452&path-prefix=ru'
    ,'https://portret.msk.ru/images/Portret_na_kone/00016%D0%9F%D0%BE%D1%80%D1%82%D1%80%D0%B5%D1%82%20%D0%BD%D0%B0%20%D0%BA%D0%BE%D0%BD%D0%B5.jpg'
    ,'https://i.ytimg.com/vi/yTrB1SnUYPc/maxresdefault.jpg'
    ,'https://img.gazeta.ru/files3/291/12697291/upload-RTX75B8J-pic905-895x505-4746.jpg'
    ,'https://vyveski66.ru/userfiles/shop/slider/467_skupka-1.jpg'
    ,'https://tengrinews.kz/userdata/images/u269/2baed16bc6b14b72fe92605122a4c29c.jpg'
]

for img_url in img_urls:
    reconstruction_pipeline_by_url(models, url=img_url, size=256)

Output hidden; open in https://colab.research.google.com to view.