In [None]:
!pip install transformers diffusers tensorboardX
!git clone https://github.com/ByeongHyunPak/omni-proj.git

import os
os.chdir('/content/omni-proj/omni-proj')

<a href="https://colab.research.google.com/github/ByeongHyunPak/omni-proj/blob/main/scratchpad.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MultiDiffusion

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler

# suppress partial model loading warning
logging.set_verbosity_error()

import torch
import torch.nn as nn
import torchvision.transforms as T
import argparse
from tqdm import tqdm

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True

def get_views(panorama_height, panorama_width, window_size=64, stride=8):
    panorama_height /= 8
    panorama_width /= 8
    num_blocks_height = (panorama_height - window_size) // stride + 1
    num_blocks_width = (panorama_width - window_size) // stride + 1
    total_num_blocks = int(num_blocks_height * num_blocks_width)
    views = []
    for i in range(total_num_blocks):
        h_start = int((i // num_blocks_width) * stride)
        h_end = h_start + window_size
        w_start = int((i % num_blocks_width) * stride)
        w_end = w_start + window_size
        views.append((h_start, h_end, w_start, w_end))
    return views

class MultiDiffusion(nn.Module):
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__()

        self.device = device
        self.sd_version = sd_version

        print(f'[INFO] loading stable diffusion...')
        if hf_key is not None:
            print(f'[INFO] using hugging face custom model key: {hf_key}')
            model_key = hf_key
        elif self.sd_version == '2.1':
            model_key = "stabilityai/stable-diffusion-2-1-base"
        elif self.sd_version == '2.0':
            model_key = "stabilityai/stable-diffusion-2-base"
        elif self.sd_version == '1.5':
            model_key = "runwayml/stable-diffusion-v1-5"
        else:
            raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')

        # Create model
        self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
        self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
        self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
        self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device)

        self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")

        print(f'[INFO] loaded stable diffusion!')

    @torch.no_grad()
    def get_text_embeds(self, prompt, negative_prompt):
        # prompt, negative_prompt: [str]

        # Tokenize text and get embeddings
        text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
                                    truncation=True, return_tensors='pt')
        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]

        # Do the same for unconditional embeddings
        uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
                                      return_tensors='pt')

        uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

        # Cat for final embeddings
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
        return text_embeddings

    @torch.no_grad()
    def decode_latents(self, latents):
        latents = 1 / 0.18215 * latents
        imgs = self.vae.decode(latents).sample
        imgs = (imgs / 2 + 0.5).clamp(0, 1)
        return imgs

    @torch.no_grad()
    def text2panorama(self, prompts, negative_prompts='', height=512, width=2048, num_inference_steps=50,
                      guidance_scale=7.5, visualize_intermidiates=False):

        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define panorama grid and get views
        latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)
        views = get_views(height, width)
        count = torch.zeros_like(latent)
        value = torch.zeros_like(latent)

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.autocast('cuda'):
            if visualize_intermidiates is True:
                intermidiate_imgs = []
            for i, t in enumerate(tqdm(self.scheduler.timesteps)):
                count.zero_()
                value.zero_()

                for h_start, h_end, w_start, w_end in views:
                    # TODO we can support batches, and pass multiple views at once to the unet
                    latent_view = latent[:, :, h_start:h_end, w_start:w_end]

                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latent_view] * 2)

                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']
                    value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
                    count[:, :, h_start:h_end, w_start:w_end] += 1

                # take the MultiDiffusion step
                latent = torch.where(count > 0, value / count, value)

                # visualize intermidiate timesteps
                if visualize_intermidiates is True:
                    imgs = self.decode_latents(latent)  # [1, 3, 512, 512]
                    img = T.ToPILImage()(imgs[0].cpu())
                    intermidiate_imgs.append((i, img))

        # Img latents -> imgs
        imgs = self.decode_latents(latent)  # [1, 3, 512, 512]
        img = T.ToPILImage()(imgs[0].cpu())

        if visualize_intermidiates is True:
            intermidiate_imgs.append((len(intermidiate_imgs), img))
            return intermidiate_imgs
        else:
            return [img]

In [None]:
%mkdir /content/md

seed_everything(2024)

device = torch.device('cuda')

# opt variables
sd_version = '2.0'
prompt = 'frog PePe riding a horse'
negative = ''
H = 512
W = 512
steps = 50
outfile = f'/content/md/{prompt}.png'

sd = MultiDiffusion(device, sd_version)

img = sd.text2panorama(prompt, negative, H, W, steps, visualize_intermidiates=True)

# save image
if len(img) == 1:
    img[0].save(outfile)
else:
    for t, im in tqdm(img):
        im.save(f'/content/md/{t:03d}_{prompt}.png')

# MultiDiffusion for 360 degree Panorama image

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler

# suppress partial model loading warning
logging.set_verbosity_error()

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import utils
from utils import gridy2x_erp2pers, gridy2x_pers2erp

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True

class ERPMultiDiffusion(MultiDiffusion):
    def __init__(self, num_rows=4, **kwargs):
        super(ERPMultiDiffusion, self).__init__(**kwargs)      
        
        if num_rows == 4:
            num_cols = [3, 6, 6, 3]
            phi_centers = [-67.5, -22.5, 22.5, 67.5]

        self.pers_centers = self.get_pers_centers(num_cols, phi_centers)
        
    def get_pers_centers(self, num_cols, phi_centers):
        pers_centers = []
        for i, n_cols in enumerate(num_cols):
            PHI = phi_centers[i]
            for j in np.arange(n_cols):
                theta_interval = 360 / n_cols
                THETA = j * theta_interval + theta_interval / 2
                pers_centers.append((THETA, PHI))
        return pers_centers
    
    def projection(self, x_inp, proj_y2x, proj_x2y, THETA, PHI, FOVy, FOVx, HWy):
        device = x_inp.device
        HWx = x_inp.shape[-2:]

        gridy = utils.make_coord(HWy).to(device)
        gridy2x, masky = proj_y2x(gridy, HWy, HWx, THETA, PHI, FOVy, FOVx, device)
        gridy2x, masky = gridy2x.view(*HWy, 2), masky.view(1, *HWy)

        y_inp = F.grid_sample(
                    x_inp, gridy2x.unsqueeze(0).flip(-1),
                    mode='nearest', padding_mode='reflection',
                    align_corners=False)
        y_inp = y_inp * masky

        gridx = utils.make_coord(HWx, flatten=False).to(device)
        _, maskx = proj_x2y(gridx, HWy, HWx, THETA, PHI, FOVy, FOVx, device)
        maskx = maskx.view(1, *HWx)

        return y_inp, masky, maskx

    def erp2pers(self, 
                 erp_inp, 
                 pers_size=(512//8, 512//8)):
        
        # # Upscale the ERP input before projection
        # erp_inp = F.interpolate(erp_inp,
        #     size=(erp_inp.shape[-2]*4, erp_inp.shape[-1]*4), 
        #     mode='bicubic', align_corners=True)
        
        # ERP2Pers Projection
        pers_outs = []
        for THETA, PHI in self.pers_centers:
            pers_out, _, _ = self.projection(
                erp_inp, gridy2x_erp2pers, gridy2x_pers2erp, 
                THETA, PHI, FOVy=90, FOVx=360, HWy=pers_size)
            pers_outs.append(pers_out)

        return pers_outs
    
    def pers2erp(self,
                 pers_inps,
                 erp_size=(1024//8, 2048//8)):
        
        erp_outs = None
        count = None
        for i, pers_inp in enumerate(pers_inps):
            # # Upscale the Pers input before projection
            # pers_inp = F.interpolate(pers_inp,
            #     size=(pers_inp.shape[-2]*4, pers_inp.shape[-1]*4),
            #     mode='bicubic', align_corners=True)
        
            # Pers2ERP Projection
            THETA, PHI = self.pers_centers[i]
            erp_out, erp_mask, _ = self.projection(
                pers_inp, gridy2x_pers2erp, gridy2x_erp2pers,
                THETA, PHI, FOVy=360, FOVx=90, HWy=erp_size)
            
            if erp_outs is None:
                erp_outs = erp_out
            else:
                erp_outs += erp_out
            if count is None:
                count = erp_mask
            else:
                count += erp_mask
        erp_outs /= count

        return erp_outs

    @torch.no_grad()
    def text2erp(self, 
                 prompts, 
                 negative_prompts='', 
                 height=2048, width=4096, 
                 num_inference_steps=50,
                 guidance_scale=7.5):
        
        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define ERP panorama grid and get Perspective views
        # erp_latent = torch.randn((1, self.unet.in_channels, height//8, width//8), device=self.device)
        pers_latents = [torch.randn((1, self.unet.in_channels, 512//8, 512//8), device=self.device) for i in range(18)]

        # pers_latents = self.erp2pers(erp_latent)
        
        self.scheduler.set_timesteps(num_inference_steps)

        with torch.autocast('cuda'):
            for i, t in enumerate(tqdm(self.scheduler.timesteps)):
                
                # pers_latents = self.erp2pers(erp_latent)
                denoised_pers_latents = []
                for latent_view in pers_latents:

                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latent_view] * 2)

                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']

                    denoised_pers_latents.append(latents_view_denoised)

                erp_latent = self.pers2erp(denoised_pers_latents, erp_size=(height, width))
                pers_latents = self.erp2pers(erp_latent)

        pers_img_inps = []
        for pers_latent in pers_latents:
            pers_img = self.decode_latents(pers_latent)
            pers_img_inps.append(pers_img)

        erp_img = T.ToPILImage()(self.pers2erp(pers_img_inps, erp_size=(height, width))[0].cpu())
        return erp_img, pers_img_inps


In [None]:
seed_everything(2024)

device = torch.device('cuda')

# opt variables
sd_version = '2.0'
prompt = 'realistic firenze cityscape'
negative = ''

steps = 50
outfile = f'/content/imgs/{prompt}.png'

sd = ERPMultiDiffusion(num_rows=4, device=device, sd_version=sd_version)

erp_img, pers_imgs = sd.text2erp(prompt, negative, num_inference_steps=steps)

# save image
erp_img.save(outfile)

for i, pers_img in enumerate(pers_imgs):
  pers_img = T.ToPILImage()(pers_img[0].cpu())
  pers_img.save(f'/content/imgs/{prompt}_pers_{i}.png')