# Install Requiremnets

In [None]:
!sudo apt-get install ninja-build
!ninja --version

In [None]:
import os

if os.path.exists('/content/nvdiffrast'):
  %rm -rf /content/nvdiffrast

!git clone --recursive https://github.com/NVlabs/nvdiffrast
%cd /content/nvdiffrast
!pip install .
%cd /content/

In [None]:
import cv2

import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import nvdiffrast.torch as dr

from tqdm import tqdm
from IPython.display import Image
from torchvision.transforms import ToPILImage, ToTensor
from einops import rearrange, reduce, repeat

# Load OpenGL Context

In [None]:
use_opengl = False # On T4 GPU, only False works, but rasterizer works much better if = True
glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()

# MultiDiffusion Code

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler

def get_views(panorama_height, panorama_width, window_size=64, stride=8):
    panorama_height /= 8
    panorama_width /= 8
    num_blocks_height = (panorama_height - window_size) // stride + 1
    num_blocks_width = (panorama_width - window_size) // stride + 1
    total_num_blocks = int(num_blocks_height * num_blocks_width)
    views = []
    for i in range(total_num_blocks):
        h_start = int((i // num_blocks_width) * stride)
        h_end = h_start + window_size
        w_start = int((i % num_blocks_width) * stride)
        w_end = w_start + window_size
        views.append((h_start, h_end, w_start, w_end))
    return views

class MultiDiffusion(nn.Module):
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__()

        self.device = device
        self.sd_version = sd_version

        print(f'[INFO] loading stable diffusion...')
        if hf_key is not None:
            print(f'[INFO] using hugging face custom model key: {hf_key}')
            model_key = hf_key
        elif self.sd_version == '2.1':
            model_key = "stabilityai/stable-diffusion-2-1-base"
        elif self.sd_version == '2.0':
            model_key = "stabilityai/stable-diffusion-2-base"
        elif self.sd_version == '1.5':
            model_key = "runwayml/stable-diffusion-v1-5"
        else:
            raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')

        # Create model
        self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
        self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
        self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
        self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device)

        self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")

        print(f'[INFO] loaded stable diffusion!')

    @torch.no_grad()
    def get_text_embeds(self, prompt, negative_prompt):
        # prompt, negative_prompt: [str]

        # Tokenize text and get embeddings
        text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
                                    truncation=True, return_tensors='pt')
        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]

        # Do the same for unconditional embeddings
        uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
                                      return_tensors='pt')

        uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

        # Cat for final embeddings
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
        return text_embeddings

    @torch.no_grad()
    def decode_latents(self, latents):
        latents = 1 / 0.18215 * latents
        imgs = self.vae.decode(latents).sample
        imgs = (imgs / 2 + 0.5).clamp(0, 1)
        return imgs

    @torch.no_grad()
    def text2panorama(self, prompts, negative_prompts='', height=512, width=2048, num_inference_steps=50,
                      guidance_scale=7.5, visualize_intermidiates=False):

        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define panorama grid and get views
        latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)
        views = get_views(height, width)
        count = torch.zeros_like(latent)
        value = torch.zeros_like(latent)

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.autocast('cuda'):

            if visualize_intermidiates is True:
                intermidiate_imgs = []
                
            for i, t in enumerate(tqdm(self.scheduler.timesteps)):
                count.zero_()
                value.zero_()

                for h_start, h_end, w_start, w_end in views:
                    # TODO we can support batches, and pass multiple views at once to the unet
                    latent_view = latent[:, :, h_start:h_end, w_start:w_end]

                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latent_view] * 2)

                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']
                    value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
                    count[:, :, h_start:h_end, w_start:w_end] += 1

                # take the MultiDiffusion step
                latent = torch.where(count > 0, value / count, value)

                # visualize intermidiate timesteps
                if visualize_intermidiates is True:
                    imgs = self.decode_latents(latent)  # [1, 3, 512, 512]
                    img = T.ToPILImage()(imgs[0].cpu())
                    intermidiate_imgs.append((i, img))

        # Img latents -> imgs
        imgs = self.decode_latents(latent)  # [1, 3, 512, 512]
        img = T.ToPILImage()(imgs[0].cpu())

        if visualize_intermidiates is True:
            intermidiate_imgs.append((len(intermidiate_imgs), img))
            return intermidiate_imgs
        else:
            return [img]

# How I Warped Your Noise (ERP-Pers.)

In [None]:
def gridy2x_pers2erp(gridy, HWy, HWx, THETA, PHI, FOVy, FOVx):
    H, W, h, w = *HWy, *HWx
    hFOVy, wFOVy = FOVy * float(H) / W, FOVy
    hFOVx, wFOVx = FOVx * float(h) / w, FOVx
    
    # gridy2x
    ### onto sphere
    gridy = gridy.reshape(-1, 2).float()
    gridy[:, 0] *= np.tan(np.radians(hFOVy / 2.0))
    gridy[:, 1] *= np.tan(np.radians(wFOVy / 2.0))
    gridy = gridy.double().flip(-1)
    
    x0 = torch.ones(gridy.shape[0], 1)
    gridy = torch.cat((x0, gridy), dim=-1)
    gridy /= torch.norm(gridy, p=2, dim=-1, keepdim=True)
    
    ### rotation
    y_axis = np.array([0.0, 1.0, 0.0], np.float64)
    z_axis = np.array([0.0, 0.0, 1.0], np.float64)
    [R1, _] = cv2.Rodrigues(z_axis * np.radians(THETA))
    [R2, _] = cv2.Rodrigues(np.dot(R1, y_axis) * np.radians(PHI))   
    
    gridy = torch.mm(torch.from_numpy(R1), gridy.permute(1, 0)).permute(1, 0)
    gridy = torch.mm(torch.from_numpy(R2), gridy.permute(1, 0)).permute(1, 0)

    ### sphere to gridx
    lat = torch.arcsin(gridy[:, 2]) / np.pi * 2
    lon = torch.atan2(gridy[:, 1] , gridy[:, 0]) / np.pi
    gridx = torch.stack((lat, lon), dim=-1)

    # masky
    mask = torch.where(torch.abs(gridx) > 1, 0, 1)
    mask = mask[:, 0] * mask[:, 1]

    return gridx.float(), mask.float()

def gridy2x_erp2pers(gridy, HWy, HWx, THETA, PHI, FOVy, FOVx):
    H, W, h, w = *HWy, *HWx
    hFOVy, wFOVy = FOVy * float(H) / W, FOVy
    hFOVx, wFOVx = FOVx * float(h) / w, FOVx

    # gridy2x
    ### onto sphere
    gridy = gridy.reshape(-1, 2).float()
    lat = gridy[:, 0] * np.pi / 2
    lon = gridy[:, 1] * np.pi

    z0 = torch.sin(lat)
    y0 = torch.cos(lat) * torch.sin(lon)
    x0 = torch.cos(lat) * torch.cos(lon)
    gridy = torch.stack((x0, y0, z0), dim=-1).double()

    ### rotation
    y_axis = np.array([0.0, 1.0, 0.0], np.float64)
    z_axis = np.array([0.0, 0.0, 1.0], np.float64)
    [R1, _] = cv2.Rodrigues(z_axis * np.radians(THETA))
    [R2, _] = cv2.Rodrigues(np.dot(R1, y_axis) * np.radians(PHI))

    R1_inv = torch.inverse(torch.from_numpy(R1))
    R2_inv = torch.inverse(torch.from_numpy(R2))

    gridy = torch.mm(R2_inv, gridy.permute(1, 0)).permute(1, 0)
    gridy = torch.mm(R1_inv, gridy.permute(1, 0)).permute(1, 0)

    ### sphere to gridx
    z0 = gridy[:, 2] / gridy[:, 0]
    y0 = gridy[:, 1] / gridy[:, 0]
    gridx = torch.stack((z0, y0), dim=-1).float()

    # masky
    mask = torch.where(torch.abs(gridx) > 1, 0, 1)
    mask = mask[:, 0] * mask[:, 1]
    mask *= torch.where(gridy[:, 0] < 0, 0, 1)

    return gridx.float(), mask.float()

# 1. Naive Implementation
- Just sample init. pers noises
- and denoise independently

In [None]:
def cond_noise_sampling(src_noise, level=3):
    B, C, H, W = src_noise.shape
    up_factor = 2 ** level
    upscaled_means = F.interpolate(src_noise, scale_factor=(up_factor, up_factor), mode='nearest')
    up_H = up_factor * H
    up_W = up_factor * W
    raw_rand = torch.randn(B, C, up_H, up_W)
    Z_mean = raw_rand.unfold(2, up_factor, up_factor).unfold(3, up_factor, up_factor).mean((4, 5))
    Z_mean = F.interpolate(Z_mean, scale_factor=up_factor, mode='nearest')
    mean_removed_rand = raw_rand - Z_mean
    up_noise = upscaled_means / up_factor + mean_removed_rand
    return up_noise


def get_pers_view_noises(src_noise, src_cfg, tgt_cfg):
    up_level = src_cfg["up_level"]
    B, C, H_src, W_src = src_noise.shape
    if up_level > 1:
        src_up_noise = cond_noise_sampling(src_noise, level=up_level)
    elif up_level == 1:
        src_up_noise = src_noise
    else:
        NotImplementedError
    
    # Defining the partitioned polygons for target noise map
    H_tgt, W_tgt = tgt_cfg["size"]
    tr_H_tgt, tr_W_tgt = 2 * H_tgt + 1, 2 * W_tgt + 1
    
    i, j = torch.meshgrid(
        torch.arange(tr_H_tgt, dtype=torch.int32),
        torch.arange(tr_W_tgt, dtype=torch.int32),
        indexing="ij")
    mesh_idxs = torch.stack((i, j), dim=-1)
    reshaped_mesh_idxs = mesh_idxs.reshape(-1,2)
    
    front_tri_verts = torch.tensor([
        [0, 1, 1+tr_W_tgt], [0, tr_W_tgt, 1+tr_W_tgt], 
        [tr_W_tgt, 1+tr_W_tgt, 1+2*tr_W_tgt], [tr_W_tgt, 2*tr_W_tgt, 1+2*tr_W_tgt]])
    per_tri_verts = torch.cat((front_tri_verts, front_tri_verts + 1),dim=0)
    width = torch.arange(0, tr_W_tgt - 1, 2)
    height = torch.arange(0, tr_H_tgt-1, 2) * (tr_W_tgt)
    start_idxs = (width[None,...] + height[...,None]).reshape(-1,1)
    vertices = (start_idxs.repeat(1,8)[...,None] + per_tri_verts[None,...]).reshape(-1,3)
    
    # Perspective view vertex grid
    pers_i, pers_j = torch.meshgrid(
        torch.linspace(-1, 1, tr_H_tgt),
        torch.linspace(-1, 1, tr_W_tgt),
        indexing="ij")
    pers_grid = torch.stack((pers_i, pers_j), dim=-1)

    res = []
    for theta, phi in tgt_cfg["view_dirs"]:
        # Warping Rasterized Pers. grid
        pers2erp_grid, _ = gridy2x_pers2erp(gridy=pers_grid,
            HWy=(2*H_tgt, 2*W_tgt), HWx=(2*H_src, 2*W_src),
            THETA=theta, PHI=phi, FOVy=90, FOVx=360)
        
        tgt_to_src_map = pers2erp_grid.view(tr_H_tgt, tr_W_tgt, 2)
        idx_y = reshaped_mesh_idxs[..., 0].int()
        idx_x = reshaped_mesh_idxs[..., 1].int()
        warped_coords = tgt_to_src_map[idx_y, idx_x].fliplr()

        len_grid = idx_y.shape[0]
        warped_vtx_pos = torch.cat((warped_coords, torch.zeros(len_grid, 1), torch.ones(len_grid, 1)), dim=-1)
        warped_vtx_pos = warped_vtx_pos[None,...].to("cuda")
        vertices = vertices.int().to("cuda")

        resolution = [H_src * (2 ** up_level), W_src * (2 ** up_level)]
        with torch.no_grad():
            rast_out, _ = dr.rasterize(glctx, warped_vtx_pos, vertices, resolution=resolution)
        rast = rast_out[:,:,:,3:].permute(0,3,1,2).to(torch.int64)

        # Finding pixel indices in cond-upsampled map
        indices = (rast - 1) // 8 + 1 # there is 8 triangles per pixel
        src_up_noise_flat = src_up_noise.reshape(B*C, -1).cpu()
        ones_flat = torch.ones_like(src_up_noise_flat[:1])
        indices_flat = indices.reshape(1, -1).cpu().to(torch.int64)

        # Get warped target noise
        fin_v_val = torch.zeros(B*C, H_tgt*W_tgt+1).scatter_add_(1, index=indices_flat.repeat(B*C, 1), src=src_up_noise_flat)[..., 1:]
        fin_v_num = torch.zeros(1, H_tgt*W_tgt+1).scatter_add_(1, index=indices_flat, src=ones_flat)[..., 1:]
        assert fin_v_num.min() != 0, ValueError(f"{theta},{phi}")

        final_values = fin_v_val / torch.sqrt(fin_v_num)
        tgt_warped_noise = final_values.reshape(B, C, H_tgt, W_tgt).float()
        tgt_warped_noise = tgt_warped_noise.cuda()
        res.append(tgt_warped_noise)

    return res

In [None]:
class ERPMultiDiffusion_v3(MultiDiffusion):
    
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__(device, sd_version, hf_key)
        self.src_cfg = {"up_level": 3}
        self.tgt_cfg = {
            # "view_dirs": [
            #     (0,0), (0,30), (0,-30), (0,60), (0,-60), (0,90), (0,-90), (0,120), (0,-120),
            #     (22.5, 0), (22.5, 30), (22.5, -30), (22.5, 60), (22.5, -60), (22.5, 90), (22.5, -90), (22.5, 90), (22.5, 120), (22.5, -120),
            #     (45.0, 0), (45.0, 30), (45.0, -30), (45.0, 60), (45.0, -60), (45.0, 90), (45.0, -90), (45.0, 90), (45.0, 120), (45.0, -120),
            # ]
            "view_dirs": [
                (0.0, -45.0), (30.0, -45.0), (60.0, -45.0), (90.0, -45.0), (-30.0, -45.0), (-60.0, -45.0), (-90.0, -45.0),
                (0.0, -22.5), (30.0, -22.5), (60.0, -22.5), (90.0, -22.5), (-30.0, -22.5), (-60.0, -22.5), (-90.0, -22.5),
                (0.0, 0.0), (30.0, 0.0), (60.0, 0.0), (90.0, 0.0), (-30.0, 0.0), (-60.0, 0.0), (-90.0, 0.0),
                (0.0, 22.5), (30.0, 22.5), (60.0, 22.5), (90.0, 22.5), (-30.0, 22.5), (-60.0, 22.5), (-90.0, 22.5),
                (0.0, 45.0), (30.0, 45.0), (60.0, 45.0), (90.0, 45.0), (-30.0, 45.0), (-60.0, 45.0), (-90.0, 45.0),
            ]
        }
    
    @torch.no_grad()
    def text2erp(self,
                  prompts, 
                 negative_prompts='', 
                 height=512, width=1024, 
                 num_inference_steps=50,
                 guidance_scale=7.5,
                 visualize_intermidiates=False,
                 save_dir=None):
        
        if isinstance(prompts, str):
            prompts = [prompts]
        
        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define ERP source noise
        latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.no_grad():
            if visualize_intermidiates is True:
                intermidiate_imgs = []
                
            self.tgt_cfg["size"] = (64, 64)
            pers_latents = get_pers_view_noises(latent.to("cpu"), self.src_cfg, self.tgt_cfg)

            for i, t in enumerate(tqdm(self.scheduler.timesteps)):

                denoised_pers_latents = []

                for latent_view in pers_latents:
                    
                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latent_view] * 2)

                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']

                    denoised_pers_latents.append(latents_view_denoised)
                
                pers_latents = denoised_pers_latents

                # visualize intermidiate timesteps
                if visualize_intermidiates is True:
                    pers_img_inps = []
                    for k, pers_latent in enumerate(pers_latents):
                        pers_img = self.decode_latents(pers_latent)
                        pers_img_inps.append((self.tgt_cfg['view_dirs'][k], pers_img))
                    intermidiate_imgs.append((i+1, pers_img_inps))
                
                if save_dir is not None:
                    # save image
                    if os.path.exists(f"{save_dir}/{i:0>2}") is False:
                        os.mkdir(f"{save_dir}/{i:0>2}/")
                    for v, im in pers_img_inps:
                        theta, phi = v
                        im = ToPILImage()(im[0].cpu())
                        im.save(f'/{save_dir}/{i:0>2}/pers_{theta}_{phi}.png')
        
        return intermidiate_imgs


In [None]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True
    
seed_everything(2024)
device = torch.device('cuda')

# opt variables
sd_version = '2.0'
negative = ''
steps = 50

In [None]:
prompt = "Realistic cityscape of Florence."

H, W = 1024, 2048
sd = ERPMultiDiffusion_v3(device=device, sd_version=sd_version)

dir_name = "hiwyn"

if os.path.exists(f'/content/{dir_name}/') is False:
    os.mkdir(f'/content/{dir_name}/')

if os.path.exists(f'/content/{dir_name}/{prompt.split(" ")[0]}/') is False:
    os.mkdir(f'/content/{dir_name}/{prompt.split(" ")[0]}/')

dir = f'/content/{dir_name}/{prompt.split(" ")[0]}'
outputs = sd.text2erp(prompt, negative, height=H, width=W, num_inference_steps=steps, visualize_intermidiates=True, save_dir=dir)

In [None]:
# save image
dir = f'/content/{dir_name}/{prompt.split(" ")[0]}'
for i, vim in tqdm(outputs):
    if os.path.exists(f"{dir}/{i}") is False:
        os.mkdir(f"{dir}/{i}/")
    for v, im in vim:
        theta, phi = v
        im = ToPILImage()(im[0].cpu())
        im.save(f'/{dir}/{i}/pers_{theta}_{phi}.png')

# 2. Denoising dependently

In [None]:
def cond_noise_sampling(src_noise, level=3):
    B, C, H, W = src_noise.shape
    up_factor = 2 ** level
    upscaled_means = F.interpolate(src_noise, scale_factor=(up_factor, up_factor), mode='nearest')
    up_H = up_factor * H
    up_W = up_factor * W
    raw_rand = torch.randn(B, C, up_H, up_W)
    Z_mean = raw_rand.unfold(2, up_factor, up_factor).unfold(3, up_factor, up_factor).mean((4, 5))
    Z_mean = F.interpolate(Z_mean, scale_factor=up_factor, mode='nearest')
    mean_removed_rand = raw_rand - Z_mean
    up_noise = upscaled_means / up_factor + mean_removed_rand
    return up_noise


def get_pers_view_noises(src_noise, up_level, tgt_cfg):
    if up_level > 1:
        B, C, H_src, W_src = src_noise.shape
        src_up_noise = cond_noise_sampling(src_noise, level=up_level)
    elif up_level == 1:
        B, C, H_src, W_src = src_noise.shape
        H_src, W_src = H_src//8, W_src//8 # hard coding
        src_up_noise = src_noise
    else:
        NotImplementedError
    
    # Defining the partitioned polygons for target noise map
    H_tgt, W_tgt = tgt_cfg["size"]
    tr_H_tgt, tr_W_tgt = 2 * H_tgt + 1, 2 * W_tgt + 1
    
    i, j = torch.meshgrid(
        torch.arange(tr_H_tgt, dtype=torch.int32),
        torch.arange(tr_W_tgt, dtype=torch.int32),
        indexing="ij")
    mesh_idxs = torch.stack((i, j), dim=-1)
    reshaped_mesh_idxs = mesh_idxs.reshape(-1,2)
    
    front_tri_verts = torch.tensor([
        [0, 1, 1+tr_W_tgt], [0, tr_W_tgt, 1+tr_W_tgt], 
        [tr_W_tgt, 1+tr_W_tgt, 1+2*tr_W_tgt], [tr_W_tgt, 2*tr_W_tgt, 1+2*tr_W_tgt]])
    per_tri_verts = torch.cat((front_tri_verts, front_tri_verts + 1),dim=0)
    width = torch.arange(0, tr_W_tgt - 1, 2)
    height = torch.arange(0, tr_H_tgt-1, 2) * (tr_W_tgt)
    start_idxs = (width[None,...] + height[...,None]).reshape(-1,1)
    vertices = (start_idxs.repeat(1,8)[...,None] + per_tri_verts[None,...]).reshape(-1,3)
    
    # Perspective view vertex grid
    pers_i, pers_j = torch.meshgrid(
        torch.linspace(-1, 1, tr_H_tgt),
        torch.linspace(-1, 1, tr_W_tgt),
        indexing="ij")
    pers_grid = torch.stack((pers_i, pers_j), dim=-1)

    res = []
    inds = []
    for theta, phi in tgt_cfg["view_dirs"]:
        # Warping Rasterized Pers. grid
        pers2erp_grid, _ = gridy2x_pers2erp(gridy=pers_grid,
            HWy=(2*H_tgt, 2*W_tgt), HWx=(2*H_src, 2*W_src),
            THETA=theta, PHI=phi, FOVy=90, FOVx=360)
        
        tgt_to_src_map = pers2erp_grid.view(tr_H_tgt, tr_W_tgt, 2)
        idx_y = reshaped_mesh_idxs[..., 0].int()
        idx_x = reshaped_mesh_idxs[..., 1].int()
        warped_coords = tgt_to_src_map[idx_y, idx_x].fliplr()

        len_grid = idx_y.shape[0]
        warped_vtx_pos = torch.cat((warped_coords, torch.zeros(len_grid, 1), torch.ones(len_grid, 1)), dim=-1)
        warped_vtx_pos = warped_vtx_pos[None,...].to("cuda")
        vertices = vertices.int().to("cuda")

        resolution = [H_src * (2 ** up_level), W_src * (2 ** up_level)]
        with torch.no_grad():
            rast_out, _ = dr.rasterize(glctx, warped_vtx_pos, vertices, resolution=resolution)
        rast = rast_out[:,:,:,3:].permute(0,3,1,2).to(torch.int64)

        # Finding pixel indices in cond-upsampled map
        indices = (rast - 1) // 8 + 1 # there is 8 triangles per pixel
        src_up_noise_flat = src_up_noise.reshape(B*C, -1).cpu()
        ones_flat = torch.ones_like(src_up_noise_flat[:1])
        indices_flat = indices.reshape(1, -1).cpu().to(torch.int64)

        # Get warped target noise
        fin_v_val = torch.zeros(B*C, H_tgt*W_tgt+1).scatter_add_(1, index=indices_flat.repeat(B*C, 1), src=src_up_noise_flat)[..., 1:]
        fin_v_num = torch.zeros(1, H_tgt*W_tgt+1).scatter_add_(1, index=indices_flat, src=ones_flat)[..., 1:]
        assert fin_v_num.min() != 0, ValueError(f"{theta},{phi}")

        final_values = fin_v_val / torch.sqrt(fin_v_num)
        tgt_warped_noise = final_values.reshape(B, C, H_tgt, W_tgt).float()
        tgt_warped_noise = tgt_warped_noise.cuda()
        res.append(tgt_warped_noise)
        inds.append(indices.reshape(*resolution).to(torch.int64))
        fin_v_num = fin_v_num.reshape(1, 1, H_tgt, W_tgt).cuda()

    return res, inds, src_up_noise, fin_v_num

In [None]:
""" Hard-coding by myself
"""
def denoise_erp_up_noise(residual_pers_noises, pers_indices, erp_up_noise):
    B, C, H_tgt, W_tgt = residual_pers_noises[0].shape
    H_up_src, W_up_src = erp_up_noise.shape[-2:]

    residual_erp_noise = torch.zeros(B, C, H_up_src, W_up_src, device=erp_up_noise.device)
    residual_erp_counts = torch.zeros(B, C, H_up_src, W_up_src, device=erp_up_noise.device)

    for residual_pers_noise, indices in zip(residual_pers_noises, pers_indices):
        residual_pers_noise_flat = residual_pers_noise.reshape(B, C, -1) # (B, C, H_tgt*W_tgt)
        # residual_pers_noise_flat: (B, C, H_tgt*W_tgt) - containing residual noises on perspective grid
        # indices: (H_up_src, W_up_src) - containing index on perspective grid to map each ERP pixel to Perspective grid - range: [0, H_tgt*W_tgt], 0 means no-mapping

        residual_erp_noise = torch.zeros(B, C, H_up_src, W_up_src, device=erp_up_noise.device)
        residual_erp_count = torch.zeros(B, C, H_up_src, W_up_src, device=erp_up_noise.device)

        indices = indices - 1
        for b in range(0, B):
            for c in range(0, C):
                for i in range(0, H_up_src):
                    for j in range(0, W_up_src):
                        idx = indices[i, j]
                        if idx == -1:
                            pass
                        residual_erp_noise[b, c, i, j] += residual_pers_noise_flat[b, c, idx]
                        residual_erp_count[b, c, i, j] += 1
        
    residual_erp_noise = residual_erp_noise / residual_erp_counts
    erp_up_noise_denoised = erp_up_noise - residual_erp_noise
    return erp_up_noise_denoised

In [None]:
"""By GPT
"""
def denoise_erp_up_noise(residual_pers_noises, pers_indices, erp_up_noise):
    B, C, H_tgt, W_tgt = residual_pers_noises[0].shape
    H_up_src, W_up_src = erp_up_noise.shape[-2:]

    # 결과를 저장할 텐서 초기화
    residual_erp_noise = torch.zeros(B, C, H_up_src, W_up_src, device=erp_up_noise.device)
    residual_erp_counts = torch.zeros(B, C, H_up_src, W_up_src, device=erp_up_noise.device)

    # `indices` 값이 1부터 시작하므로 1을 빼서 0 기반 인덱스로 변환
    pers_indices = [indices - 1 for indices in pers_indices]

    # 모든 데이터 처리
    for residual_pers_noise, indices in zip(residual_pers_noises, pers_indices):
        residual_pers_noise_flat = residual_pers_noise.reshape(B, C, -1)  # (B, C, H_tgt*W_tgt)

        # 유효한 인덱스만 필터링
        valid_mask = indices >= 0  # 유효한 매핑 위치
        valid_indices = indices[valid_mask].view(-1).cuda()  # Flattened valid indices

        # 각 ERP 픽셀에 대응하는 residual 값과 카운트를 더하기
        for b in range(B):
            for c in range(C):
                residual_erp_noise[b, c][valid_mask] += residual_pers_noise_flat[b, c, valid_indices]
                residual_erp_counts[b, c][valid_mask] += 1

    # 평균을 계산하여 denoise 적용  
    residual_erp_counts = torch.clamp(residual_erp_counts, min=1)
    residual_erp_noise = residual_erp_noise / residual_erp_counts
    erp_up_noise_denoised = erp_up_noise - residual_erp_noise
    return erp_up_noise_denoised


In [None]:
# def denoise_erp_up_noise(residual_pers_noises, pers_indices, erp_up_noise):
#     B, C, H_tgt, W_tgt = residual_pers_noises[0].shape
#     H_up_src, W_up_src = erp_up_noise.shape[-2:]

#     # Ensure erp_up_noise is on the correct device
#     device = erp_up_noise.device

#     # Initialize result tensors
#     residual_erp_noise = torch.zeros(B, C, H_up_src, W_up_src, device=device)
#     residual_erp_counts = torch.zeros(B, C, H_up_src, W_up_src, device=device)

#     # Loop over residual_pers_noises and indices
#     for residual_pers_noise, indices in zip(residual_pers_noises, pers_indices):
#         residual_pers_noise_flat = residual_pers_noise.reshape(B, C, -1)  # (B, C, H_tgt * W_tgt)

#         # Adjust indices (-1 for no mapping)
#         indices = indices.to(device) - 1  # Move indices to the correct device and adjust

#         # Create a mask to filter out invalid indices (-1)
#         valid_mask = indices >= 0

#         # Ensure indices are valid for `scatter_add_` and `gather`
#         indices_valid = torch.where(valid_mask, indices, torch.zeros_like(indices))

#         # Expand indices and mask for batch and channel dimensions
#         indices_expanded = indices_valid.unsqueeze(0).unsqueeze(0).expand(B, C, H_up_src, W_up_src)
#         valid_mask_expanded = valid_mask.unsqueeze(0).unsqueeze(0).expand(B, C, H_up_src, W_up_src)

#         # Map ERP noise using valid indices
#         gathered_residuals = torch.zeros_like(residual_erp_noise)
#         for b in range(B):
#             for c in range(C):
#                 gathered_residuals[b, c] = residual_pers_noise_flat[b, c, indices_valid.to(torch.int64)]

#         # Accumulate noise and counts
#         residual_erp_noise += gathered_residuals * valid_mask_expanded
#         residual_erp_counts += valid_mask_expanded.float()

#     # Avoid division by zero
#     residual_erp_counts = torch.where(residual_erp_counts == 0, torch.ones_like(residual_erp_counts), residual_erp_counts)

#     # Compute averaged noise and denoise the input
#     residual_erp_noise = residual_erp_noise / residual_erp_counts
#     erp_up_noise_denoised = erp_up_noise - residual_erp_noise

#     return erp_up_noise_denoised

In [None]:
# def denoise_erp_up_noise(residual_pers_noises, pers_indices, erp_up_noise):
#     B, C, H_tgt, W_tgt = residual_pers_noises[0].shape
#     H_up_src, W_up_src = erp_up_noise.shape[-2:]

#     # Initialize result tensors
#     residual_erp_noise = torch.zeros(B, C, H_up_src, W_up_src, device=erp_up_noise.device)
#     residual_erp_counts = torch.zeros(B, C, H_up_src, W_up_src, device=erp_up_noise.device)

#     # Loop over residual_pers_noises and indices
#     for residual_pers_noise, indices in zip(residual_pers_noises, pers_indices):
#         residual_pers_noise_flat = residual_pers_noise.reshape(B, C, -1)  # (B, C, H_tgt * W_tgt)

#         # Adjust indices (-1 for no mapping)
#         indices = indices - 1  # indices now range from -1 to (H_tgt * W_tgt - 1)

#         # Create a mask to filter out invalid indices (-1)
#         valid_mask = indices >= 0

#         # Ensure indices are valid for `scatter_add_` and `gather`
#         indices_valid = torch.where(valid_mask, indices, torch.zeros_like(indices))

#         # Expand indices and mask for batch and channel dimensions
#         indices_expanded = indices_valid.unsqueeze(0).unsqueeze(0).expand(B, C, H_up_src, W_up_src)
#         valid_mask_expanded = valid_mask.unsqueeze(0).unsqueeze(0).expand(B, C, H_up_src, W_up_src)

#         # Map ERP noise using valid indices
#         gathered_residuals = torch.zeros_like(residual_erp_noise, device=erp_up_noise.device)
#         for b in range(B):
#             for c in range(C):
#                 gathered_residuals[b, c] = residual_pers_noise_flat[b, c, indices_valid]

#         # Accumulate noise and counts
#         residual_erp_noise += gathered_residuals * valid_mask_expanded
#         residual_erp_counts += valid_mask_expanded.float()

#     # Avoid division by zero
#     residual_erp_counts = torch.where(residual_erp_counts == 0, torch.ones_like(residual_erp_counts), residual_erp_counts)

#     # Compute averaged noise and denoise the input
#     residual_erp_noise = residual_erp_noise / residual_erp_counts
#     erp_up_noise_denoised = erp_up_noise - residual_erp_noise

#     return erp_up_noise_denoised

In [None]:
def denoise_erp_up_noise(residual_pers_noises, pers_indices, erp_up_noise):
    B, C, H_up_src, W_up_src = erp_up_noise.shape
    device = erp_up_noise.device

    # Initialize tensors to store accumulated noise and count
    residual_erp_noise = torch.zeros(B, C, H_up_src, W_up_src, device=device)
    residual_erp_counts = torch.zeros(B, C, H_up_src, W_up_src, device=device)

    # Preprocess indices to handle zero-indexing and no-mapping
    for residual_pers_noise, indices in zip(residual_pers_noises, pers_indices):
        # Reshape residual noise to flat representation
        residual_pers_noise_flat = residual_pers_noise.reshape(B, C, -1)
        
        # Adjust indices: subtract 1 to handle zero-indexing, use mask for valid indices
        indices_adj = indices - 1
        valid_mask = indices_adj != -1

        # Use advanced indexing for efficient computation
        for b in range(B):
            for c in range(C):
                # Extract valid indices for this batch and channel
                curr_indices = indices_adj[valid_mask[b]]
                curr_noise = residual_pers_noise_flat[b, c, curr_indices]
                
                # Increment noise and count using advanced indexing
                residual_erp_noise[b, c][valid_mask[b]] += curr_noise
                residual_erp_counts[b, c][valid_mask[b]] += 1

    # Avoid division by zero
    residual_erp_counts = torch.clamp(residual_erp_counts, min=1)
    residual_erp_noise /= residual_erp_counts

    # Compute denoised noise
    erp_up_noise_denoised = erp_up_noise - residual_erp_noise
    return erp_up_noise_denoised

In [None]:
# def denoise_erp_up_noise(residual_pers_noises, pers_indices, erp_up_noise):
#     B, C, H_tgt, W_tgt = residual_pers_noises[0].shape
#     H_up_src, W_up_src = erp_up_noise.shape[-2:]

#     # Initialize result tensors
#     residual_erp_noise = torch.zeros_like(erp_up_noise, device=erp_up_noise.device)
#     residual_erp_counts = torch.zeros_like(erp_up_noise, device=erp_up_noise.device)

#     # Loop over residual_pers_noises and indices
#     for residual_pers_noise, indices in zip(residual_pers_noises, pers_indices):
#         # Flatten residual noise and indices
#         residual_pers_noise_flat = residual_pers_noise.reshape(B, C, -1)  # (B, C, H_tgt * W_tgt)
#         indices = indices.view(-1)  # Flatten indices to (H_tgt * W_tgt)

#         # Create a mask for valid indices
#         valid_mask = (indices >= 0)

#         # Filter valid indices and corresponding residuals
#         valid_indices = indices[valid_mask]  # Only valid indices
#         valid_residuals = residual_pers_noise_flat[:, :, valid_indices]  # (B, C, valid_count)

#         # Scatter add residuals to the ERP noise grid
#         residual_erp_noise.scatter_add_(-1, valid_indices.expand_as(valid_residuals), valid_residuals)

#         # Scatter add counts
#         residual_erp_counts.scatter_add_(-1, valid_indices.expand_as(valid_residuals), torch.ones_like(valid_residuals))

#     # Avoid division by zero
#     residual_erp_counts = torch.where(residual_erp_counts == 0, torch.ones_like(residual_erp_counts), residual_erp_counts)

#     # Compute averaged noise and denoise the input
#     residual_erp_noise = residual_erp_noise / residual_erp_counts
#     erp_up_noise_denoised = erp_up_noise - residual_erp_noise

#     return erp_up_noise_denoised

In [None]:
def denoise_erp_up_noise(residual_pers_noises, pers_indices, erp_up_noise):
    B, C, H_tgt, W_tgt = residual_pers_noises[0].shape
    H_up_src, W_up_src = erp_up_noise.shape[-2:]

    # Initialize result tensors
    residual_erp_noise = torch.zeros(B, C, H_up_src, W_up_src, device=erp_up_noise.device)
    residual_erp_counts = torch.zeros(B, C, H_up_src, W_up_src, device=erp_up_noise.device)

    # Loop over residual_pers_noises and indices
    for residual_pers_noise, indices in zip(residual_pers_noises, pers_indices):
        residual_pers_noise_flat = residual_pers_noise.reshape(B, C, -1)  # (B, C, H_tgt * W_tgt)

        # Adjust indices (-1 for no mapping)
        indices = indices - 1  # indices now range from -1 to (H_tgt * W_tgt - 1)

        # Create a mask to filter out invalid indices (-1)
        valid_mask = indices >= 0

        # Expand indices for batch and channel dimensions
        indices_expanded = indices.unsqueeze(0).unsqueeze(0).expand(B, C, H_up_src, W_up_src)
        valid_mask_expanded = valid_mask.unsqueeze(0).unsqueeze(0).expand(B, C, H_up_src, W_up_src)

        # Use advanced indexing to accumulate residual noise
        residual_erp_noise.scatter_add_(
            dim=2,
            index=indices_expanded * valid_mask_expanded,
            src=residual_pers_noise_flat.gather(dim=2, index=indices_expanded),
        )

        # Count valid mappings
        residual_erp_counts.scatter_add_(
            dim=2,
            index=indices_expanded * valid_mask_expanded,
            src=valid_mask_expanded.float()
        )

    # Avoid division by zero
    residual_erp_counts = torch.where(residual_erp_counts == 0, torch.ones_like(residual_erp_counts), residual_erp_counts)

    # Compute averaged noise and denoise the input
    residual_erp_noise = residual_erp_noise / residual_erp_counts
    erp_up_noise_denoised = erp_up_noise - residual_erp_noise

    return erp_up_noise_denoised

In [None]:
class ERPMultiDiffusion_v3_2(MultiDiffusion):
    
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__(device, sd_version, hf_key)
        self.up_level = 3
        self.tgt_cfg = {
            # "view_dirs": [
            #     (0,0), (0,30), (0,-30), (0,60), (0,-60), (0,90), (0,-90), (0,120), (0,-120),
            #     (22.5, 0), (22.5, 30), (22.5, -30), (22.5, 60), (22.5, -60), (22.5, 90), (22.5, -90), (22.5, 90), (22.5, 120), (22.5, -120),
            #     (45.0, 0), (45.0, 30), (45.0, -30), (45.0, 60), (45.0, -60), (45.0, 90), (45.0, -90), (45.0, 90), (45.0, 120), (45.0, -120),
            # ]
            # "view_dirs": [
            #     (0.0, -45.0), (30.0, -45.0), (60.0, -45.0), (90.0, -45.0), (-30.0, -45.0), (-60.0, -45.0), (-90.0, -45.0),
            #     (0.0, -22.5), (30.0, -22.5), (60.0, -22.5), (90.0, -22.5), (-30.0, -22.5), (-60.0, -22.5), (-90.0, -22.5),
            #     (0.0, 0.0), (30.0, 0.0), (60.0, 0.0), (90.0, 0.0), (-30.0, 0.0), (-60.0, 0.0), (-90.0, 0.0),
            #     (0.0, 22.5), (30.0, 22.5), (60.0, 22.5), (90.0, 22.5), (-30.0, 22.5), (-60.0, 22.5), (-90.0, 22.5),
            #     (0.0, 45.0), (30.0, 45.0), (60.0, 45.0), (90.0, 45.0), (-30.0, 45.0), (-60.0, 45.0), (-90.0, 45.0),
            # ]
            # "view_dirs": [
            #     (0.0, -22.5), (30.0, -22.5), (60.0, -22.5), (-30.0, -22.5), (-60.0, -22.5),
            #     (0.0, 0.0), (30.0, 0.0), (60.0, 0.0), (-30.0, 0.0), (-60.0, 0.0),
            #     (0.0, 22.5), (30.0, 22.5), (60.0, 22.5), (-30.0, 22.5), (-60.0, 22.5),
            # ]
            "view_dirs": [
                (0.0, -22.5), (30.0, -22.5), (-30.0, -22.5),
                (0.0, 0.0), (30.0, 0.0), (-30.0, 0.0),
                (0.0, 22.5), (30.0, 22.5), (-30.0, 22.5),
            ]
        }
    
    @torch.no_grad()
    def text2erp(self,
                 prompts, 
                 negative_prompts='', 
                 height=512, width=1024, 
                 num_inference_steps=50,
                 guidance_scale=7.5,
                 visualize_intermidiates=False,
                 save_dir=None):
        
        if isinstance(prompts, str):
            prompts = [prompts]
        
        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define ERP source noise
        latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.no_grad():
            
            if visualize_intermidiates is True:
                intermidiate_imgs = []
                
            self.tgt_cfg["size"] = (64, 64)
            pers_latents, pers_indices, erp_up_noise, fin_v_num =\
                get_pers_view_noises(latent.to("cpu"), self.up_level, self.tgt_cfg)

            for i, t in enumerate(tqdm(self.scheduler.timesteps)):

                denoised_pers_latents = []
                residual_pers_noises = []

                for latent_view in pers_latents:
                    
                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latent_view] * 2)

                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']
                    denoised_pers_latents.append(latents_view_denoised)
                    
                    # compute residual noise
                    residual_noise = latent_view - latents_view_denoised
                    residual_noise = residual_noise * torch.sqrt(fin_v_num)
                    residual_pers_noises.append(residual_noise)

                erp_up_noise_denoised = denoise_erp_up_noise(residual_pers_noises, pers_indices, erp_up_noise)

                pers_latents, _, erp_up_noise, _ =\
                    get_pers_view_noises(erp_up_noise_denoised, 1, self.tgt_cfg)

                # visualize intermidiate timesteps
                if visualize_intermidiates is True:
                    pers_img_inps = []
                    for k, pers_latent in enumerate(pers_latents):
                        pers_img = self.decode_latents(pers_latent)
                        pers_img_inps.append((self.tgt_cfg['view_dirs'][k], pers_img))
                    intermidiate_imgs.append((i+1, pers_img_inps))
                
                if save_dir is not None:
                    # save image
                    if os.path.exists(f"{save_dir}/{i:0>2}") is False:
                        os.mkdir(f"{save_dir}/{i:0>2}/")
                    for v, im in pers_img_inps:
                        theta, phi = v
                        im = ToPILImage()(im[0].cpu())
                        im.save(f'/{save_dir}/{i:0>2}/pers_{theta}_{phi}.png')
        
        return intermidiate_imgs


In [None]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True
    
seed_everything(2024)
device = torch.device('cuda')

# opt variables
sd_version = '2.0'
negative = ''
steps = 50

In [None]:
prompt = "Realistic cityscape of Florence."

H, W = 1024, 2048
sd = ERPMultiDiffusion_v3_2(device=device, sd_version=sd_version)

dir_name = "hiwyn"

if os.path.exists(f'/content/{dir_name}/') is False:
    os.mkdir(f'/content/{dir_name}/')

if os.path.exists(f'/content/{dir_name}/{prompt.split(" ")[0]}/') is False:
    os.mkdir(f'/content/{dir_name}/{prompt.split(" ")[0]}/')

dir = f'/content/{dir_name}/{prompt.split(" ")[0]}'
outputs = sd.text2erp(prompt, negative, height=H, width=W, num_inference_steps=steps, visualize_intermidiates=True, save_dir=dir)