In [None]:
#@title Resources
import multiprocessing
import torch
from psutil import virtual_memory

ram_gb = round(virtual_memory().total / 1024**3, 1)

print('CPU:', multiprocessing.cpu_count())
print('RAM GB:', ram_gb)
print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("cuDNN version:", torch.backends.cudnn.version())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device.type)

!nvidia-smi

CPU: 2
RAM GB: 12.7
PyTorch version: 1.9.0+cu102
CUDA version: 10.2
cuDNN version: 7605
device: cuda
Tue Sep 28 06:49:39 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.63.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    26W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                   

In [None]:
#@title 1. Cloning taming-transformers
!git clone https://github.com/CompVis/taming-transformers.git

Cloning into 'taming-transformers'...
remote: Enumerating objects: 1020, done.[K
remote: Counting objects: 100% (224/224), done.[K
remote: Compressing objects: 100% (221/221), done.[K
remote: Total 1020 (delta 3), reused 213 (delta 3), pack-reused 796[K
Receiving objects: 100% (1020/1020), 350.30 MiB | 19.32 MiB/s, done.
Resolving deltas: 100% (206/206), done.


In [None]:
#@title 2. Dataset downloading
!mkdir -p ./coco/images

!wget http://images.cocodataset.org/zips/val2017.zip -P ./coco/images
!unzip -q ./coco/images/val2017.zip -d ./coco/images
!rm ./coco/images/val2017.zip

!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -P ./coco
!unzip -q ./coco/annotations_trainval2017.zip -d ./coco
!rm ./coco/annotations_trainval2017.zip

## !rm -r ./coco

--2021-09-28 06:49:59--  http://images.cocodataset.org/zips/val2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 52.216.100.251
Connecting to images.cocodataset.org (images.cocodataset.org)|52.216.100.251|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 815585330 (778M) [application/zip]
Saving to: ‘./coco/images/val2017.zip’


2021-09-28 06:50:49 (15.7 MB/s) - ‘./coco/images/val2017.zip’ saved [815585330/815585330]

--2021-09-28 06:50:55--  http://images.cocodataset.org/annotations/annotations_trainval2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 52.217.194.201
Connecting to images.cocodataset.org (images.cocodataset.org)|52.217.194.201|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 252907541 (241M) [application/zip]
Saving to: ‘./coco/annotations_trainval2017.zip’


2021-09-28 06:51:03 (32.0 MB/s) - ‘./coco/annotations_trainval2017.zip’ saved [252907541/252907541]



In [None]:
#@title 3. Installing dependencies
!pip install omegaconf > /dev/null
!pip install pytorch_lightning > /dev/null
!pip install einops > /dev/null
!pip install DALL-E > /dev/null
!pip install torch_fidelity > /dev/null

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.[0m


In [None]:
#@title 4. Lbraries importing
import io
import os
import sys
import yaml
import gdown
import glob
import random
from math import sqrt
sys.path.append("./taming-transformers")
import warnings
warnings.filterwarnings('ignore')

import requests
import numpy as np
import pandas as pd
import PIL
from PIL import Image
from PIL import ImageDraw, ImageFont
from matplotlib import pyplot as plt
from omegaconf import OmegaConf
from einops import rearrange
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from torch_fidelity.metrics import calculate_metrics

torch.set_grad_enabled(False);

from taming.models.vqgan import VQModel, GumbelVQ

In [None]:
#@title 5. Random seed
def seed_everything(seed=17):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()

In [None]:
#@title 6. Configs to download models from google disk
models_folder = './models'
configs_folder = './configs'

os.makedirs(models_folder, exist_ok=True)
os.makedirs(configs_folder, exist_ok=True)

models_storage = [
{
    'id': '1COec-dpskvwHbIl9QA8qy_9nuOzsRLkl',
    'name': 'encoder.pkl',
},
{
    'id': '1pCIvZnVrzA968dqSAi2OEj299Y9YLcDQ',
    'name': 'decoder.pkl',
},
{
    'id': '1yB5nPXiJqYnoBEOannq_M5JJ2lpzhp3T',
    'name': 'vqgan.16384.model.ckpt',
},
{
    'id': '1UHuUUWX5F4y17oaW8sWuDzrsXyExU-rK',
    'name': 'vqgan.gumbelf8.model.ckpt',
},
{
    'id': '1WP6Li2Po8xYcQPGMpmaxIlI1yPB5lF5m',
    'name': 'sber.gumbelf8.ckpt',
},
]

configs_storage = [{
    'id': '1mXu9ThC3ET_uFGPwCYKCbOXqwma7wHo-',
    'name': 'vqgan.16384.config.yml',
},{
    'id': '1M7RvSoiuKBwpF-98sScKng0lsZnwFebR',
    'name': 'vqgan.gumbelf8.config.yml',
}]

In [None]:
#@title 7. Models downloading
url_template = 'https://drive.google.com/uc?id={}'

for item in models_storage:
    out_name = os.path.join(models_folder, item['name'])
    url = url_template.format(item['id'])
    gdown.download(url, out_name, quiet=True)

for item in configs_storage:
    out_name = os.path.join(configs_folder, item['name'])
    url = url_template.format(item['id'])
    gdown.download(url, out_name, quiet=True)

In [None]:
#@title 8. Downloading file with splitted data
csv_folder = './csv'
os.makedirs(csv_folder, exist_ok=True)

csv_id = '1_ccngzEbw_NiqsVH60lv8S8I82JrSNPv'
url = url_template.format(csv_id)
out_name = os.path.join(csv_folder, 'categories.csv')

gdown.download(url, out_name, quiet=True)

'./csv/categories.csv'

In [None]:
#@title 9. VAE model
def map_pixels(x, eps=0.1):
    return (1 - 2 * eps) * x + eps


def unmap_pixels(x, eps=0.1):
    return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1)


class VQVAE(torch.nn.Module):
    def __init__(self, enc_path, dec_path):
        super().__init__()

        self.enc = torch.load(enc_path, map_location=torch.device('cpu'))
        self.dec = torch.load(dec_path, map_location=torch.device('cpu'))

        self.num_layers = 3
        self.image_size = 256
        self.num_tokens = 8192

    @torch.no_grad()
    def get_codebook_indices(self, img):
        img = map_pixels(img)
        z_logits = self.enc.blocks(img)
        z = torch.argmax(z_logits, dim=1)
        return rearrange(z, 'b h w -> b (h w)')

    def decode(self, img_seq):
        b, n = img_seq.shape
        img_seq = rearrange(img_seq, 'b (h w) -> b h w', h=int(sqrt(n)))

        z = torch.nn.functional.one_hot(img_seq, num_classes=self.num_tokens)
        z = rearrange(z, 'b h w c -> b c h w').float()
        x_stats = self.dec(z).float()
        x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
        return x_rec

    def forward(self, img):
        raise NotImplementedError

In [None]:
#@title 10. Additional functions
def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return PIL.Image.open(io.BytesIO(resp.content))

def load_config(config_path, display=False):
    config = OmegaConf.load(config_path)
    if display:
        print(yaml.dump(OmegaConf.to_container(config)))
    return config

def preprocess(img, target_image_size=256, map_dalle=True):
    s = min(img.size)
        
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = T.ToTensor()(img)
    if map_dalle: 
        img = map_pixels(img)
    return img


def load_vqgan(config, ckpt_path=None, is_gumbel=False):
    if is_gumbel:
        model = GumbelVQ(**config.model.params)
    else:
        model = VQModel(**config.model.params)
    if ckpt_path is not None:
        sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
        missing, unexpected = model.load_state_dict(sd, strict=False)
    return model.eval()

def preprocess_vqgan(x):
    x = 2.*x - 1.
    return x

def map_pixels(x, eps=0.1):
    return (1 - 2 * eps) * x + eps

def vae_postprocess(x):
    x = x.detach().cpu()
    x = torch.clamp(x, 0., 1.)
    x = x.permute(1,2,0).numpy()
    x = (255*x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

def vqgan_postprocess(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.)/2.
    x = x.permute(1,2,0).numpy()
    x = (255*x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

def reconstruct_with_vqgan(x, model):
    z, _, [_, _, indices] = model.encode(x)
    xrec = model.decode(z)
    return xrec

def reconstruct_with_vae(x, model):
    img_seq = model.get_codebook_indices(x)
    out_img = model.decode(img_seq)
    return out_img

def load_image(tar_path, img_name):
    with tarfile.open(tar_path) as tar:
        if img_name not in tar.getnames():
            return None
        member = tar.getmember(img_name)
        image = tar.extractfile(member)
        image = PIL.Image.open(io.BytesIO(image.read()))
    return image

def find_image_from_tars(tar_files, img_name):
    image = None
    for i, tar_path in enumerate(tar_files):
        image = load_image(tar_path, img_name)
        if image is not None:
            break
    return image

In [None]:
#@title 11. Config with models info
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

models_info = [{
    'model_name': 'VAE',
    'enc_path': './models/encoder.pkl',
    'dec_path': './models/decoder.pkl',
},{
    'model_name': '16384',
    'config_path': './configs/vqgan.16384.config.yml',
    'ckpt_path': './models/vqgan.16384.model.ckpt',
    'is_gumbel': False,
},{
    'model_name': 'gumbelf8',
    'config_path': './configs/vqgan.gumbelf8.config.yml',
    'ckpt_path': './models/vqgan.gumbelf8.model.ckpt',
    'is_gumbel': True,
},{
    'model_name': 'SBER-gumbelf8',
    'config_path': './configs/vqgan.gumbelf8.config.yml',
    'ckpt_path': './models/sber.gumbelf8.ckpt',
    'is_gumbel': True,
},]

In [None]:
#@title 12. Load models into memory
models = []
for model_info in models_info:
    if model_info['model_name'] == 'VAE':
        model = VQVAE(model_info['enc_path'], 
                      model_info['dec_path']).eval().to(DEVICE)
    else:
        config = load_config(model_info['config_path'], display=False)
        model = load_vqgan(config, 
                           ckpt_path=model_info['ckpt_path'], 
                           is_gumbel=model_info['is_gumbel']).to(DEVICE)
    models.append({
        'model_name': model_info['model_name'],
        'model': model,
    })
    model = None
    config = None
    del model
    del config

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

Downloading vgg_lpips model from https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1 to taming/modules/autoencoder/lpips/vgg.pth


8.19kB [00:00, 5.64MB/s]                   

loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.





Working with z of shape (1, 256, 32, 32) = 262144 dimensions.
Working with z of shape (1, 256, 32, 32) = 262144 dimensions.


In [None]:
#@title 13. Domains and parameters
domains = [
    'indoor', 'kitchen', 'appliance', 'electronic', 'furniture', 'outdoor', 
    'sports', 'food', 'vehicle', 'animal', 'accessory', 'person', 'face', 'text'
]

BATCH_SIZE = 32
NUM_WORKERS = 2
IMG_SIZE = 256

In [None]:
#@title 14. Creating folder for domains
IMAGES_FOLDER = './images'
os.makedirs(IMAGES_FOLDER, exist_ok=True)

df = pd.read_csv('./csv/categories.csv', index_col=0)

for _ in range(2):
    for domain in domains:
        for model in models:
            if 'folder_path' in model.keys():
                model['folder_path'][domain] = os.path.join(IMAGES_FOLDER, model['model_name'], domain)
                os.makedirs(model['folder_path'][domain], exist_ok=True)
            else:
                model['folder_path'] = {
                    domain: os.path.join(IMAGES_FOLDER, model['model_name'], domain)
                }

In [None]:
#@title 15. Dict. Image name - Image Path
orig_images_folder = '/content/coco/images/val2017/*.jpg'
all_images = glob.glob(orig_images_folder)

img_name2path = {}
for i, row in tqdm(df.iterrows()):
    img_name = row['index']
    cur_path = None
    for img_path in all_images:
        if img_name in img_path:
          cur_path = img_path
    if cur_path is None:
        raise
    
    img_name2path[img_name] = cur_path

4873it [00:03, 1601.05it/s]


In [None]:
#@title 16. Dataset class
class DatasetRetriever(Dataset):
    def __init__(self, img_paths, map_dalle=False):
        self.img_paths = img_paths
        self.map_dalle = map_dalle

    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = PIL.Image.open(img_path).convert('RGB')
        x_vqgan = preprocess(image, target_image_size=IMG_SIZE, map_dalle=self.map_dalle)
        return x_vqgan, os.path.basename(img_path)

In [None]:
#@title 17. Inference
for model in models:
    for domain in domains:
        df_domain = df[df[domain] == True].copy()
        
        map_dalle = model['model_name'] == 'VAE'
        selected_images = [x for x in all_images if os.path.basename(x) in df_domain['index'].values]
        dataset = DatasetRetriever(selected_images, map_dalle=map_dalle)
        dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=SequentialSampler(dataset), 
                                pin_memory=False, drop_last=False, num_workers=NUM_WORKERS,)

        for images, img_names in dataloader:
            images = images.to(DEVICE)
            if model['model_name'] == 'VAE':
                pr_imgs = reconstruct_with_vae(images, model['model'])
            else:
                pr_imgs = reconstruct_with_vqgan(preprocess_vqgan(images), model['model'])

            for pr_img, img_name in zip(pr_imgs, img_names):
                save_path = os.path.join(model['folder_path'][domain], img_name)
                if model['model_name'] == 'VAE':
                    vae_postprocess(pr_img).save(save_path)
                else:
                    vqgan_postprocess(pr_img).save(save_path)

In [None]:
#@title 18. Original dataset splitting
original_files_root = './images/original'
os.makedirs(original_files_root, exist_ok=True)

for domain in domains:
    df_domain = df[df[domain] == True].copy()
        
    data_folder = os.path.join(original_files_root, domain)
    os.makedirs(data_folder, exist_ok=True)
    
    for i, row in tqdm(df_domain.iterrows(), desc=domain, total=df_domain.shape[0]):
        img_name = row['index']
        # image = find_image_from_tars(tar_files, img_name)
        img_path = img_name2path[img_name]
        image = PIL.Image.open(img_path).convert('RGB')
        img_save_path = os.path.join(data_folder, img_name)
        image = image.resize((IMG_SIZE, IMG_SIZE), Image.ANTIALIAS)
        image.save(img_save_path)

indoor: 100%|██████████| 645/645 [00:08<00:00, 77.03it/s]
kitchen: 100%|██████████| 904/904 [00:11<00:00, 75.77it/s]
appliance: 100%|██████████| 316/316 [00:04<00:00, 78.69it/s]
electronic: 100%|██████████| 589/589 [00:07<00:00, 79.83it/s]
furniture: 100%|██████████| 1239/1239 [00:16<00:00, 77.17it/s]
outdoor: 100%|██████████| 553/553 [00:07<00:00, 75.90it/s]
sports: 100%|██████████| 919/919 [00:11<00:00, 78.04it/s]
food: 100%|██████████| 702/702 [00:09<00:00, 75.56it/s]
vehicle: 100%|██████████| 1139/1139 [00:14<00:00, 76.07it/s]
animal: 100%|██████████| 1003/1003 [00:13<00:00, 75.43it/s]
accessory: 100%|██████████| 718/718 [00:09<00:00, 74.57it/s]
person: 100%|██████████| 2652/2652 [00:34<00:00, 76.81it/s]
face: 100%|██████████| 1488/1488 [00:19<00:00, 76.06it/s]
text: 100%|██████████| 1706/1706 [00:22<00:00, 75.89it/s]


In [None]:
#@title 19. Metrics calculation

df_isc = pd.DataFrame(columns=['model', *domains])
df_fid = pd.DataFrame(columns=['model', *domains])


for model in models + [{'model_name': 'original'}]:
    row_isc = {'model': model['model_name']}
    row_fid = {'model': model['model_name']}
    
    for domain in domains:
        input_folder = os.path.join(IMAGES_FOLDER, model['model_name'], domain)
        original_folder = os.path.join(IMAGES_FOLDER, 'original', domain)
        
        metrics = calculate_metrics(input1=input_folder, input2=original_folder, 
                                    cuda=True, isc=True, fid=True, verbose=False,)
        
        row_isc[domain] = round(metrics['inception_score_mean'], 3)
        row_fid[domain] = round(metrics['frechet_inception_distance'], 3) 
        
    df_isc = df_isc.append(row_isc, ignore_index=True)
    df_fid = df_fid.append(row_fid, ignore_index=True)
            

Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth


  0%|          | 0.00/91.2M [00:00<?, ?B/s]

In [None]:
tdf_isc = df_isc.set_index('model').T
tdf_isc

model,VAE,16384,gumbelf8,SBER-gumbelf8,original
indoor,9.769,10.744,11.707,11.688,11.638
kitchen,9.726,11.354,12.333,12.152,11.813
appliance,5.705,6.024,6.154,6.199,5.89
electronic,7.83,9.509,9.712,9.606,9.497
furniture,10.861,13.346,14.5,14.531,14.592
outdoor,8.163,9.52,10.668,10.293,10.451
sports,7.467,8.544,8.814,8.841,8.962
food,7.954,8.725,9.39,9.434,9.191
vehicle,10.527,12.947,14.24,14.559,14.233
animal,11.933,14.249,15.999,15.879,15.857


In [None]:
tdf_fid = df_fid.set_index('model').T
tdf_fid

model,VAE,16384,gumbelf8,SBER-gumbelf8,original
indoor,74.734,57.925,45.432,44.686,-0.0
kitchen,66.424,47.086,36.735,36.579,-0.0
appliance,80.359,70.604,53.225,52.064,-0.0
electronic,77.856,64.034,50.759,50.447,-0.0
furniture,53.438,38.204,29.51,29.569,-0.0
outdoor,91.932,58.877,46.309,45.287,-0.0
sports,65.54,39.961,32.219,31.756,-0.0
food,76.974,53.109,41.018,41.413,-0.0
vehicle,60.318,34.259,26.721,26.463,-0.0
animal,64.25,41.52,32.039,32.078,-0.0


In [None]:
#@title 18. Number of images for each domain
counts = []
for domain in domains:
    imgs_path = os.path.join(IMAGES_FOLDER, 'original', domain)
    img_count = glob.glob(imgs_path + '/*.jpg')
    counts.append(len(img_count))
counts = np.array(counts)
counts

array([ 645,  904,  316,  589, 1239,  553,  919,  702, 1139, 1003,  718,
       2652, 1488, 1706])

In [None]:
# IS. Weighted average for each domain
for column in tdf_isc.columns:
    print(column, (tdf_isc[column].values * counts / counts.sum()).sum())

VAE 11.13271611884993
16384 13.647365195910245
gumbelf8 15.202788169903247
SBER-gumbelf8 15.31654546078364
original 15.278036505867014


In [None]:
# FID. Weighted average for each domain
for column in tdf_fid.columns:
    print(column, (tdf_fid[column].values * counts / counts.sum()).sum())

VAE 59.75355410691003
16384 38.912181431414254
gumbelf8 30.304300349962254
SBER-gumbelf8 30.135895423042612
original 0.0
