In [None]:
!pip install transformers diffusers tensorboardX
!git clone https://github.com/ByeongHyunPak/omni-proj.git

import os
os.chdir('/content/omni-proj/omni_proj')

# MultiDiffusion for 360 degree Panorama image

In [None]:
import numpy as np

def icosahedron_sample_camera():
    # reference: https://en.wikipedia.org/wiki/Regular_icosahedron
    radius_circumscribed = np.sin(2 * np.pi / 5.0)
    radius_inscribed = np.sqrt(3) / 12.0 * (3 + np.sqrt(5))
    radius_midradius = np.cos(np.pi / 5.0)
    theta_step = 2.0 * np.pi / 5.0

    thetas = []
    phis = []
    for triangle_index in range(20):
        # 1) the up 5 triangles
        if 0 <= triangle_index <= 4:
            theta = - np.pi + theta_step / 2.0 + triangle_index * theta_step
            phi = np.pi / 2 - np.arccos(radius_inscribed / radius_circumscribed)

        # 2) the middle 10 triangles
        # 2-0) middle-up triangles
        if 5 <= triangle_index <= 9:
            triangle_index_temp = triangle_index - 5
            theta = - np.pi + theta_step / 2.0 + triangle_index_temp * theta_step
            phi = np.pi / 2.0 - np.arccos(radius_inscribed / radius_circumscribed) - 2 * np.arccos(radius_inscribed / radius_midradius)

        # 2-1) the middle-down triangles
        if 10 <= triangle_index <= 14:
            triangle_index_temp = triangle_index - 10
            theta = - np.pi + triangle_index_temp * theta_step
            phi = -(np.pi / 2.0 - np.arccos(radius_inscribed / radius_circumscribed) - 2 * np.arccos(radius_inscribed / radius_midradius))

        # 3) the down 5 triangles
        if 15 <= triangle_index <= 19:
            triangle_index_temp = triangle_index - 15
            theta = - np.pi + triangle_index_temp * theta_step
            phi = - (np.pi / 2 - np.arccos(radius_inscribed / radius_circumscribed))

        theta = theta * 180 / np.pi
        phi = phi * 180 / np.pi

        thetas.append(theta)
        phis.append(phi)

    return list(zip(np.array(thetas), np.array(phis)))

def horizon_sample_camera(num_rows):

    if num_rows == 3:
        num_cols = [3, 5, 3]
        phi_centers = [-67.5, 0.0, 67.5]
        # num_cols = [3, 5, 3]
        # phi_centers = [-45.0, 0.0, 45.0]

    elif num_rows == 4:
        num_cols = [1, 4, 4, 1]
        phi_centers = [-90.0, -22.5, 22.5, 90.0]
        # num_cols = [3, 6, 6, 3]
        # phi_centers = [-67.5, -22.5, 22.5, 67.5]

    pers_centers = []
    for i, n_cols in enumerate(num_cols):
        PHI = phi_centers[i]
        for j in np.arange(n_cols):
            theta_interval = 360 / n_cols
            THETA = j * theta_interval + theta_interval / 2
            pers_centers.append((THETA, PHI))

    return pers_centers

In [None]:
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torchvision.transforms as T

import utils
from utils import gridy2x_erp2pers, gridy2x_pers2erp
from multidiffusions import MultiDiffusion, seed_everything

class ERPMultiDiffusion(MultiDiffusion):
    def __init__(self, get_tagent_method, pre_upsampling=False, projection_mode='nearest', initialize_mode='pers_first', **kwargs):

        super(ERPMultiDiffusion, self).__init__(**kwargs)   

        self.get_tagent_method = get_tagent_method
        self.pre_upsampling = pre_upsampling
        self.projection_mode = projection_mode
        self.initialize_mode = initialize_mode

        if get_tagent_method == "icosahedron_sample":
            self.pers_centers = icosahedron_sample_camera()

        elif get_tagent_method.startswith("horizon_sample"):
            self.pers_centers = horizon_sample_camera(int(get_tagent_method.split("-")[-1]))
        
        else:
            raise NotImplementedError
    
    def projection(self, x_inp, proj_type, THETA, PHI, FOVy, FOVx, HWy):

        if proj_type == 'e2p':
            proj_y2x = gridy2x_erp2pers
            proj_x2y = gridy2x_erp2pers
        elif proj_type == 'p2e':
            proj_y2x = gridy2x_pers2erp
            proj_x2y = gridy2x_pers2erp
        else:
            raise NotImplementedError

        device = x_inp.device
        HWx = x_inp.shape[-2:]

        gridy = utils.make_coord(HWy).to(device)
        gridy2x, masky = proj_y2x(gridy, HWy, HWx, THETA, PHI, FOVy, FOVx, device)
        gridy2x, masky = gridy2x.view(*HWy, 2), masky.view(1, *HWy)

        align_corners = True if self.projection_mode == 'nearest' else False
        y_inp = F.grid_sample(
            x_inp, gridy2x.unsqueeze(0).flip(-1),
            mode=self.projection_mode, padding_mode='reflection',
            align_corners=align_corners).clamp_(x_inp.min(), x_inp.max())
        y_inp = y_inp * masky

        gridx = utils.make_coord(HWx, flatten=False).to(device)
        _, maskx = proj_x2y(gridx, HWy, HWx, THETA, PHI, FOVy, FOVx, device)
        maskx = maskx.view(1, *HWx)

        return y_inp, masky, maskx

    def erp2pers(self, erp_inp, pers_size=(512//8, 512//8)):

        pers_outs = []

        if self.pre_upsampling:
            # Upscale the ERP input before projection
            erp_inp = F.interpolate(erp_inp,
                size=(erp_inp.shape[-2]*4, erp_inp.shape[-1]*4), 
                mode='bicubic', align_corners=True)

        for THETA, PHI in self.pers_centers:
            pers_out, _, _ = self.projection(
                erp_inp, 'e2p', THETA, PHI, FOVy=90, FOVx=360, HWy=pers_size)
            pers_outs.append(pers_out)

        return pers_outs
    
    def pers2erp(self, pers_inps, erp_size=(1024//8, 2048//8)):
        
        erp_outs = None
        count = None

        for i, pers_inp in enumerate(pers_inps):
            # Upscale the Pers input before projection
            if self.pre_upsampling:
                pers_inp = F.interpolate(pers_inp,
                    size=(pers_inp.shape[-2]*4, pers_inp.shape[-1]*4),
                    mode='bicubic', align_corners=True)
        
            THETA, PHI = self.pers_centers[i]
            erp_out, erp_mask, _ = self.projection(
                pers_inp, 'p2e', THETA, PHI, FOVy=360, FOVx=90, HWy=erp_size)
            
            if erp_outs is None:
                erp_outs = erp_out
            else:
                erp_outs += erp_out

            if count is None:
                count = erp_mask
            else:
                count += erp_mask

        erp_outs = torch.where(count > 0, erp_outs / count, erp_outs)

        return erp_outs

    @torch.no_grad()
    def text2erp(self, 
                 prompts, 
                 negative_prompts='', 
                 height=1024, width=2048, 
                 num_inference_steps=50,
                 guidance_scale=7.5,
                 visualize_intermidiates=False):
        
        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # # Initialize ERP noise
        if self.initialize_mode == "erp_first":
            erp_latent = torch.randn((1, self.unet.in_channels, height//8, width//8), device=self.device)

        # Initialize Pers. noises and project onto ERP canvas
        elif self.initialize_mode == "pers_first":
            pers_latents = [torch.randn((1, self.unet.in_channels, 512//8, 512//8), device=self.device) for i in range(len(self.pers_centers))]
            erp_latent = self.pers2erp(pers_latents, erp_size=(height//8, width//8))

        else:
            raise NotImplementedError
        
        self.scheduler.set_timesteps(num_inference_steps)

        with torch.no_grad():

            if visualize_intermidiates is True:
                  intermidiate_imgs = []

            for i, t in enumerate(tqdm(self.scheduler.timesteps)):

                denoised_pers_latents = []
                
                # get latents on pers. grid
                pers_latents = self.erp2pers(erp_latent)

                for latent_view in pers_latents:

                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latent_view] * 2)

                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']

                    denoised_pers_latents.append(latents_view_denoised)

                pers_latents = denoised_pers_latents

                erp_latent = self.pers2erp(denoised_pers_latents, erp_size=(height//8, width//8))

                # visualize intermidiate timesteps
                if visualize_intermidiates is True:
                    pers_img_inps = []
                    for pers_latent in pers_latents:
                        pers_img = self.decode_latents(pers_latent)
                        pers_img_inps.append(pers_img)
                    erp_img = T.ToPILImage()(self.pers2erp(pers_img_inps, erp_size=(height, width))[0].cpu())
                    intermidiate_imgs.append((i+1, erp_img))

        if visualize_intermidiates is True:
            return intermidiate_imgs, pers_img_inps
        
        else:
            pers_img_inps = []
            for pers_latent in pers_latents:
                pers_img = self.decode_latents(pers_latent)
                pers_img_inps.append(pers_img)
        
            erp_img = T.ToPILImage()(self.pers2erp(pers_img_inps, erp_size=(height, width))[0].cpu())
            return [erp_img], pers_img_inps


In [None]:
seed_everything(2024)
device = torch.device('cuda')

# opt variables
sd_version = '2.0'
negative = ''
steps = 50

In [None]:
""" MultiDiffusion exp.
"""
prompt  = "firenze cityscpae"
H = 512
W = 512

if os.path.exists(f'/content/md/') is False:
    os.mkdir(f'/content/md/')

if os.path.exists(f'/content/md/{prompt.split(" ")[0]}/') is False:
    os.mkdir(f'/content/md/{prompt.split(" ")[0]}/')

sd = MultiDiffusion(device, sd_version)

img = sd.text2panorama(prompt, negative, H, W, steps, visualize_intermidiates=True)

# save image
dir = f'/content/md/{prompt.split(" ")[0]}'
if len(img) == 1:
    img[0].save(f'{dir}/output.png')
else:
    for t, im in tqdm(img):
        im.save(f'{dir}/output_t={t:02d}.png')

In [None]:
""" ERP MultiDiffusion exp.
"""

prompt = "firenze cityscpae"
# prompt = "firenze cityscape"
# prompt = "Realistic cityscape of Florence."
# prompt = "Street art area under a cloudy sky."
# prompt = "Arctic wilderness, aurora, trail."


H = 1024
W = 2048


get_tagent_method = "horizon_sample-3"
# get_tagent_method = "icosahedron_sample"

pre_upsampling = False
projection_mode = "nearest"
initialize_mode = "pers_first"

sd = ERPMultiDiffusion(get_tagent_method, pre_upsampling, projection_mode, initialize_mode, device=device, sd_version=sd_version)

dir_name = "erp_md_(d)"
   
if os.path.exists(f'/content/{dir_name}/') is False:
    os.mkdir(f'/content/{dir_name}/')

if os.path.exists(f'/content/{dir_name}/{prompt.split(" ")[0]}/') is False:
    os.mkdir(f'/content/{dir_name}/{prompt.split(" ")[0]}/')

erp_img, pers_imgs = sd.text2erp(prompt, negative, height=H, width=W, num_inference_steps=steps, visualize_intermidiates=True)

# save image
dir = f'/content/{dir_name}/{prompt.split(" ")[0]}'

if len(erp_img) == 1:
  erp_img[0].save(f'{dir}/erp_output.png')
else:
  for t, im in tqdm(erp_img):
    im.save(f'/{dir}/erp_output_t={t:02d}.png')

if os.path.exists(f'/{dir}/pers/') is False:
    os.mkdir(f'/{dir}/pers/')

for i, pers_img in enumerate(pers_imgs):
  pers_img = T.ToPILImage()(pers_img[0].cpu())
  pers_img.save(f'/{dir}/pers/pers_output_i={i:02d}.png')