In [1]:
import os
import numpy as np
from PIL import Image

from functools import partial, wraps
from tqdm.auto import tqdm
from contextlib import contextmanager

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose, ToTensor, Pad, Resize, ToPILImage, InterpolationMode, Normalize

from dalle2_laion import ModelLoadConfig, DalleModelManager
from dalle2_laion.scripts import InferenceScript
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CocoDataset(torch.utils.data.Dataset):
    def __init__(self, im_dir, caption_dir):
        self.im_dir = im_dir
        self.caption_dir = caption_dir
        self.im_fnames = sorted(os.listdir(im_dir))
        self.caption_fnames = sorted(os.listdir(caption_dir))
        self.to_tensor_transform = ToTensor()
        self.pad_transform = lambda im, pad_right, pad_bottom: Pad(padding=(0, 0, pad_right, pad_bottom))(im)
        self.resize_transform = Resize((256, 256), interpolation=InterpolationMode.BILINEAR)

    def __getitem__(self, idx):
        im = Image.open(os.path.join(self.im_dir, self.im_fnames[idx]))
        w, h = im.size
        im = self.to_tensor_transform(im)
        new_size = max(w, h)
        im = self.pad_transform(im, new_size - w, new_size - h)
        im = self.resize_transform(im)

        with open(os.path.join(self.caption_dir, self.caption_fnames[idx]), 'r') as captions_f:
            captions = [caption.strip() for caption in captions_f.readlines()]

        return self.im_fnames[idx], im, captions

    def __len__(self):
        return len(self.im_fnames)

        
class ExampleInference(InferenceScript):
    def run(self, text: str):
        """
        Takes a string and returns a single image.
        """
        text = [text]
        image_embedding_map = self._sample_prior(text)
        image_embedding = image_embedding_map[0][0].unsqueeze(0)
        image_map = self._sample_decoder(text=text, image_embed=image_embedding)
        return image_map[0][0]

In [3]:
def exists(val):
    return val is not None

def first(arr, d = None):
    if len(arr) == 0:
        return d
    return arr[0]

def maybe(fn):
    @wraps(fn)
    def inner(x, *args, **kwargs):
        if not exists(x):
            return x
        return fn(x, *args, **kwargs)
    return inner

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def cast_tuple(val, length = None, validate = True):
    if isinstance(val, list):
        val = tuple(val)

    out = val if isinstance(val, tuple) else ((val,) * default(length, 1))

    if exists(length) and validate:
        assert len(out) == length

    return out

@contextmanager
def null_context(*args, **kwargs):
    yield

def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

def resize_image_to(
    image,
    target_image_size,
    clamp_range = None,
    nearest = False,
    **kwargs
):
    orig_image_size = image.shape[-1]

    if orig_image_size == target_image_size:
        return image

    if not nearest:
        scale_factors = target_image_size / orig_image_size
        out = resize(image, scale_factors = scale_factors, **kwargs)
    else:
        out = F.interpolate(image, target_image_size, mode = 'bicubic') #'nearest')

    if exists(clamp_range):
        out = out.clamp(*clamp_range)

    return out

In [4]:
def module_device(module):
    if isinstance(module, nn.Identity):
        return 'cpu' # It doesn't matter
    return next(module.parameters()).device

@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None, device=torch.device('cpu')):
    assert exists(unet_number) ^ exists(unet)

    if exists(unet_number):
        unet = self.get_unet(unet_number)

    self.to(device)

    devices = [module_device(unet) for unet in self.unets]
    self.unets.cpu()
    unet.to(device)

    yield
    
    for unet, device in zip(self.unets, devices):
        unet.to(device)


def p_mean_variance_custom(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
    assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'

    pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level))

    if learned_variance:
        pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)

    if predict_x_start:
        x_recon = pred
    else:
        x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)

    if clip_denoised:
        x_recon = self.dynamic_threshold(x_recon)

    model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t)

    if learned_variance:
        # if learned variance, posterio variance and posterior log variance are predicted by the network
        # by an interpolation of the max and min log beta values
        # eq 15 - https://arxiv.org/abs/2102.09672
        min_log = extract(noise_scheduler.posterior_log_variance_clipped, t, x.shape)
        max_log = extract(torch.log(noise_scheduler.betas), t, x.shape)
        var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)

        if self.learned_variance_constrain_frac:
            var_interp_frac = var_interp_frac.sigmoid()

        posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
        posterior_variance = posterior_log_variance.exp()

    return model_mean, posterior_variance, posterior_log_variance


def p_sample_custom(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
    b, *_, device = *x.shape, x.device
    model_mean, _, model_log_variance = p_mean_variance_custom(self, unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
    noise = torch.randn_like(x)
    # no noise when t == 0
    nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
    return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise


def p_sample_loop_ddpm_custom(
    self,
    unet,
    shape,
    img_start,
    image_embed,
    noise_scheduler,
    predict_x_start = False,
    learned_variance = False,
    clip_denoised = True,
    lowres_cond_img = None,
    text_encodings = None,
    cond_scale = 1,
    is_latent_diffusion = False,
    lowres_noise_level = None,
    inpaint_image = None,
    inpaint_mask = None,
    inpaint_resample_times = 5,
    gt=None
):
    for param in unet.parameters():
        param.requires_grad = False
    
    device = self.device
    
    b = shape[0]
    img = img_start.to(device) #torch.randn(shape, device = device)

    resample_times = 1

    if not is_latent_diffusion:
        lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
    
    img = F.interpolate(img, size=(256, 256), mode='bicubic')

    text_encodings.requires_grad = True
    image_embed.requires_grad = True
    lowres_cond_img.requires_grad = True
    img.requires_grad = True
    
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.AdamW([lowres_cond_img], lr=1e-4)
    
    psnrs = []
    
    for time in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
        is_last_timestep = time == 0

        
        if time > 100:
            continue
        

        for r in reversed(range(0, resample_times)):
            is_last_resample_step = r == 0

            times = torch.full((b,), time, device = device, dtype = torch.long)

            img = p_sample_custom(
                self,
                unet,
                img.detach(),
                times,
                image_embed = image_embed,
                text_encodings = text_encodings,
                cond_scale = cond_scale,
                lowres_cond_img = lowres_cond_img,
                lowres_noise_level = lowres_noise_level,
                predict_x_start = predict_x_start,
                noise_scheduler = noise_scheduler,
                learned_variance = learned_variance,
                clip_denoised = clip_denoised
            )
            loss = loss_fn(img, gt.unsqueeze(0))
            loss.backward()
            optimizer.step()

            psnrs.append(psnr(self.unnormalize_img(img.detach()).squeeze(0), gt.cpu()).item())

    unnormalize_img = self.unnormalize_img(img)
    return unnormalize_img, psnrs


def p_sample_loop_custom(self, *args, noise_scheduler, timesteps = None, **kwargs):
    num_timesteps = noise_scheduler.num_timesteps

    timesteps = default(timesteps, num_timesteps)
    assert timesteps <= num_timesteps
    is_ddim = timesteps < num_timesteps

    return p_sample_loop_ddpm_custom(inference.model_manager.decoder_info.model,*args, noise_scheduler = noise_scheduler, **kwargs)



def sample_custom(
    self,
    lowres_cond_img=None,
    image = None,
    image_embed = None,
    text = None,
    text_encodings = None,
    batch_size = 1,
    cond_scale = 1.,
    start_at_unet_number = 1,
    stop_at_unet_number = None,
    distributed = False,
    inpaint_image = None,
    inpaint_mask = None,
    inpaint_resample_times = 5,
    gt=None
):
    assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'

    if not self.unconditional:
        batch_size = image_embed.shape[0]

    if exists(text) and not exists(text_encodings) and not self.unconditional:
        assert exists(self.clip)
        _, text_encodings = self.clip.embed_text(text)

    assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
    assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'

    assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'

    img = None
    if start_at_unet_number > 1:
        # Then we are not generating the first image and one must have been passed in
        assert exists(image), 'image must be passed in if starting at unet number > 1'
        assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)
        prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]
        img = resize_image_to(image, prev_unet_output_size, nearest = True)
    is_cuda = next(self.parameters()).is_cuda
    num_unets = self.num_unets
    cond_scale = cast_tuple(cond_scale, num_unets)

    for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
        if unet_number < start_at_unet_number:
            continue  # It's the easiest way to do it

        context = one_unet_in_gpu(self, unet = unet, device=device) if is_cuda else null_context()
        
        with context:
            # prepare low resolution conditioning for upsamplers

            #lowres_cond_img = lowres_noise_level = None
            lowres_noise_level = None
            shape = (batch_size, channel, image_size, image_size)

            is_latent_diffusion = isinstance(vae, VQGanVAE)
            image_size = vae.get_encoded_fmap_size(image_size)
            shape = (batch_size, vae.encoded_dim, image_size, image_size)
            
            lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
            
            # denoising loop for image

            img, psnrs = p_sample_loop_custom(
                self=inference.model_manager.decoder_info.model,
                unet=unet,
                shape=shape,
                img_start=image,
                image_embed = image_embed,
                text_encodings = text_encodings,
                cond_scale = unet_cond_scale,
                predict_x_start = predict_x_start,
                learned_variance = learned_variance,
                clip_denoised = not is_latent_diffusion,
                lowres_cond_img = lowres_cond_img,
                lowres_noise_level = lowres_noise_level,
                is_latent_diffusion = is_latent_diffusion,
                noise_scheduler = noise_scheduler,
                timesteps = sample_timesteps,
                inpaint_image = inpaint_image,
                inpaint_mask = inpaint_mask,
                inpaint_resample_times = inpaint_resample_times,
                gt=gt
            )

            img = vae.decode(img)

        if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
            break

    return img, psnrs


In [5]:
def psnr(pred, gt):
    # pred, gt \in [0, 1]
    pred_int = (pred * 255).to(torch.uint8).cpu().numpy()
    gt_int = (gt * 255).to(torch.uint8).cpu().numpy()
    return 20 * np.log10(255) - 10 * np.log10(((pred_int - gt_int) ** 2).mean())

## Optimization

In [6]:
dataset = CocoDataset(
    im_dir='./data/images',
    caption_dir='./data/captions'
)

In [7]:
device = torch.device('cuda:0')

In [8]:
model_config = ModelLoadConfig.from_json_path('./configs/dalle2.json')

In [9]:
model_manager = DalleModelManager(model_config)

FIX: Switch to this version with `pip install DALLE2-pytorch==1.1.0`. If different models suggest different versions, you may just need to choose one.


In [10]:
inference = ExampleInference(model_manager)
inference.model_manager.decoder_info.model.sample_timesteps = (None, 1000)

In [12]:
psnrs = []

interpolation_mode = 'bicubic'

psnrs_from_steps = []
for source_image_i in range(len(dataset)):
    cur_image_psnrs_from_steps = []

    source_image_name, source_image, source_captions = dataset[source_image_i]
    max_psnr = 0
    for text_str_i, text_str in enumerate(source_captions):
        text = [text_str]

        with torch.no_grad():
            image_embedding_map = inference._sample_prior(text)
            image_embedding = image_embedding_map[0][0].unsqueeze(0)
            source_image_lowres_small = F.interpolate(source_image.unsqueeze(0), size=(64, 64), mode=interpolation_mode)
        lowres_cond_img = F.interpolate(source_image_lowres_small, size=(256, 256), mode=interpolation_mode)
        
        inference.model_manager.decoder_info.model.to(device)
        
        lowres_cond_img = lowres_cond_img.to(device)
        with torch.no_grad():
            text_encodings = inference._encode_text(text).to(device)
        image_embed = image_embedding.to(device)
        source_image_lowres_small = source_image_lowres_small.to(device)
        source_image = source_image.to(device)

        source_image_lowres_small = source_image_lowres_small * 2. - 1.

        source_image.requires_grad = False

        res, cur_img_text_psnrs_from_steps = sample_custom(
            self=inference.model_manager.decoder_info.model,
            lowres_cond_img=lowres_cond_img,
            image_embed = image_embed,
            text_encodings = text_encodings,
            image=source_image_lowres_small,
            start_at_unet_number = 2,
            gt=source_image
        )

        res = res.detach().squeeze(0)
        cur_image_psnrs_from_steps.append(cur_img_text_psnrs_from_steps)

        with torch.no_grad():
            max_psnr = max(psnr(res, source_image).item(), max_psnr)
    
    psnrs_from_steps.append(cur_image_psnrs_from_steps)

    print(f"image = {source_image_name}")
    print(max_psnr)
    print("\n---------------------------\n")
    psnrs.append(max_psnr)

sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 47.79it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 19.23it/s]
2it [00:53, 26.53s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.94it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 19.14it/s]
2it [00:53, 26.66s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.46it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 19.05it/s]
2it [00:53, 26.79s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 49.18it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.96it/s]
2it [00:53, 26.90s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 49.50it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.94it/s]
2it [00:53, 26.93s/it]


image = 1.jpg
32.65313127129459

---------------------------



sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 51.01it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.94it/s]
2it [00:53, 26.93s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.23it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.93it/s]
2it [00:53, 26.95s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.66it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.92it/s]
2it [00:53, 26.97s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.97it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.92it/s]
2it [00:53, 26.96s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.58it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.93it/s]
2it [00:53, 26.95s/it]


image = 2.jpg
32.08402842591964

---------------------------



sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 51.12it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.94it/s]
2it [00:53, 26.94s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 48.91it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.95it/s]
2it [00:53, 26.92s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.91it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.96it/s]
2it [00:53, 26.92s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.98it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.97it/s]
2it [00:53, 26.90s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 49.05it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.96it/s]
2it [00:53, 26.91s/it]


image = 3.jpg
31.590432737255895

---------------------------



sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.41it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.94it/s]
2it [00:53, 26.93s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.40it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.93it/s]
2it [00:53, 26.95s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.72it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.91it/s]
2it [00:53, 26.97s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 49.98it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 18.94it/s]
2it [00:53, 26.94s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.71it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 19.07it/s]
2it [00:53, 26.76s/it]


image = 4.jpg
36.71517332465518

---------------------------



sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.75it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 19.06it/s]
2it [00:53, 26.76s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.40it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 19.06it/s]
2it [00:53, 26.77s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.25it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 19.07it/s]
2it [00:53, 26.76s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.76it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 19.09it/s]
2it [00:53, 26.73s/it]
sampling loop time step: 100%|██████████| 64/64 [00:01<00:00, 50.85it/s]
sampling loop time step: 100%|██████████| 1000/1000 [00:52<00:00, 19.11it/s]
2it [00:53, 26.70s/it]

image = 5.jpg
30.37966538876232

---------------------------






In [14]:
print(np.mean(psnrs))

32.68448622957752
