In [None]:
!pip install transformers diffusers tensorboardX
!git clone https://github.com/ByeongHyunPak/omni-proj.git

import os
os.chdir('/content/omni-proj/omni-proj')

<a href="https://colab.research.google.com/github/ByeongHyunPak/omni-proj/blob/main/scratchpad.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MultiDiffusion

In [None]:
seed_everything(2024)

device = torch.device('cuda')

# opt variables
sd_version = '2.0'
prompt = "Under a breathtaking night sky filled with stars and the Milky Way, a road marked with a 'Safety' sign leads to a quiet building, creating a serene contrast between the universe's vastness and human presence."
negative = ''
H = 512
W = 512
steps = 50


if os.path.exists(f'/content/md/{prompt[0:5]}/') is False:
    os.mkdir(f'/content/md/{prompt[0:5]}/')

outfile = f'/content/md/{prompt[0:5]}/output.png'

sd = MultiDiffusion(device, sd_version)

img = sd.text2panorama(prompt, negative, H, W, steps, visualize_intermidiates=True)

# save image
if len(img) == 1:
    img[0].save(outfile)
else:
    for t, im in tqdm(img):
        im.save(f'/content/md/{prompt[0:5]}/{t:03d}_output.png')

# MultiDiffusion for 360 degree Panorama image

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler

# suppress partial model loading warning
logging.set_verbosity_error()

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import utils
from utils import gridy2x_erp2pers, gridy2x_pers2erp
from multidiffusion import MultiDiffusion, seed_everything

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True

class ERPMultiDiffusion(MultiDiffusion):
    def __init__(self, num_rows=4, **kwargs):
        super(ERPMultiDiffusion, self).__init__(**kwargs)      
        
        if num_rows == 3:
            num_cols = [3, 5, 3]
            phi_centers = [-67.5, 0.0, 67.5]

        elif num_rows == 4:
            num_cols = [3, 6, 6, 3]
            phi_centers = [-67.5, -22.5, 22.5, 67.5]

        self.pers_centers = self.get_pers_centers(num_cols, phi_centers)
        
    def get_pers_centers(self, num_cols, phi_centers):
        pers_centers = []
        for i, n_cols in enumerate(num_cols):
            PHI = phi_centers[i]
            for j in np.arange(n_cols):
                theta_interval = 360 / n_cols
                THETA = j * theta_interval + theta_interval / 2
                pers_centers.append((THETA, PHI))
        return pers_centers
    
    def projection(self, x_inp, proj_y2x, proj_x2y, THETA, PHI, FOVy, FOVx, HWy):
        device = x_inp.device
        HWx = x_inp.shape[-2:]

        gridy = utils.make_coord(HWy).to(device)
        gridy2x, masky = proj_y2x(gridy, HWy, HWx, THETA, PHI, FOVy, FOVx, device)
        gridy2x, masky = gridy2x.view(*HWy, 2), masky.view(1, *HWy)

        y_inp = F.grid_sample(
                    x_inp, gridy2x.unsqueeze(0).flip(-1),
                    mode='nearest', padding_mode='reflection',
                    align_corners=True).clamp_(x_inp.min(), x_inp.max())
        y_inp = y_inp * masky

        gridx = utils.make_coord(HWx, flatten=False).to(device)
        _, maskx = proj_x2y(gridx, HWy, HWx, THETA, PHI, FOVy, FOVx, device)
        maskx = maskx.view(1, *HWx)

        return y_inp, masky, maskx

    def erp2pers(self, 
                 erp_inp, 
                 pers_size=(512//8, 512//8)):
        
        # Upscale the ERP input before projection
        # erp_inp = F.interpolate(erp_inp,
        #     size=(erp_inp.shape[-2]*4, erp_inp.shape[-1]*4), 
        #     mode='bicubic', align_corners=True)
        
        # ERP2Pers Projection
        pers_outs = []
        for THETA, PHI in self.pers_centers:
            pers_out, _, _ = self.projection(
                erp_inp, gridy2x_erp2pers, gridy2x_pers2erp, 
                THETA, PHI, FOVy=90, FOVx=360, HWy=pers_size)
            pers_outs.append(pers_out)

        return pers_outs
    
    def pers2erp(self,
                 pers_inps,
                 erp_size=(1024//8, 2048//8)):
        
        erp_outs = None
        count = None
        for i, pers_inp in enumerate(pers_inps):
            # Upscale the Pers input before projection
            # pers_inp = F.interpolate(pers_inp,
            #     size=(pers_inp.shape[-2]*4, pers_inp.shape[-1]*4),
            #     mode='bicubic', align_corners=True)
        
            # Pers2ERP Projection
            THETA, PHI = self.pers_centers[i]
            erp_out, erp_mask, _ = self.projection(
                pers_inp, gridy2x_pers2erp, gridy2x_erp2pers,
                THETA, PHI, FOVy=360, FOVx=90, HWy=erp_size)
            
            if erp_outs is None:
                erp_outs = erp_out
            else:
                erp_outs += erp_out
            if count is None:
                count = erp_mask
            else:
                count += erp_mask
        erp_outs = torch.where(count > 0, erp_outs / count, erp_outs)

        return erp_outs

    @torch.no_grad()
    def text2erp(self, 
                 prompts, 
                 negative_prompts='', 
                 height=1024, width=2048, 
                 num_inference_steps=50,
                 guidance_scale=7.5,
                 visualize_intermidiates=False):
        
        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define ERP panorama grid and get Perspective views
        erp_latent = torch.randn((1, self.unet.in_channels, height//8, width//8), device=self.device)
        # pers_latents = [torch.randn((1, self.unet.in_channels, 512//8, 512//8), device=self.device) for i in range(18)]
        
        self.scheduler.set_timesteps(num_inference_steps)

        # with torch.autocast('cuda'):
        with torch.no_grad():
            if visualize_intermidiates is True:
                  intermidiate_imgs = []
            pers_latents = self.erp2pers(erp_latent)
            for i, t in enumerate(tqdm(self.scheduler.timesteps)):
                
                denoised_pers_latents = []
                for latent_view in pers_latents:

                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latent_view] * 2)

                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']

                    denoised_pers_latents.append(latents_view_denoised)

                # if i % 5 ==  0:
                #   erp_latent = self.pers2erp(denoised_pers_latents, erp_size=(height//8, width//8))
                #   pers_latents = self.erp2pers(erp_latent)
                # else:
                #   pers_latents = denoised_pers_latents

                if i != len(self.scheduler.timesteps) - 1:
                    erp_latent = self.pers2erp(denoised_pers_latents, erp_size=(height//8, width//8))
                    pers_latents = self.erp2pers(erp_latent)
                else:
                    pers_latents = denoised_pers_latents
                
                if i == len(self.scheduler.timesteps) - 1:
                  pers_latents = denoised_pers_latents
                elif (i < 10) or (i > 40):
                    erp_latent = self.pers2erp(denoised_pers_latents, erp_size=(height//8, width//8))
                    pers_latents = self.erp2pers(erp_latent)
                else:
                    pers_latents = denoised_pers_latents

                # visualize intermidiate timesteps
                if visualize_intermidiates is True:
                    pers_img_inps = []
                    for pers_latent in pers_latents:
                        pers_img = self.decode_latents(pers_latent)
                        pers_img_inps.append(pers_img)
                    erp_img = T.ToPILImage()(self.pers2erp(pers_img_inps, erp_size=(height, width))[0].cpu())
                    intermidiate_imgs.append((i, erp_img))

        pers_img_inps = []
        for pers_latent in pers_latents:
            pers_img = self.decode_latents(pers_latent)
            pers_img_inps.append(pers_img)
        
        erp_img = T.ToPILImage()(self.pers2erp(pers_img_inps, erp_size=(height, width))[0].cpu())

        if visualize_intermidiates is True:
            intermidiate_imgs.append((len(intermidiate_imgs), erp_img))
            return intermidiate_imgs, pers_img_inps
        else:
          return [erp_img], pers_img_inps


In [None]:
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler

# suppress partial model loading warning
logging.set_verbosity_error()

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import utils
from utils import gridy2x_erp2pers, gridy2x_pers2erp
from multidiffusion import MultiDiffusion, seed_everything

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True

class singleERPMultiDiffusion(MultiDiffusion):
    def __init__(self, num_rows=4, **kwargs):
        super(singleERPMultiDiffusion, self).__init__(**kwargs)      
        
        if num_rows == 3:
            num_cols = [1, 6, 1]
            phi_centers = [-90.0, 0.0, 90.0]

        elif num_rows == 4:
            num_cols = [3, 6, 6, 3]
            phi_centers = [-67.5, -22.5, 22.5, 67.5]

        self.pers_centers = self.get_pers_centers(num_cols, phi_centers)
        
    def get_pers_centers(self, num_cols, phi_centers):
        pers_centers = []
        for i, n_cols in enumerate(num_cols):
            PHI = phi_centers[i]
            for j in np.arange(n_cols):
                theta_interval = 360 / n_cols
                THETA = j * theta_interval + theta_interval / 2
                pers_centers.append((THETA, PHI))
                break
            break
        return pers_centers
    
    def projection(self, x_inp, proj_y2x, proj_x2y, THETA, PHI, FOVy, FOVx, HWy):
        device = x_inp.device
        HWx = x_inp.shape[-2:]

        gridy = utils.make_coord(HWy).to(device)
        gridy2x, masky = proj_y2x(gridy, HWy, HWx, THETA, PHI, FOVy, FOVx, device)
        gridy2x, masky = gridy2x.view(*HWy, 2), masky.view(1, *HWy)

        y_inp = F.grid_sample(
                    x_inp, gridy2x.unsqueeze(0).flip(-1),
                    mode='nearest', padding_mode='reflection',
                    align_corners=False).clamp_(x_inp.min(), x_inp.max())
        y_inp = y_inp * masky

        gridx = utils.make_coord(HWx, flatten=False).to(device)
        _, maskx = proj_x2y(gridx, HWy, HWx, THETA, PHI, FOVy, FOVx, device)
        maskx = maskx.view(1, *HWx)

        return y_inp, masky, maskx

    def erp2pers(self, 
                 erp_inp, 
                 pers_size=(512//8, 512//8)):
        
        # Upscale the ERP input before projection
        # erp_inp = F.interpolate(erp_inp,
        #     size=(erp_inp.shape[-2]*4, erp_inp.shape[-1]*4), 
        #     mode='bicubic', align_corners=True)
        
        # ERP2Pers Projection
        pers_outs = []
        for THETA, PHI in self.pers_centers:
            pers_out, _, _ = self.projection(
                erp_inp, gridy2x_erp2pers, gridy2x_pers2erp, 
                THETA, PHI, FOVy=90, FOVx=360, HWy=pers_size)
            pers_outs.append(pers_out)

        return pers_outs
    
    def pers2erp(self,
                 pers_inps,
                 erp_size=(1024//8, 2048//8)):
        
        erp_outs = None
        count = None
        for i, pers_inp in enumerate(pers_inps):
            # Upscale the Pers input before projection
            # pers_inp = F.interpolate(pers_inp,
            #     size=(pers_inp.shape[-2]*4, pers_inp.shape[-1]*4),
            #     mode='bicubic', align_corners=True)
        
            # Pers2ERP Projection
            THETA, PHI = self.pers_centers[i]
            erp_out, erp_mask, _ = self.projection(
                pers_inp, gridy2x_pers2erp, gridy2x_erp2pers,
                THETA, PHI, FOVy=360, FOVx=90, HWy=erp_size)
            
            if erp_outs is None:
                erp_outs = erp_out
            else:
                erp_outs += erp_out
            if count is None:
                count = erp_mask
            else:
                count += erp_mask
        erp_outs = torch.where(count > 0, erp_outs / count, erp_outs)

        return erp_outs

    @torch.no_grad()
    def text2erp(self, 
                 prompts, 
                 negative_prompts='', 
                 height=1024, width=2048, 
                 num_inference_steps=50,
                 guidance_scale=7.5,
                 visualize_intermidiates=False):
        
        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define ERP panorama grid and get Perspective views
        erp_latent = torch.randn((1, self.unet.in_channels, height//8, width//8), device=self.device)
        # pers_latents = [torch.randn((1, self.unet.in_channels, 512//8, 512//8), device=self.device) for i in range(18)]
        
        self.scheduler.set_timesteps(num_inference_steps)

        # with torch.autocast('cuda'):
        with torch.no_grad():
            if visualize_intermidiates is True:
                  intermidiate_imgs = []
            pers_latents = self.erp2pers(erp_latent)
            for i, t in enumerate(tqdm(self.scheduler.timesteps)):
                
                denoised_pers_latents = []
                for latent_view in pers_latents:

                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latent_view] * 2)

                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']

                    denoised_pers_latents.append(latents_view_denoised)

                erp_latent = self.pers2erp(denoised_pers_latents, erp_size=(height//8, width//8))
                pers_latents = self.erp2pers(erp_latent)
                # pers_latents = denoised_pers_latents

                # visualize intermidiate timesteps
                if visualize_intermidiates is True:
                    pers_img_inps = []
                    for pers_latent in pers_latents:
                        pers_img = self.decode_latents(pers_latent)
                        pers_img_inps.append(pers_img)
                    erp_img = T.ToPILImage()(self.pers2erp(pers_img_inps, erp_size=(height, width))[0].cpu())
                    intermidiate_imgs.append((i, erp_img))

        pers_img_inps = []
        for pers_latent in pers_latents:
            pers_img = self.decode_latents(pers_latent)
            pers_img_inps.append(pers_img)
        
        erp_img = T.ToPILImage()(self.pers2erp(pers_img_inps, erp_size=(height, width))[0].cpu())

        if visualize_intermidiates is True:
            intermidiate_imgs.append((len(intermidiate_imgs), erp_img))
            return intermidiate_imgs, pers_img_inps
        else:
          return [erp_img], pers_img_inps


In [None]:
seed_everything(2024)

device = torch.device('cuda')

# opt variables
sd_version = '2.0'
prompt = "Under a breathtaking night sky filled with stars and the Milky Way, a road marked with a 'Safety' sign leads to a quiet building, creating a serene contrast between the universe's vastness and human presence."
negative = ''
steps = 50

dir = 'semd'

if os.path.exists(f'/content/{dir}/') is False:
    os.mkdir(f'/content/{dir}/')

if os.path.exists(f'/content/{dir}/{prompt[0:5]}/') is False:
    os.mkdir(f'/content/{dir}/{prompt[0:5]}/')
    os.mkdir(f'/content/{dir}/{prompt[0:5]}/pers/')

outfile = f'/content/{dir}/{prompt[0:5]}/output.png'

sd = singleERPMultiDiffusion(num_rows=3, device=device, sd_version=sd_version)

erp_img, pers_imgs = sd.text2erp(prompt, negative, num_inference_steps=steps, visualize_intermidiates=True)

# save image
if len(erp_img) == 1:
  erp_img[0].save(outfile)
else:
  for t, im in tqdm(erp_img):
    im.save(f'/content/{dir}/{prompt[0:5]}/{t:03d}_output.png')

for i, pers_img in enumerate(pers_imgs):
  pers_img = T.ToPILImage()(pers_img[0].cpu())
  pers_img.save(f'/content/{dir}/{prompt[0:5]}/pers/output_{i}.png')