In [None]:
import multiprocessing
import torch
from psutil import virtual_memory

ram_gb = round(virtual_memory().total / 1024**3, 1)

print('CPU:', multiprocessing.cpu_count())
print('RAM GB:', ram_gb)
print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("cuDNN version:", torch.backends.cudnn.version())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device.type)

!nvidia-smi

CPU: 2
RAM GB: 12.7
PyTorch version: 1.11.0+cu113
CUDA version: 11.3
cuDNN version: 8200
device: cuda
Mon Jul  4 19:49:23 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   51C    P8    13W /  70W |      3MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                  

In [None]:
!git clone https://github.com/CompVis/taming-transformers.git
!git clone https://github.com/fchollet/ARC.git

Cloning into 'taming-transformers'...
remote: Enumerating objects: 1335, done.[K
remote: Total 1335 (delta 0), reused 0 (delta 0), pack-reused 1335[K
Receiving objects: 100% (1335/1335), 409.77 MiB | 48.37 MiB/s, done.
Resolving deltas: 100% (277/277), done.
Cloning into 'ARC'...
remote: Enumerating objects: 1159, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 1159 (delta 0), reused 2 (delta 0), pack-reused 1156[K
Receiving objects: 100% (1159/1159), 473.41 KiB | 4.30 MiB/s, done.
Resolving deltas: 100% (670/670), done.


In [None]:
import os
os.listdir("./")

['.config', 'ARC', 'taming-transformers', 'sample_data']

In [None]:
!pip install omegaconf > /dev/null
!pip install pytorch_lightning > /dev/null
!pip install einops > /dev/null
!pip install DALL-E > /dev/null

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.[0m


In [None]:
import io
import os
import sys
import yaml
import gdown
import glob
import json
from math import sqrt
sys.path.append("./taming-transformers")

import requests
import numpy as np
import PIL
from PIL import Image
from PIL import ImageDraw, ImageFont
from matplotlib import pyplot as plt
from omegaconf import OmegaConf
from einops import rearrange

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF

import torch
torch.set_grad_enabled(False);

from taming.models.vqgan import VQModel, GumbelVQ

In [None]:
models_folder = './models'
configs_folder = './configs'

os.makedirs(models_folder, exist_ok=True)
os.makedirs(configs_folder, exist_ok=True)

models_storage = [
{
    'id': '1COec-dpskvwHbIl9QA8qy_9nuOzsRLkl',
    'name': 'encoder.pkl',
},
{
    'id': '1pCIvZnVrzA968dqSAi2OEj299Y9YLcDQ',
    'name': 'decoder.pkl',
},
{
    'id': '1yB5nPXiJqYnoBEOannq_M5JJ2lpzhp3T',
    'name': 'vqgan.16384.model.ckpt',
},
{
    'id': '1UHuUUWX5F4y17oaW8sWuDzrsXyExU-rK',
    'name': 'vqgan.gumbelf8.model.ckpt',
},
{
    'id': '1WP6Li2Po8xYcQPGMpmaxIlI1yPB5lF5m',
    'name': 'sber.gumbelf8.ckpt',
},
]

configs_storage = [{
    'id': '1mXu9ThC3ET_uFGPwCYKCbOXqwma7wHo-',
    'name': 'vqgan.16384.config.yml',
},{
    'id': '1M7RvSoiuKBwpF-98sScKng0lsZnwFebR',
    'name': 'vqgan.gumbelf8.config.yml',
}]

In [None]:
url_template = 'https://drive.google.com/uc?id={}'

for item in models_storage:
    out_name = os.path.join(models_folder, item['name'])
    url = url_template.format(item['id'])
    gdown.download(url, out_name, quiet=True)

for item in configs_storage:
    out_name = os.path.join(configs_folder, item['name'])
    url = url_template.format(item['id'])
    gdown.download(url, out_name, quiet=True)

In [None]:
def map_pixels(x, eps=0.1):
    return (1 - 2 * eps) * x + eps


def unmap_pixels(x, eps=0.1):
    return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1)


class VQVAE(torch.nn.Module):
    def __init__(self, enc_path, dec_path):
        super().__init__()

        self.enc = torch.load(enc_path, map_location=torch.device('cpu'))
        self.dec = torch.load(dec_path, map_location=torch.device('cpu'))

        self.num_layers = 3
        self.image_size = 256
        self.num_tokens = 8192

    @torch.no_grad()
    def get_codebook_indices(self, img):
        img = map_pixels(img)
        z_logits = self.enc.blocks(img)
        z = torch.argmax(z_logits, dim=1)
        return rearrange(z, 'b h w -> b (h w)')

    def decode(self, img_seq):
        b, n = img_seq.shape
        img_seq = rearrange(img_seq, 'b (h w) -> b h w', h=int(sqrt(n)))

        z = torch.nn.functional.one_hot(img_seq, num_classes=self.num_tokens)
        z = rearrange(z, 'b h w c -> b c h w').float()
        x_stats = self.dec(z).float()
        x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
        return x_rec

    def forward(self, img):
        raise NotImplementedError

In [None]:
def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return PIL.Image.open(io.BytesIO(resp.content))

def load_config(config_path, display=False):
    config = OmegaConf.load(config_path)
    if display:
        print(yaml.dump(OmegaConf.to_container(config)))
    return config

def preprocess(img, target_image_size=256, map_dalle=True):
    s = min(img.size)
        
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    if map_dalle: 
        img = map_pixels(img)
    return img

def load_vqgan(config, ckpt_path=None, is_gumbel=False):
    if is_gumbel:
        model = GumbelVQ(**config.model.params)
    else:
        model = VQModel(**config.model.params)
    if ckpt_path is not None:
        sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
        missing, unexpected = model.load_state_dict(sd, strict=False)
    return model.eval()

def preprocess_vqgan(x):
    x = 2.*x - 1.
    return x

def map_pixels(x, eps=0.1):
    return (1 - 2 * eps) * x + eps

def vae_postprocess(x):
    x = x.detach().cpu()
    x = torch.clamp(x, 0., 1.)
    x = x.permute(1,2,0).numpy()
    x = (255*x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x
    
def vqgan_postprocess(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.)/2.
    x = x.permute(1,2,0).numpy()
    x = (255*x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

def reconstruct_with_vqgan(x, model):
    with torch.no_grad():
        z, _, [_, _, indices] = model.encode(x)
        #print('z:', z.shape)
        #print('indices:', indices.shape)
        #print(indices)
        #print(z)
        xrec = model.decode(z)
    return xrec

def reconstruct_with_vae(x, model):
    with torch.no_grad():
        img_seq = model.get_codebook_indices(x)
        out_img = model.decode(img_seq)
    return out_img

In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

'''
models_info = [{
    'model_name': 'VAE',
    'enc_path': './models/encoder.pkl',
    'dec_path': './models/decoder.pkl',
},{
    'model_name': '16384',
    'config_path': './configs/vqgan.16384.config.yml',
    'ckpt_path': './models/vqgan.16384.model.ckpt',
    'is_gumbel': False,
},{
    'model_name': 'gumbelf8',
    'config_path': './configs/vqgan.gumbelf8.config.yml',
    'ckpt_path': './models/vqgan.gumbelf8.model.ckpt',
    'is_gumbel': True,
},{
    'model_name': 'SBER-gumbelf8',
    'config_path': './configs/vqgan.gumbelf8.config.yml',
    'ckpt_path': './models/sber.gumbelf8.ckpt',
    'is_gumbel': True,
},]
'''

models_info = [{
    'model_name': 'SBER-gumbelf8',
    'config_path': './configs/vqgan.gumbelf8.config.yml',
    'ckpt_path': './models/sber.gumbelf8.ckpt',
    'is_gumbel': True,
}]

In [None]:

models = []
for model_info in models_info:
    if model_info['model_name'] == 'VAE':
        model = VQVAE(model_info['enc_path'], 
                      model_info['dec_path']).eval().to(DEVICE)
    else:
        config = load_config(model_info['config_path'], display=False)
        model = load_vqgan(config, 
                           ckpt_path=model_info['ckpt_path'], 
                           is_gumbel=model_info['is_gumbel']).to(DEVICE)
    models.append({
        'model_name': model_info['model_name'],
        'model': model,
    })
    model = None
    config = None
    del model
    del config

Working with z of shape (1, 256, 32, 32) = 262144 dimensions.


In [None]:

def stack_reconstructions(images):
    gt_img = images[0]['image']
    w, h = gt_img.size[0], gt_img.size[1]
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSans-BoldItalic.ttf", 12)
    imgs_count = len(images)
    img = Image.new("RGB", (imgs_count*w, h))
    
    for i, pr_img in enumerate(images):
        img.paste(pr_img['image'], (i*w, 0))
        ImageDraw.Draw(img).text((i*w, 0), pr_img['title'], (255, 0, 0), font=font)
    return img

def reconstruction_pipeline_by_url(models, url, size=256):
    original_image = download_image(url)
    img = preprocess(original_image, target_image_size=size)
    images = [{
        'image': vae_postprocess(img[0]), 
        'title': 'original',
    }]
    for model in models:
        map_dalle = model['model_name'] == 'VAE'
        
        x = preprocess(original_image, target_image_size=size, map_dalle=map_dalle)
        x = x.to(DEVICE)
        
        if model['model_name'] == 'VAE':
            pr_imgs = reconstruct_with_vae(x, model['model'])
            pr_img = vae_postprocess(pr_imgs[0])
        else:
            pr_imgs = reconstruct_with_vqgan(preprocess_vqgan(x), model['model'])
            pr_img = vqgan_postprocess(pr_imgs[0])
        
        images.append({
            'image': pr_img,
            'title': model['model_name']
        })
    
    img = stack_reconstructions(images)
    
    
    plt.figure(figsize=(20, 10))
    plt.imshow(img)
    plt.axis(False)
    plt.show();

In [None]:
COLOR_TRANSLATOR = {0: "000000",
                    1: "0074d9",
                    2: "ff4136",
                    3: "2ecc40",
                    4: "ffdc00",
                    5: "aaaaaa",
                    6: "f012be",
                    7: "ff851b",
                    8: "7fdbff",
                    9: "870c25",
                    10: "ffffff"}

COLOR_TRANSLATOR = {k: list(int(v[i:i+2], 16) for i in (0, 2, 4)) for k, v in COLOR_TRANSLATOR.items()}
COLOR_TRANSLATOR = {k: np.array(v) for k, v in COLOR_TRANSLATOR.items()}

def pad_grid_and_convert(grid: np.ndarray, max_grid_size: int = 32) -> np.ndarray:

    x_size, y_size = grid.shape
    new_grid = np.full(shape=[max_grid_size, max_grid_size, 3], fill_value=255, dtype=np.uint8)

    x_offset = (max_grid_size - x_size) // 2
    y_offset = (max_grid_size - y_size) // 2

    for x in range(x_size):
        for y in range(y_size):
            new_grid[x + x_offset, y + y_offset] = COLOR_TRANSLATOR[grid[x, y]]

    return new_grid


def reconstruct_arc(models, size=256):
    files = list(glob.glob("./ARC/data/training/*.json")) + list(glob.glob("./ARC/data/evaluation/*.json"))

    for file in files:
        data = json.load(open(file))
        img = []

        for k in ['train', 'test']:
            for x in data[k]:
                input = np.array(x['input'])
                output = np.array(x['output'])

                for _original_image in [input, output]:

                  original_image = pad_grid_and_convert(_original_image)
                  original_image = np.repeat(original_image, axis=0, repeats=(size // 32))
                  original_image = np.repeat(original_image, axis=1, repeats=(size // 32))
                  original_image = Image.fromarray(original_image)
    
                  img = preprocess(original_image, target_image_size=size)
                  images = [{
                      'image': vae_postprocess(img[0]), 
                      'title': 'original',
                  }]
                  for model in models:
                      map_dalle = model['model_name'] == 'VAE'
                      
                      x = preprocess(original_image, target_image_size=size, map_dalle=map_dalle)
                      x = x.to(DEVICE)
                      
                      if model['model_name'] == 'VAE':
                          pr_imgs = reconstruct_with_vae(x, model['model'])
                          pr_img = vae_postprocess(pr_imgs[0])
                      else:
                          pr_imgs = reconstruct_with_vqgan(preprocess_vqgan(x), model['model'])
                          pr_img = vqgan_postprocess(pr_imgs[0])
                      
                      images.append({
                          'image': pr_img,
                          'title': model['model_name']
                      })
                  
                  img = stack_reconstructions(images)
                  
                  
                  plt.figure(figsize=(10, 5))
                  plt.imshow(img)
                  plt.axis(False)
                  plt.show();

In [None]:
reconstruct_arc(models, size=256)