# Multi Diffusion
by Tuval Gelvan & Nitzan Bar




In [1]:
!pip3 install pycocotools



In [2]:
!pip3 install scikit-image



In [3]:
# from pycocotools.coco import COCO
# import matplotlib.pyplot as plt
# import skimage.io as io
 
# # Path to the directory containing annotations
# dataDir = 'data'
 
# # Type of dataset (train, val, test)
# dataType = 'train2017'  # or 'val2017', 'test2017', depending on your dataset
 
# # Path to the annotation file
# annFile = '{}/annotations/instances_{}.json'.format(dataDir, dataType)
 
# # Initialize COCO object
# coco = COCO(annFile)
 
# # Get IDs of all images in the dataset
# imgIds = coco.getImgIds()
 
# # Load the first image
# img = coco.loadImgs(imgIds[0])[0]
 
# # Load the image using its URL
# I = io.imread(img['coco_url'])
 
# # Display the image
# plt.imshow(I)
# plt.axis('off')
# plt.show()

## Setup

### Install Dependencies

In [4]:
!pip3 install numpy
!pip3 install matplotlib
!pip3 install fastai
!pip3 install accelerate
!pip3 install --upgrade transformers diffusers ftfy
!pip3 install cv



### Imports and Setup

In [5]:
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler

# suppress partial model loading warning
logging.set_verbosity_error()

import torch
import torch.nn as nn
import torchvision.transforms as T
import argparse
import numpy as np
from PIL import Image

# Set device
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

# summarize tensor
_s = lambda x: (x.shape,x.max(),x.min())

  from .autonotebook import tqdm as notebook_tqdm


## Authenticate with Hugging Face

To run Stable Diffusion on your computer you have to accept the model license. It's an open CreativeML OpenRail-M license that claims no rights on the outputs you generate and prohibits you from deliberately producing illegal or harmful content. The [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) provides more details. If you do accept the license, you need to be a registered user in 🤗 Hugging Face Hub and use an access token for the code to work. You have two options to provide your access token:

* Use the `huggingface-cli login` command-line tool in your terminal and paste your token when prompted. It will be saved in a file in your computer.
* Or use `notebook_login()` in a notebook, which does the same thing.

In [6]:
# torch.manual_seed(1)
# if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()

## Multi Diffusion Class

In [7]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True

In [8]:
def get_views(panorama_height, panorama_width, window_size=64, stride=8):
    panorama_height /= 8
    panorama_width /= 8
    num_blocks_height = (panorama_height - window_size) // stride + 1
    num_blocks_width = (panorama_width - window_size) // stride + 1
    total_num_blocks = int(num_blocks_height * num_blocks_width)
    views = []
    for i in range(total_num_blocks):
        h_start = int((i // num_blocks_width) * stride)
        h_end = h_start + window_size
        w_start = int((i % num_blocks_width) * stride)
        w_end = w_start + window_size
        views.append((h_start, h_end, w_start, w_end))
    return views

In [9]:
def preprocess_mask(mask_path, h, w, device):
    mask = np.array(Image.open(mask_path).convert("L"))
    mask = mask.astype(np.float32) / 255.0
    mask = mask[None, None]
    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1
    mask = torch.from_numpy(mask).to(device)
    mask = torch.nn.functional.interpolate(mask, size=(h, w), mode='nearest')
    return mask

In [20]:
class MultiDiffusion(nn.Module):
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__()

        self.device = device
        self.sd_version = sd_version
        # for unet text embeddings
        self.tokenizer_2 = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
        self.text_encoder_2 = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)

        print(f'[INFO] loading stable diffusion...')
        if hf_key is not None:
            print(f'[INFO] using hugging face custom model key: {hf_key}')
            model_key = hf_key
        elif self.sd_version == '2.1':
            model_key = "stabilityai/stable-diffusion-2-1-base"
        elif self.sd_version == '2.0':
            model_key = "stabilityai/stable-diffusion-2-base"
        elif self.sd_version == '1.5':
            model_key = "runwayml/stable-diffusion-v1-5"
        else:
            model_key = self.sd_version #For custom models or fine-tunes, allow people to use arbitrary versions
            #raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')

        # Create model
        self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
        self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
        self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
        # self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device)
        self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").to(self.device)

        self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")

        print(f'[INFO] loaded stable diffusion!')

    @torch.no_grad()
    def get_random_background(self, n_samples):
        # sample random background with a constant rgb value
        backgrounds = torch.rand(n_samples, 3, device=self.device)[:, :, None, None].repeat(1, 1, 512, 512)
        return torch.cat([self.encode_imgs(bg.unsqueeze(0)) for bg in backgrounds])

    @torch.no_grad()
    def get_text_embeds(self, prompt, negative_prompt):
        # Tokenize text and get embeddings
        text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
                                    truncation=True, return_tensors='pt')
        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]

        # Do the same for unconditional embeddings
        uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
                                      return_tensors='pt')

        uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

        # Cat for final embeddings
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
        return text_embeddings

    @torch.no_grad()
    def encode_imgs(self, imgs):
        imgs = 2 * imgs - 1
        posterior = self.vae.encode(imgs).latent_dist
        latents = posterior.sample() * 0.18215
        return latents

    @torch.no_grad()
    def decode_latents(self, latents):
        latents = 1 / 0.18215 * latents
        imgs = self.vae.decode(latents).sample
        imgs = (imgs / 2 + 0.5).clamp(0, 1)
        return imgs

    @torch.no_grad()
    def get_embedding_for_prompt(self, prompt):
        print("prompt:", prompt)
        max_length = self.tokenizer_2.model_max_length
        tokens = self.tokenizer_2([prompt],padding="max_length",max_length=max_length,truncation=True,return_tensors="pt")
        embeddings = self.text_encoder_2(tokens.input_ids.to(self.device))[0]
        return embeddings

    @torch.no_grad()
    def generate(self, masks, prompts, negative_prompts='', height=512, width=2048, num_inference_steps=50,
                      guidance_scale=0, bootstrapping=20, bg_img=None, our_bg_img=False):

        # get bootstrapping backgrounds
        # can move this outside of the function to speed up generation. i.e., calculate in init
        if our_bg_img:
            bg_orig = bg_img
            print("bg shape:", bg_orig.shape)
            img = T.ToPILImage()(bg_orig.squeeze().cpu())
            img.save('bg_img.png')
        else:
            bootstrapping_backgrounds = self.get_random_background(bootstrapping)
            print("bootstrapping_backgrounds shape:", bootstrapping_backgrounds.shape)

        # Prompts -> text embeds
        # text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2 * len(prompts), 77, 768]

        uncond_prompts = [""] * len(prompts)
        text_embeddings_list = []
        uncond_list = []
        for prompt, uncond in zip(prompts, uncond_prompts):
            text_embeddings_list.append(self.get_embedding_for_prompt(prompt))
            uncond_list.append(self.get_embedding_for_prompt(uncond))
        text_embeds = torch.cat(uncond_list + text_embeddings_list)
        text_embeds = torch.cat(uncond_list + text_embeddings_list)
        print(" text_embeds shape:", text_embeds.shape)

        # Define panorama grid and get views
        latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)
        noise = latent.clone().repeat(len(prompts) - 1, 1, 1, 1)
        views = get_views(height, width)
        count = torch.zeros_like(latent)
        value = torch.zeros_like(latent)

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.autocast('cuda'):
            for i, t in enumerate(self.scheduler.timesteps):
                count.zero_()
                value.zero_()

                for h_start, h_end, w_start, w_end in views:
                    masks_view = masks[:, :, h_start:h_end, w_start:w_end]
                    latent_view = latent[:, :, h_start:h_end, w_start:w_end].repeat(len(prompts), 1, 1, 1)
                    if not(our_bg_img):
                        if i < bootstrapping:
                            bg = bootstrapping_backgrounds[torch.randint(0, bootstrapping, (len(prompts) - 1,))]
                            bg = self.scheduler.add_noise(bg, noise[:, :, h_start:h_end, w_start:w_end], t)
                            # Check the sizes of the tensors
                            assert latent_view[1:].size(0) == masks_view[1:].size(0), "Size mismatch between latent_view and masks_view"
                            assert bg.size(0) == (1 - masks_view[1:]).size(0), "Size mismatch between bg and masks_view"
                            latent_view[1:] = latent_view[1:] * masks_view[1:] + bg * (1 - masks_view[1:])
                    else:
                        if i < bootstrapping:
                            # bg = bg_orig[torch.randint(0, bootstrapping, (len(prompts) - 1,))]
                            bg = bg_orig[:, :, h_start:h_end, w_start:w_end].repeat(len(prompts), 1, 1, 1)
                            print("bg_view shape:", bg.shape)
                            img = T.ToPILImage()(bg[0][0].cpu())
                            img.save(f'bg_img_{i}.png')
                            # Check the sizes of the tensors
                            assert latent_view[1:].size(0) == masks_view[1:].size(0), "Size mismatch between latent_view and masks_view"
                            assert bg[1:].size(0) == (1 - masks_view[1:]).size(0), "Size mismatch between bg and masks_view"
                            latent_view[1:] = latent_view[1:] * masks_view[1:] + bg[1:] * (1 - masks_view[1:])


                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latent_view] * 2)

                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']

                    value[:, :, h_start:h_end, w_start:w_end] += (latents_view_denoised * masks_view).sum(dim=0,
                                                                                                          keepdims=True)
                    count[:, :, h_start:h_end, w_start:w_end] += masks_view.sum(dim=0, keepdims=True)

                # take the MultiDiffusion step
                latent = torch.where(count > 0, value / count, value)

        # Img latents -> imgs
        imgs = self.decode_latents(latent)  # [1, 3, 512, 512]
        img = T.ToPILImage()(imgs[0].cpu())
        return img

### Run Multi Diffusion

In [22]:
mask_paths = ["masks/mask_0.jpg", "masks/mask_1.jpg"]
# important: it is necessary that SD output high-quality images for the bg/fg prompts.
bg_prompt = ['']
bg_negative = ""  # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image'
# fg_prompts = ["Sand", "Zebra"]
fg_prompts = ["Horse", "Zebra"]
fg_negative = ""  # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image'
sd_version='2.0'
H=512
W=512
seed=42
steps=100
# bootstrapping encourages high fidelity to tight masks, the value can be lowered is most cases
# bootstrapping=20
bootstrapping=1

seed_everything(seed)

device = torch.device('cuda')

sd = MultiDiffusion(device, sd_version)

# fg_masks = torch.cat([preprocess_mask(mask_path, H // 8, W // 8, device) for mask_path in mask_paths])
fg_masks = torch.cat([torch.from_numpy(np.array(Image.open(mask_path))).unsqueeze(0).unsqueeze(0).to(device) for mask_path in mask_paths])
print(fg_masks.shape)
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
bg_mask[bg_mask < 0] = 0
masks = torch.cat([bg_mask, fg_masks])

prompts = bg_prompt + fg_prompts
# neg_prompts = [bg_negative] + [fg_negative]
neg_prompts = ''
bg_img_path = "data/horse_scaled.jpg"
bg_img_array = np.array(Image.open(bg_img_path).convert("L"))
print(bg_img_array.shape)
bg_img = torch.from_numpy(np.array(Image.open(bg_img_path).convert("L").resize((64,64)))).unsqueeze(0).unsqueeze(0).to(device)
img = sd.generate(masks, prompts, neg_prompts, H, W, steps, bootstrapping=bootstrapping, bg_img=bg_img, our_bg_img=True, guidance_scale=15)

# save image
img.save('out.png')

[INFO] loading stable diffusion...
[INFO] loaded stable diffusion!
torch.Size([2, 1, 512, 512])
(512, 512)
bg shape: torch.Size([1, 1, 64, 64])
prompt: 
prompt: 
prompt: Horse
prompt: 
prompt: Zebra
prompt: 
 text_embeds shape: torch.Size([6, 77, 768])
bg_view shape: torch.Size([3, 1, 64, 64])


  latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)
