In [None]:
!pip install transformers diffusers tensorboardX
!git clone https://github.com/ByeongHyunPak/omni-proj.git

import os
os.chdir('/content/omni-proj/omni_proj')

In [None]:
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torchvision.transforms as T

import utils
from utils import gridy2x_erp2pers, gridy2x_pers2erp
from multidiffusions import MultiDiffusion, seed_everything, get_views

seed_everything(2024)
device = torch.device('cuda')

# opt variables
sd_version = '2.0'
negative = ''
steps = 50

In [None]:
class ERPMultiDiffusion_v2(MultiDiffusion): 
	def __init__(self, latent_rotation, circular_padding, **kwargs):
		super(ERPMultiDiffusion_v2, self).__init__(**kwargs)

		self.latent_rotation = latent_rotation
		self.circular_padding = circular_padding

	@torch.no_grad()
	def decode_latents(self, latents):
		imgs = super().decode_latents(latents)
		if self.circular_padding:
			w = imgs.shape[-1] // 4
			imgs = imgs[:,:,:, w:-w]
		return imgs

	@torch.no_grad()
	def text2erp(self, 
				 prompts, 
				 negative_prompts='', 
				 height=512, width=1024, 
				 num_inference_steps=50,
				 guidance_scale=7.5,
				 visualize_intermidiates=False):
		
		if isinstance(prompts, str):
			prompts = [prompts]

		if isinstance(negative_prompts, str):
			negative_prompts = [negative_prompts]

		# Prompts -> text embeds
		text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

		# Define panorama grid and get views
		latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)
		if self.circular_padding:
			latent = torch.cat((latent[:,:,:,width // 16:], latent, latent[:,:,:,:width // 16]), dim=-1) # - circular padding
			views = get_views(height, 2 * width)
		else:
			views = get_views(height, width)
		count = torch.zeros_like(latent)
		value = torch.zeros_like(latent)

		self.scheduler.set_timesteps(num_inference_steps)

		with torch.no_grad():
			
			if visualize_intermidiates:
				intermidiate_imgs = []
			
			for i, t in enumerate(tqdm(self.scheduler.timesteps)):
				count.zero_()
				value.zero_()

				for h_start, h_end, w_start, w_end in views:
					# TODO we can support batches, and pass multiple views at once to the unet
					latent_view = latent[:, :, h_start:h_end, w_start:w_end]

					# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
					latent_model_input = torch.cat([latent_view] * 2)

					# predict the noise residual
					noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

					# perform guidance
					noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
					noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

					# compute the denoising step with the reference model
					latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']
					value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
					count[:, :, h_start:h_end, w_start:w_end] += 1

				# take the MultiDiffusion step
				latent = torch.where(count > 0, value / count, value)

				if self.circular_padding:
					latent = latent[:,:,:,width//16:-width//16] # - circular unpadding
				if self.latent_rotation == "horizontal_only":
					w = width//8
					latent = torch.roll(latent, int(w / num_inference_steps), dims=-1) # - latent rotation
				elif self.latent_rotation == "vertical_too":
					h, w = height//8, width//8
					latent = torch.roll(latent, (int(h / num_inference_steps), int(w / num_inference_steps)), dims=(-2,-1)) # - latent rotation
				if self.circular_padding:
					latent = torch.cat((latent[:,:,:,width // 16:], latent, latent[:,:,:,:width // 16]), dim=-1) # - circular padding

				# visualize intermidiate timesteps
				if visualize_intermidiates:
					imgs = self.decode_latents(latent)  # [1, 3, 512, 1024]
					img = T.ToPILImage()(imgs[0].cpu())
					intermidiate_imgs.append((i, img))

		# Img latents -> imgs
		imgs = self.decode_latents(latent)  # [1, 3, 512, 1024]
		img = T.ToPILImage()(imgs[0].cpu())

		if visualize_intermidiates:
			intermidiate_imgs.append((len(intermidiate_imgs), img))
			return intermidiate_imgs
		else:
			return [img]

In [None]:
""" ERPMultiDiffusion_v2 Exp.
"""
prompt  = "360-degree panoramic image, Japanese anime style downtown city street"
H = 512
W = 1024

dir = f'/content/emd2/{prompt.split(" ")[3]}'

if os.path.exists(f'/content/emd2/') is False:
    os.mkdir(f'/content/emd2/')

if os.path.exists(dir) is False:
    os.mkdir(dir)


latent_rotation = "vertical_too" # horizontal_only / vertical_too / none
circular_padding = True

sd = ERPMultiDiffusion_v2(latent_rotation, circular_padding, device=device, sd_version=sd_version)

img = sd.text2erp(prompt, negative, H, W, steps, visualize_intermidiates=True)

# save image

if len(img) == 1:
    img[0].save(f'{dir}/output.png')
else:
    for t, im in tqdm(img):
        im.save(f'{dir}/output_t={t:02d}.png')

a - latent_rotation=True; circular_padding=True
b - latent_rotation=False; circular_padding=True
c - latent_rotation=True; circular_padding=False
d - latent_rotation=False; circular_padding=False