#**Latent Data Extraction**

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%cd /content/drive/MyDrive/AI_VIETNAM/AIO2023/Module 09/[Exercise]-Stable-Diffusion-Model/code/data

/content/drive/MyDrive/AI_VIETNAM/AIO2023/Module 09/[Exercise]-Stable-Diffusion-Model/code/data


In [1]:
!pip install -q einops diffusers accelerate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m853.4 kB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m27.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.3/297.3 kB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m53.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m47.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m65.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m731.7/731.7 MB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m410.6/410.6 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━

In [5]:
import torch
import pickle
import torchvision
import numpy as np
import torch.nn as nn
from torch.optim import Adam

from einops import einsum

from diffusers import VQModel
from transformers import CLIPTokenizer, CLIPTextModel

import os
import glob
import random
from PIL import Image
from tqdm import tqdm
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

##**Download Dataset**

In [None]:
!mkdir CelebAMask-HQ
%cd CelebAMask-HQ

# Tải file ảnh
!gdown 1jlQ8umhpJo8lVgC9q4_1q_t_Frv1kZ3f
!unzip image.zip

# Tải file mô tả của ảnh
!gdown 1X1EFCyralNN2Bg3LhelL_lShrSrmTitW
!unzip text.zip

# Tải file train và test
!gdown 1GdeTdBpi_IV7AuBpJAhLElqjswRmOy-7 -O train.pickle
!gdown 1JNxgdvPMI_HHUq2-JUuJp8L7cD-74OAf -O test.pickle

!mv images CelebA-HQ-img
!rm image.zip text.zip

##**Dataset**

In [6]:
def load_latents(latent_path):
    r"""
    Simple utility to save latents to speed up ldm training
    :param latent_path:
    :return:
    """
    latent_maps = {}
    for fname in glob.glob(os.path.join(latent_path, '*.pkl')):
        s = pickle.load(open(fname, 'rb'))
        for k, v in s.items():
            latent_maps[k] = v[0]
    return latent_maps

In [7]:
class CelebDataset(Dataset):
    r"""
    Celeb dataset will by default centre crop and resize the images.
    This can be replaced by any other dataset. As long as all the images
    are under one directory.
    """

    def __init__(self, split, im_path, im_size=256, im_channels=3, im_ext='jpg',
                 use_latents=False, latent_path=None, condition_config=None):
        self.split = split
        if self.split != 'all':
          self.split_filter = pickle.load(open(f'/content/data/CelebAMask-HQ/{self.split}.pickle', 'rb'))

        self.im_size = im_size
        self.im_channels = im_channels
        self.im_ext = im_ext
        self.im_path = im_path
        self.latent_maps = None
        self.use_latents = False

        self.condition_types = [] if condition_config is None else condition_config['condition_types']
        self.images, self.texts = self.load_images(im_path)

        # Whether to load images or to load latents
        if use_latents and latent_path is not None:
            latent_maps = load_latents(latent_path)
            if len(latent_maps) == len(self.images):
                self.use_latents = True
                self.latent_maps = latent_maps
                print('Found {} latents'.format(len(self.latent_maps)))
            else:
                print('Latents not found')

    def load_images(self, im_path):
        r"""
        Gets all images from the path specified
        and stacks them all up
        """
        assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
        ims = []
        fnames = glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('png')))
        fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpg')))
        fnames += glob.glob(os.path.join(im_path, 'CelebA-HQ-img/*.{}'.format('jpeg')))
        texts = []

        for fname in tqdm(fnames):
            im_name = os.path.split(fname)[1].split('.')[0]

            if self.split != 'all':
              if im_name not in self.split_filter:
                continue

            ims.append(fname)

            if 'text' in self.condition_types:
                captions_im = []
                with open(os.path.join(im_path, 'celeba-caption/{}.txt'.format(im_name))) as f:
                    for line in f.readlines():
                        captions_im.append(line.strip())
                texts.append(captions_im)

        if 'text' in self.condition_types:
            assert len(texts) == len(ims), "Condition Type Text but could not find captions for all images"

        print('Found {} images'.format(len(ims)))
        print('Found {} captions'.format(len(texts)))

        return ims, texts

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        ######## Set Conditioning Info ########
        cond_inputs = {}
        if 'text' in self.condition_types:
            cond_inputs['text'] = random.sample(self.texts[index], k=1)[0]
        #######################################

        if self.use_latents:
            latent = self.latent_maps[self.images[index]]
            if len(self.condition_types) == 0:
                return latent
            else:
                return latent, cond_inputs
        else:
            im = Image.open(self.images[index])
            im_tensor = torchvision.transforms.Compose([
                torchvision.transforms.Resize(self.im_size),
                torchvision.transforms.CenterCrop(self.im_size),
                torchvision.transforms.ToTensor(),
            ])(im)
            im.close()

            # Convert input to -1 to 1 range.
            im_tensor = (2 * im_tensor) - 1
            if len(self.condition_types) == 0:
                return im_tensor
            else:
                return im_tensor, cond_inputs

##**Model Config**

In [8]:
config = {
    "dataset_params": {
        "im_path": "data/CelebAMask-HQ",
        "im_channels": 3,
        "im_size": 256,
        "name": "celebhq"
    },
    "diffusion_params": {
        "num_timesteps": 1000,
        "beta_start": 0.00085,
        "beta_end": 0.012
    },
    "ldm_params": {
        "down_channels": [256, 384, 512, 768],
        "mid_channels": [768, 512],
        "down_sample": [True, True, True],
        "attn_down": [True, True, True],
        "time_emb_dim": 512,
        "norm_channels": 32,
        "num_heads": 16,
        "conv_out_channels": 128,
        "num_down_layers": 2,
        "num_mid_layers": 2,
        "num_up_layers": 2,
        "condition_config": {
            "condition_types": ["text"],
            "text_condition_config": {
                "text_embed_model": "clip",
                "train_text_embed_model": False,
                "text_embed_dim": 512,
                "cond_drop_prob": 0.1
            }
        }
    },
    "train_params": {
        "seed": 1111,
        "task_name": "celebhq",
        "ldm_batch_size": 16,
        "ldm_epochs": 100,
        "num_samples": 1,
        "num_grid_rows": 1,
        "ldm_lr": 0.000005,
        "save_latents": True,
        "vqvae_latent_dir_name": 'vqvae_latents',
        "cf_guidance_scale": 1.0,
        "ldm_ckpt_name": "ddpm_ckpt_text_cond_clip.pth",
    }
}


##**Extract Latent Data using VQModel**

In [12]:
%cd /content/drive/MyDrive/AI_VIETNAM/AIO2023/Module 09/[Exercise]-Stable-Diffusion-Model/code

/content/drive/MyDrive/AI_VIETNAM/AIO2023/Module 09/[Exercise]-Stable-Diffusion-Model/code


In [None]:
dataset_config = config['dataset_params']
train_config = config['train_params']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

im_dataset = CelebDataset(split='all',
                          im_path=dataset_config['im_path'],
                          im_size=dataset_config['im_size'],
                          im_channels=dataset_config['im_channels'])
data_loader = DataLoader(im_dataset, batch_size=1, shuffle=False)

num_images = train_config['num_samples']
ngrid = train_config['num_grid_rows']

idxs = torch.randint(0, len(im_dataset) - 1, (num_images,))
ims = torch.cat([im_dataset[idx][None, :] for idx in idxs]).float()
ims = ims.to(device)

vae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
vae.eval()
vae = vae.to(device)

os.makedirs(os.path.join(train_config['task_name']), exist_ok=True)
with torch.no_grad():
    encoded_output = vae.encode(ims).latents
    decoded_output = vae.decode(encoded_output).sample
    encoded_output = torch.clamp(encoded_output, -1., 1.)
    encoded_output = (encoded_output + 1) / 2
    decoded_output = torch.clamp(decoded_output, -1., 1.)
    decoded_output = (decoded_output + 1) / 2
    ims = (ims + 1) / 2

    encoder_grid = make_grid(encoded_output.cpu(), nrow=ngrid)
    decoder_grid = make_grid(decoded_output.cpu(), nrow=ngrid)
    input_grid = make_grid(ims.cpu(), nrow=ngrid)
    encoder_grid = torchvision.transforms.ToPILImage()(encoder_grid)
    decoder_grid = torchvision.transforms.ToPILImage()(decoder_grid)
    input_grid = torchvision.transforms.ToPILImage()(input_grid)

    input_grid.save(os.path.join(train_config['task_name'], 'input_samples.png'))
    encoder_grid.save(os.path.join(train_config['task_name'], 'encoded_samples.png'))
    decoder_grid.save(os.path.join(train_config['task_name'], 'reconstructed_samples.png'))

    os.makedirs(os.path.join(train_config['task_name'], train_config['vqvae_latent_dir_name']), exist_ok=True)
    if train_config['save_latents']:
        # save Latents (but in a very unoptimized way)
        latent_path = os.path.join(train_config['task_name'], train_config['vqvae_latent_dir_name'])
        latent_fnames = glob.glob(os.path.join(train_config['task_name'], train_config['vqvae_latent_dir_name'], '*.pkl'))

        assert len(latent_fnames) == 0, 'Latents already present. Delete all latent files and re-run'
        if not os.path.exists(latent_path):
            os.mkdir(latent_path)
        print('Saving Latents for {}'.format(dataset_config['name']))

        fname_latent_map = {}
        part_count = 0
        count = 0
        for idx, im in enumerate(tqdm(data_loader)):
            encoded_output = vae.encode(im.float().to(device)).latents
            fname_latent_map[im_dataset.images[idx]] = encoded_output.cpu()
            # Save latents every 1000 images
            if (count+1) % 1000 == 0:
                pickle.dump(fname_latent_map, open(os.path.join(latent_path,
                                                                '{}.pkl'.format(part_count)), 'wb'))
                part_count += 1
                fname_latent_map = {}
            count += 1
        if len(fname_latent_map) > 0:
            pickle.dump(fname_latent_map, open(os.path.join(latent_path,
                                                '{}.pkl'.format(part_count)), 'wb'))
        print('Done saving latents')

100%|██████████| 30000/30000 [00:00<00:00, 333080.41it/s]


Found 30000 images
Found 0 captions


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vqvae/config.json:   0%|          | 0.00/486 [00:00<?, ?B/s]

diffusion_pytorch_model.bin:   0%|          | 0.00/221M [00:00<?, ?B/s]

Saving Latents for celebhq


  6%|▌         | 1702/30000 [26:33<6:25:35,  1.22it/s]

In [None]:
dataset_config = config['dataset_params']
train_config = config['train_params']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

im_dataset = CelebDataset(split='all',
                          im_path=dataset_config['im_path'],
                          im_size=dataset_config['im_size'],
                          im_channels=dataset_config['im_channels'])
data_loader = DataLoader(im_dataset, batch_size=1, shuffle=False)

num_images = train_config['num_samples']
ngrid = train_config['num_grid_rows']

idxs = torch.randint(0, len(im_dataset) - 1, (num_images,))
ims = torch.cat([im_dataset[idx][None, :] for idx in idxs]).float()
ims = ims.to(device)

vae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
vae.eval()
vae = vae.to(device)

os.makedirs(os.path.join(train_config['task_name']), exist_ok=True)
with torch.no_grad():
    encoded_output = vae.encode(ims).latents
    decoded_output = vae.decode(encoded_output).sample
    encoded_output = torch.clamp(encoded_output, -1., 1.)
    encoded_output = (encoded_output + 1) / 2
    decoded_output = torch.clamp(decoded_output, -1., 1.)
    decoded_output = (decoded_output + 1) / 2
    ims = (ims + 1) / 2

    encoder_grid = make_grid(encoded_output.cpu(), nrow=ngrid)
    decoder_grid = make_grid(decoded_output.cpu(), nrow=ngrid)
    input_grid = make_grid(ims.cpu(), nrow=ngrid)
    encoder_grid = torchvision.transforms.ToPILImage()(encoder_grid)
    decoder_grid = torchvision.transforms.ToPILImage()(decoder_grid)
    input_grid = torchvision.transforms.ToPILImage()(input_grid)

    input_grid.save(os.path.join(train_config['task_name'], 'input_samples.png'))
    encoder_grid.save(os.path.join(train_config['task_name'], 'encoded_samples.png'))
    decoder_grid.save(os.path.join(train_config['task_name'], 'reconstructed_samples.png'))

    os.makedirs(os.path.join(train_config['task_name'], train_config['vqvae_latent_dir_name']), exist_ok=True)
    if train_config['save_latents']:
        # save Latents (but in a very unoptimized way)
        latent_path = os.path.join(train_config['task_name'], train_config['vqvae_latent_dir_name'])
        latent_fnames = glob.glob(os.path.join(train_config['task_name'], train_config['vqvae_latent_dir_name'], '*.pkl'))

        assert len(latent_fnames) == 0, 'Latents already present. Delete all latent files and re-run'
        if not os.path.exists(latent_path):
            os.mkdir(latent_path)
        print('Saving Latents for {}'.format(dataset_config['name']))

        fname_latent_map = {}
        part_count = 0
        count = 0
        for idx, im in enumerate(tqdm(data_loader)):
            encoded_output = vae.encode(im.float().to(device)).latents
            fname_latent_map[im_dataset.images[idx]] = encoded_output.cpu()
            # Save latents every 1000 images
            if (count+1) % 1000 == 0:
                pickle.dump(fname_latent_map, open(os.path.join(latent_path,
                                                                '{}.pkl'.format(part_count)), 'wb'))
                part_count += 1
                fname_latent_map = {}
            count += 1
        if len(fname_latent_map) > 0:
            pickle.dump(fname_latent_map, open(os.path.join(latent_path,
                                                '{}.pkl'.format(part_count)), 'wb'))
        print('Done saving latents')

100%|██████████| 30000/30000 [00:00<00:00, 308495.44it/s]


Found 30000 images
Found 0 captions


vqvae/config.json:   0%|          | 0.00/486 [00:00<?, ?B/s]

vqvae/diffusion_pytorch_model.bin:   0%|          | 0.00/221M [00:00<?, ?B/s]

Saving Latents for celebhq


100%|██████████| 30000/30000 [22:30<00:00, 22.21it/s]

Done saving latents



