In [1]:
from dalle2_pytorch import Decoder

In [2]:
from functools import partial, wraps
from tqdm.auto import tqdm
from contextlib import contextmanager
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE

def exists(val):
    return val is not None

def first(arr, d = None):
    if len(arr) == 0:
        return d
    return arr[0]

def maybe(fn):
    @wraps(fn)
    def inner(x, *args, **kwargs):
        if not exists(x):
            return x
        return fn(x, *args, **kwargs)
    return inner

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def cast_tuple(val, length = None, validate = True):
    if isinstance(val, list):
        val = tuple(val)

    out = val if isinstance(val, tuple) else ((val,) * default(length, 1))

    if exists(length) and validate:
        assert len(out) == length

    return out

@contextmanager
def null_context(*args, **kwargs):
    yield

def eval_decorator(fn):
    def inner(model, *args, **kwargs):
        was_training = model.training
        model.eval()
        out = fn(model, *args, **kwargs)
        model.train(was_training)
        return out
    return inner

def resize_image_to(
    image,
    target_image_size,
    clamp_range = None,
    nearest = False,
    **kwargs
):
    orig_image_size = image.shape[-1]

    if orig_image_size == target_image_size:
        return image

    if not nearest:
        scale_factors = target_image_size / orig_image_size
        out = resize(image, scale_factors = scale_factors, **kwargs)
    else:
        out = F.interpolate(image, target_image_size, mode = 'nearest')

    if exists(clamp_range):
        out = out.clamp(*clamp_range)

    return out

In [3]:
def my_sample(
    self,
    source_image_small_64,
    lowres_cond_img=None,
    image = None,
    image_embed = None,
    text = None,
    text_encodings = None,
    batch_size = 1,
    cond_scale = 1.,
    start_at_unet_number = 1,
    stop_at_unet_number = None,
    distributed = False,
    inpaint_image = None,
    inpaint_mask = None,
    inpaint_resample_times = 5,
):
    print("@@@ inside my_sample")
    assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'

    if not self.unconditional:
        batch_size = image_embed.shape[0]

    if exists(text) and not exists(text_encodings) and not self.unconditional:
        assert exists(self.clip)
        _, text_encodings = self.clip.embed_text(text)

    assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
    assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'

    assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'

    img = None
    if start_at_unet_number > 1:
        # Then we are not generating the first image and one must have been passed in
        assert exists(image), 'image must be passed in if starting at unet number > 1'
        assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)
        prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]
        img = resize_image_to(image, prev_unet_output_size, nearest = True)
    is_cuda = next(self.parameters()).is_cuda
    num_unets = self.num_unets
    cond_scale = cast_tuple(cond_scale, num_unets)

    for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
        if unet_number < start_at_unet_number:
            continue  # It's the easiest way to do it

        context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context()

        with context:
            # prepare low resolution conditioning for upsamplers

            #lowres_cond_img = lowres_noise_level = None
            lowres_noise_level = None
            shape = (batch_size, channel, image_size, image_size)

#             if unet.lowres_cond:
#                 lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = True)

#                 if lowres_cond.use_noise:
#                     lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device)
#                     lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level)

            # latent diffusion

            is_latent_diffusion = isinstance(vae, VQGanVAE)
            image_size = vae.get_encoded_fmap_size(image_size)
            shape = (batch_size, vae.encoded_dim, image_size, image_size)

            lowres_cond_img = maybe(vae.encode)(lowres_cond_img)

            # denoising loop for image
            
            def prepare_args(unet_number, lowres_cond_img):
                if unet_number == 1:
                    return lowres_cond_img
                elif unet_number == 2:
                    lowres_cond_img_res = resize_image_to(
                        source_image_small_64, 
                        target_image_size = image_size, 
                        clamp_range = self.input_image_range, 
                        nearest = True
                    )
                    return lowres_cond_img_res
                else:
                    raise ValueError()
            
            img = self.p_sample_loop(
                unet,
                shape,
                image_embed = image_embed,
                text_encodings = text_encodings,
                cond_scale = unet_cond_scale,
                predict_x_start = predict_x_start,
                learned_variance = learned_variance,
                clip_denoised = not is_latent_diffusion,
                lowres_cond_img = prepare_args(unet_number, lowres_cond_img),
                lowres_noise_level = lowres_noise_level,
                is_latent_diffusion = is_latent_diffusion,
                noise_scheduler = noise_scheduler,
                # timesteps = sample_timesteps,
                inpaint_image = inpaint_image,
                inpaint_mask = inpaint_mask,
                inpaint_resample_times = inpaint_resample_times
            )

            img = vae.decode(img)

        if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
            break

    return img


Decoder.sample = my_sample

In [4]:
from dalle2_laion import ModelLoadConfig, DalleModelManager
from dalle2_laion.scripts import InferenceScript
from typing import List

import torch

class ExampleInference(InferenceScript):
    def run(self, text: str, source_image_small_64):
        """
        Takes a string and returns a single image.
        """
        text = [text]
        image_embedding_map = self._sample_prior(text)
        image_embedding = image_embedding_map[0][0].unsqueeze(0)
        image_map = self._sample_decoder(
            source_image_small_64,
            text=text, 
            image_embed=image_embedding, 
            
        )
        return image_map[0][0]

    def _sample_decoder(
        self,
        source_image_small_64,*,
        images = None, image_embed: List[torch.Tensor] = None,
        text: List[str] = None, text_encoding: List[torch.Tensor] = None,
        inpaint_images = None, inpaint_image_masks: List[torch.Tensor] = None,
        cond_scale: float = None, sample_count: int = 1, batch_size: int = 10,
    ):
        """
        Samples images from the decoder
        Capable of doing basic generation with a list of image embeddings (possibly also conditioned with a list of strings or text embeddings)
        Also capable of two more advanced generation techniques:
        1. Variation generation: If images are passed in the image embeddings will be generated based on those.
        2. In-painting generation: If images and masks are passed in, the images will be in-painted using the masks and the image embeddings.
        """
        if cond_scale is None:
            # Then we use the default scale
            load_config = self.model_manager.model_config.decoder
            unet_configs = load_config.unet_sources
            cond_scale = [1.0] * load_config.final_unet_number
            for unet_config in unet_configs:
                if unet_config.default_cond_scale is not None:
                    for unet_number, new_cond_scale in zip(unet_config.unet_numbers, unet_config.default_cond_scale):
                        cond_scale[unet_number - 1] = new_cond_scale
        self.print(f"Sampling decoder with cond_scale: {cond_scale}")
            
        decoder_info = self.model_manager.decoder_info
        assert decoder_info is not None, "No decoder loaded."
        data_requirements = decoder_info.data_requirements
        min_image_size = min(min(image.size) for image in images) if images is not None else None
        is_valid, errors = data_requirements.is_valid(
            has_image_emb=image_embed is not None, has_image=images is not None,
            has_text_encoding=text_encoding is not None, has_text=text is not None,
            image_size=min_image_size
        )
        assert is_valid, f"The data requirements for the decoder are not satisfied: {errors}"

        # Prepare the data
        image_embeddings = []  # The null case where nothing is done. This should never be used in actuality, but for stylistic consistency I'm keeping it.
        if data_requirements.image_embedding:
            if image_embed is None:
                # Then we need to use clip to generate the image embedding
                image_embed = self._embed_images(images)
            # Then we need to group these tensors into batches of size batch_size such that the total number of samples is sample_count
            image_embeddings, image_embeddings_map = self._repeat_tensor_and_batch(image_embed, repeat_num=sample_count, batch_size=batch_size)
            self.print(f"Decoder batched inputs into {len(image_embeddings)} batches. Total number of samples: {sum(len(t) for t in image_embeddings)}.")
        
        if data_requirements.text_encoding:
            if text_encoding is None:
                text_encoding = self._encode_text(text)
            text_encodings, text_encodings_map = self._repeat_tensor_and_batch(text_encoding, repeat_num=sample_count, batch_size=batch_size)

        assert len(image_embeddings) > 0, "No data provided for decoder inference."
        output_image_map: Dict[int, List[PILImage.Image]] = {}
        with self._decoder_in_gpu() as decoder:
            for i in range(len(image_embeddings)):
                args = {}
                ### HW CODE START ###
                
                ### HW CODE END ###
                embeddings_map = []
                if data_requirements.image_embedding:
                    args["image_embed"] = image_embeddings[i].to(self.device)
                    embeddings_map = image_embeddings_map[i]
                if data_requirements.text_encoding:
                    args["text_encodings"] = text_encodings[i].to(self.device)
                    embeddings_map = text_encodings_map[i]
                if inpaint_images is not None:
                    assert len(inpaint_images) == len(inpaint_image_masks), "Number of inpaint images and masks must match."
                    inpaint_image_tensors = self._pil_to_torch(inpaint_images, resize_for_clip=False)
                    args["inpaint_image"] = inpaint_image_tensors.to(self.device)
                    args["inpaint_mask"] = torch.stack(inpaint_image_masks).to(self.device)
                    self.print(f"image tensor shape: {args['inpaint_image'].shape}. mask shape: {args['inpaint_mask'].shape}")
                args.update({
                    "source_image_small_64": source_image_small_64,
                })
                output_images = decoder.sample(**args, cond_scale=cond_scale)
                for output_image, input_embedding_number in zip(output_images, embeddings_map):
                    if input_embedding_number not in output_image_map:
                        output_image_map[input_embedding_number] = []
                    output_image_map[input_embedding_number].append(self._torch_to_pil(output_image))
            return output_image_map

# model_config = ModelLoadConfig.from_json_path("dalle2_laion.json")
# model_manager = DalleModelManager(model_config)
# inference = ExampleInference(model_manager)
# image = inference.run("Hello World")

In [5]:
model_config = ModelLoadConfig.from_json_path("dalle2_laion.json")

In [6]:
model_manager = DalleModelManager(model_config)

FIX: Switch to this version with `pip install DALLE2-pytorch==1.1.0`. If different models suggest different versions, you may just need to choose one.


In [7]:
inference = ExampleInference(model_manager)

In [8]:
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms import ToTensor, ToPILImage
import numpy as np

import os
from torchvision.transforms import Compose, ToTensor, Pad, Resize, ToPILImage, InterpolationMode

class CocoDataset(torch.utils.data.Dataset):
    def __init__(self, im_dir, caption_dir):
        self.im_dir = im_dir
        self.caption_dir = caption_dir
        self.im_fnames = sorted(os.listdir(im_dir))
        self.caption_fnames = sorted(os.listdir(caption_dir))
        self.to_tensor_transform = ToTensor()
        self.pad_transform = lambda im, pad_right, pad_bottom: Pad(padding=(0, 0, pad_right, pad_bottom))(im)
        self.resize_transform = Resize((256, 256), interpolation=InterpolationMode.BILINEAR)

    def __getitem__(self, idx):
        im = Image.open(os.path.join(self.im_dir, self.im_fnames[idx]))
        w, h = im.size
        im = self.to_tensor_transform(im)
        new_size = max(w, h)
        im = self.pad_transform(im, new_size - w, new_size - h)
        im = self.resize_transform(im)

        with open(os.path.join(self.caption_dir, self.caption_fnames[idx]), 'r') as captions_f:
            captions = captions_f.readlines()

        return self.im_fnames[idx], im, captions

    def __len__(self):
        return len(self.im_fnames)

In [9]:
def psnr(pred, gt):
    # pred \in [0, 1]
    pred_int = pred * 255
    gt = gt * 255
    pred_int = torch.round(pred_int)
    gt = torch.round(gt)
    return 20 * torch.log10(255 / torch.sqrt(F.mse_loss(pred_int, gt)))

In [16]:
dataset = CocoDataset(
    im_dir="images/",
    caption_dir="captions/"
)
psnrs = []
with torch.no_grad():
    for source_image_i in range(len(dataset)):
        (source_image_name, source_image, source_captions) = dataset[source_image_i]
        max_psnr = 0
        for text_str_i, text_str in enumerate(source_captions):
            source_image_small_64 = F.interpolate(source_image.unsqueeze(0), size=(64, 64))
            # normalize

            source_image_small_64 = source_image_small_64.cuda()

            res = inference.run(
                text=text_str,
                source_image_small_64 = source_image_small_64,
            )
            max_psnr = max(max_psnr, (psnr(ToTensor()(res), source_image).item()))
        print(f"image = {source_image_name}")
        print(max_psnr)
        print("\n---------------------------\n")
        psnrs.append(max_psnr)

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

image = banana.jpg
22.468290328979492

---------------------------



sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

image = cat.jpg
21.827993392944336

---------------------------



sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

image = elephant.jpg
23.965728759765625

---------------------------



sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

image = pizza.jpg
22.672077178955078

---------------------------



sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/64 [00:00<?, ?it/s]

@@@ inside my_sample


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

image = tie.jpg
19.42002296447754

---------------------------



In [17]:
import numpy as np

print(np.mean(psnrs))

22.070822525024415
