In [None]:
!pip install transformers diffusers tensorboardX
!git clone https://github.com/ByeongHyunPak/omni-proj.git

import os
os.chdir('/content/omni-proj/omni_proj')

# MultiDiffusion for 360 degree Panorama image

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler

# suppress partial model loading warning
logging.set_verbosity_error()

import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import utils as utils
from utils import gridy2x_erp2pers, gridy2x_pers2erp
from multidiffusion import MultiDiffusion, seed_everything

class ERPMultiDiffusion(MultiDiffusion):
    def __init__(self, num_rows=4, single_pers=False, **kwargs):
        super(ERPMultiDiffusion, self).__init__(**kwargs)      
        
        if num_rows == 3:
            num_cols = [3, 5, 3]
            phi_centers = [-67.5, 0.0, 67.5]

        elif num_rows == 4:
            # num_cols = [3, 6, 6, 3]
            # phi_centers = [-67.5, -22.5, 22.5, 67.5]

            num_cols = [1, 6, 6, 1]
            phi_centers = [-90.0, -22.5, 22.5, 90.0]

        self.single_pers = single_pers
        self.pers_centers = self.get_pers_centers(num_cols, phi_centers)
        
    def get_pers_centers(self, num_cols, phi_centers):
        pers_centers = []
        for i, n_cols in enumerate(num_cols):
            PHI = phi_centers[i]
            for j in np.arange(n_cols):
                theta_interval = 360 / n_cols
                THETA = j * theta_interval + theta_interval / 2
                pers_centers.append((THETA, PHI))
        if self.single_pers is True:
            rand_idx = random.randint(0, len(pers_centers)-1)
            pers_centers = pers_centers[rand_idx]
        return pers_centers
    
    def projection(self, x_inp, projy2x, projx2y, THETA, PHI, FOVy, FOVx, HWy):
        device = x_inp.device
        HWx = x_inp.shape[-2:]

        gridy = utils.make_coord(HWy).to(device)
        gridy2x, masky = projy2x(gridy, HWy, HWx, THETA, PHI, FOVy, FOVx, device)
        gridy2x, masky = gridy2x.view(*HWy, 2), masky.view(1, *HWy)

        y_inp = F.grid_sample(
                    x_inp, gridy2x.unsqueeze(0).flip(-1),
                    mode='nearest', padding_mode='reflection',
                    align_corners=True).clamp_(x_inp.min(), x_inp.max())
        y_inp = y_inp * masky

        gridx = utils.make_coord(HWx, flatten=False).to(device)
        _, maskx = projx2y(gridx, HWy, HWx, THETA, PHI, FOVy, FOVx, device)
        maskx = maskx.view(1, *HWx)

        return y_inp, masky, maskx

    def erp2pers(self, 
                 erp_inp, 
                 pers_size=(512//8, 512//8)):
        
        # # Upscale the ERP input before projection
        # erp_inp = F.interpolate(erp_inp,
        #     size=(erp_inp.shape[-2]*4, erp_inp.shape[-1]*4), 
        #     mode='bicubic', align_corners=True)
        
        # ERP to Perspective
        pers_outs = []
        for THETA, PHI in self.pers_centers:
            pers_out, _, _ = self.projection(
                erp_inp, gridy2x_erp2pers, gridy2x_pers2erp, 
                THETA, PHI, FOVy=90, FOVx=360, HWy=pers_size)
            pers_outs.append(pers_out)

        return pers_outs
    
    def pers2erp(self,
                 pers_inps,
                 erp_size=(1024//8, 2048//8)):
        
        erp_outs = None
        count = None
        for i, pers_inp in enumerate(pers_inps):
            # # Upscale the Pers input before projection
            # pers_inp = F.interpolate(pers_inp,
            #     size=(pers_inp.shape[-2]*4, pers_inp.shape[-1]*4),
            #     mode='bicubic', align_corners=True)
        
            # Perspective to ERP
            THETA, PHI = self.pers_centers[i]
            erp_out, erp_mask, _ = self.projection(
                pers_inp, gridy2x_pers2erp, gridy2x_erp2pers,
                THETA, PHI, FOVy=360, FOVx=90, HWy=erp_size)
            
            if erp_outs is None:
                erp_outs = erp_out
            else:
                erp_outs += erp_out
            if count is None:
                count = erp_mask
            else:
                count += erp_mask
        erp_outs = torch.where(count > 0, erp_outs / count, erp_outs)

        return erp_outs

    @torch.no_grad()
    def text2erp(self, 
                 prompts, 
                 negative_prompts='', 
                 height=1024, width=2048, 
                 num_inference_steps=50,
                 guidance_scale=7.5,
                 visualize_intermidiates=False):
        
        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Initialize ERP noise
        erp_latent = torch.randn((1, self.unet.in_channels, height//8, width//8), device=self.device)
        
        self.scheduler.set_timesteps(num_inference_steps)

        with torch.no_grad():

            if visualize_intermidiates is True:
                  intermidiate_imgs = []

            for i, t in enumerate(tqdm(self.scheduler.timesteps)):

                denoised_pers_latents = []
                
                # get latents on pers. grid
                pers_latents = self.erp2pers(erp_latent)

                for latent_view in pers_latents:

                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latent_view] * 2)

                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']

                    denoised_pers_latents.append(latents_view_denoised)

                pers_latents = denoised_pers_latents

                erp_latent = self.pers2erp(denoised_pers_latents, erp_size=(height//8, width//8))

                # visualize intermidiate timesteps
                if visualize_intermidiates is True:
                    pers_img_inps = []
                    for pers_latent in pers_latents:
                        pers_img = self.decode_latents(pers_latent)
                        pers_img_inps.append(pers_img)
                    erp_img = T.ToPILImage()(self.pers2erp(pers_img_inps, erp_size=(height, width))[0].cpu())
                    intermidiate_imgs.append((i, erp_img))

        pers_img_inps = []
        for pers_latent in pers_latents:
            pers_img = self.decode_latents(pers_latent)
            pers_img_inps.append(pers_img)
        
        erp_img = T.ToPILImage()(self.pers2erp(pers_img_inps, erp_size=(height, width))[0].cpu())

        if visualize_intermidiates is True:
            intermidiate_imgs.append((len(intermidiate_imgs), erp_img))
            return intermidiate_imgs, pers_img_inps
        else:
          return [erp_img], pers_img_inps


In [None]:
seed_everything(2024)
device = torch.device('cuda')

# opt variables
sd_version = '2.0'
negative = ''
steps = 50

In [None]:
""" MultiDiffusion exp.
"""
prompt  = "firenze cityscpae"
H = 512
W = 512

if os.path.exists(f'/content/md/') is False:
    os.mkdir(f'/content/md/')

if os.path.exists(f'/content/md/{prompt.split(' ')[0]}/') is False:
    os.mkdir(f'/content/md/{prompt.split(' ')[0]}/')

sd = MultiDiffusion(device, sd_version)

img = sd.text2panorama(prompt, negative, H, W, steps, visualize_intermidiates=True)

# save image
dir = f'/content/md/{prompt.splat(' ')[0]}'
if len(img) == 1:
    img[0].save(f'{dir}/output.png')
else:
    for t, im in tqdm(img):
        im.save(f'{dir}/output_t={t:02d}.png')

In [None]:
""" ERP MultiDiffusion exp.
"""
prompt  = "firenze cityscpae"
H = 1024
W = 2048

single_pers = False

if single_pers:
   dir_name = "single_erp_md"
else:
   dir_name = "erp_md"
   
if os.path.exists(f'/content/{dir_name}/') is False:
    os.mkdir(f'/content/{dir_name}/')

if os.path.exists(f'/content/{dir_name}/{prompt.split(' ')[0]}/') is False:
    os.mkdir(f'/content/{dir_name}/{prompt.split(' ')[0]}/')

sd = ERPMultiDiffusion(num_rows=3, device=device, sd_version=sd_version, single_pers=single_pers)

erp_img, pers_imgs = sd.text2erp(prompt, negative, num_inference_steps=steps, visualize_intermidiates=True)

# save image
dir = f'/content/{dir_name}/{prompt.split(' '[0])}'

if len(erp_img) == 1:
  erp_img[0].save(f'{dir}/erp_output.png')
else:
  for t, im in tqdm(erp_img):
    im.save(f'/{dir}/erp_output_t={t:02d}.png')

if os.path.exists(f'/{dir}/pers/') is False:
    os.mkdir(f'/{dir}/pers/')

for i, pers_img in enumerate(pers_imgs):
  pers_img = T.ToPILImage()(pers_img[0].cpu())
  pers_img.save(f'/{dir}/pers/pers_output_i={i:02d}.png')