In [None]:
%cd /content

In [None]:
import os
if os.path.exists('/content/hiwyn'):
  %rm -rf /content/hiwyn
!git clone https://github.com/ByeongHyunPak/hiwyn.git
%cd /content/hiwyn

In [None]:
!sudo apt-get install ninja-build
!ninja --version

In [None]:
if os.path.exists('/content/hiwyn/nvdiffrast'):
  %rm -rf /content/hiwyn/nvdiffrast
!git clone --recursive https://github.com/NVlabs/nvdiffrast
%cd /content/hiwyn/nvdiffrast
!pip install .
%cd /content/hiwyn

In [None]:
import cv2

import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import nvdiffrast.torch as dr

from tqdm import tqdm
from IPython.display import Image
from torchvision.transforms import ToPILImage, ToTensor
from einops import rearrange, reduce, repeat

In [None]:
use_opengl = False # On T4 GPU, only False works, but rasterizer works much better if = True
glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()

In [None]:
from multidiffusion import MultiDiffusion
from utils import cond_noise_sampling, erp2pers_latent_warping, compute_erp_up_noise_pred, compute_erp_up_noise_denoised

In [None]:
class ERPMultiDiffusion_v3_5(MultiDiffusion):
    
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__(device, sd_version, hf_key)
        self.up_level = 3
        self.views = [
                (0.0, -22.5), (15.0, -22.5), (30.0, -22.5), (-15.0, -22.5), (-30.0, -22.5),
                (0.0,   0.0), (15.0,   0.0), (30.0,   0.0), (-15.0,   0.0), (-30.0,   0.0),
                (0.0,  22.5), (15.0,  22.5), (30.0,  22.5), (-15.0,  22.5), (-30.0,  22.5),
            ]
        self.views = [
                (0.0,   -40.0), (0.0,   -30.0), (0.0,   -35.0),
                (0.0,   -30.0), (0.0,   -25.0), (0.0,   -20.0), (0.0,   -15.0),
                (0.0,   -10.0), (0.0,   -5.0), (0.0,   0.0), (0.0,   5.0),
            ]
        self.views = [
                (0.0,   0.0), (-45.0,   0.0),
            ]
        # self.views = [
        #         (0.0,   0.0),
        #     ]
        
    @torch.no_grad()
    def latent_to_image_and_save(self, i, pers_latents, save_dir, ret_imgs):

        pers_imgs = []
        for k, pers_latent in enumerate(pers_latents):
            pers_img = self.decode_latents(pers_latent)
            pers_imgs.append((self.views[k], pers_img)) # [(theta, phi), img]
        ret_imgs.append((i+1, pers_imgs)) # [i+1, [(theta, phi), img]]
        
        if save_dir is not None:
            # save image
            if os.path.exists(f"{save_dir}/{i+1:0>2}") is False:
                os.mkdir(f"{save_dir}/{i+1:0>2}/")
            for v, im in pers_imgs:
                theta, phi = v
                im = ToPILImage()(im[0].cpu())
                im.save(f'/{save_dir}/{i+1:0>2}/pers_{theta}_{phi}.png')
        
        return ret_imgs
    
    @torch.no_grad()
    def text2erp(self,
                 prompts, 
                 negative_prompts='', 
                 height=512, width=1024, 
                 num_inference_steps=50,
                 guidance_scale=7.5,
                 save_dir=None):
        
        if isinstance(prompts, str):
            prompts = [prompts]
        
        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define ERP source noise
        erp_latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)

        # Conditional white noise sampling
        erp_up_latent = cond_noise_sampling(erp_latent, self.up_level)
        count = torch.zeros_like(erp_up_latent)
        value = torch.zeros_like(erp_up_latent)

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.no_grad():
            
            HW_pers = (64, 64)

            pers_latents, erp2pers_indices, fin_v_num =\
                erp2pers_latent_warping(erp_up_latent, HW_pers, self.views, glctx)

            imgs = []
            imgs = self.latent_to_image_and_save(-1, pers_latents, save_dir, imgs)

            for i, t in enumerate(tqdm(self.scheduler.timesteps)):

                if os.path.exists(f"{save_dir}/{i+1:0>2}") is False:
                    os.mkdir(f"{save_dir}/{i+1:0>2}/")
                
                count.zero_()
                value.zero_()

                for pers_latent, erp2pers_ind, view in zip(pers_latents, erp2pers_indices, self.views):
                    
                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    pers_latent_model_input = torch.cat([pers_latent] * 2)

                    # predict the noise residual
                    pers_noise_pred = self.unet(pers_latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = pers_noise_pred.chunk(2)
                    pers_noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # # compute the denoising step with the reference model
                    pers_latent_denoised = self.scheduler.step(pers_noise_pred, t, pers_latent)['prev_sample']

                    # compute erp_denoised_latent for valid region
                    erp_up_noise_denoised, erp_up_valid_region = compute_erp_up_noise_denoised(pers_latent_denoised, erp2pers_ind, fin_v_num)

                    value[:, :] += torch.where(erp_up_valid_region, erp_up_noise_denoised[:, :], torch.zeros_like(erp_up_noise_denoised)[:, :])
                    count[:, :] += erp_up_valid_region
                    
                    theta, phi = view
                    erp_up_valid_region = erp_up_valid_region[0, 0]
                    mask = ToPILImage()(erp_up_valid_region.float().cpu())
                    mask.save(f'/{save_dir}/{i+1:0>2}/mask_{theta}_{phi}.png')

                    print()
                    print(f"{i+1} / fin_v_num               : mean {fin_v_num.mean():>8.2f} std {fin_v_num.std():>8.2f}")
                    print(f"{i+1} / pers_latent             : mean {pers_latent.abs().mean():>8.5f} std {pers_latent.abs().std():>8.5f}")
                    print(f"{i+1} / pers_noise_pred         : mean {pers_noise_pred.abs().mean():>8.5f} std {pers_noise_pred.abs().std():>8.5f}")
                    print(f"{i+1} / pers_latent_denoised    : mean {pers_latent_denoised.abs().mean():>8.5f} std {pers_latent_denoised.abs().std():>8.5f}")
                    print(f"{i+1} / erp_up_latent           : mean {erp_up_latent[:,:,erp_up_valid_region].abs().mean():>8.5f} std {erp_up_latent[:,:,erp_up_valid_region].abs().std():>8.5f}")
                    print(f"{i+1} / erp_up_noise_denoised   : mean {erp_up_noise_denoised[:,:,erp_up_valid_region].abs().mean():>8.5f} std {erp_up_noise_denoised[:,:,erp_up_valid_region].abs().std():>8.5f}")

                # average erp_up_latent on overlap region
                count_ = torch.clamp(count, min=1) 
                erp_up_latent = value / count_

                count = count.float() / count.max().float()
                count_img = ToPILImage()(count.cpu()[0][0])
                count_img.save(f'/{save_dir}/{i+1:0>2}/count.png')

                # update pers_latents from denoised erp_up_latent
                pers_latents, _, _ = erp2pers_latent_warping(erp_up_latent, HW_pers, self.views, glctx)

                imgs = self.latent_to_image_and_save(i, pers_latents, save_dir, imgs)
        
        return imgs

In [None]:
class ERPMultiDiffusion_v3_6(MultiDiffusion):
    
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__(device, sd_version, hf_key)
        self.up_level = 3
        self.views = [
                (0.0, -22.5), (15.0, -22.5), (30.0, -22.5), (-15.0, -22.5), (-30.0, -22.5),
                (0.0,   0.0), (15.0,   0.0), (30.0,   0.0), (-15.0,   0.0), (-30.0,   0.0),
                (0.0,  22.5), (15.0,  22.5), (30.0,  22.5), (-15.0,  22.5), (-30.0,  22.5),
            ]
        self.views = [
                (0.0,   -40.0), (0.0,   -30.0), (0.0,   -35.0),
                (0.0,   -30.0), (0.0,   -25.0), (0.0,   -20.0), (0.0,   -15.0),
                (0.0,   -10.0), (0.0,   -5.0), (0.0,   0.0), (0.0,   5.0),
            ]
        self.views = [
                (0.0,   0.0), (-45.0,   0.0),
            ]
        # self.views = [
        #         (0.0,   0.0),
        #     ]
        
    @torch.no_grad()
    def latent_to_image_and_save(self, i, pers_latents, save_dir, ret_imgs):

        pers_imgs = []
        for k, pers_latent in enumerate(pers_latents):
            pers_img = self.decode_latents(pers_latent)
            pers_imgs.append((self.views[k], pers_img)) # [(theta, phi), img]
        ret_imgs.append((i+1, pers_imgs)) # [i+1, [(theta, phi), img]]
        
        if save_dir is not None:
            # save image
            if os.path.exists(f"{save_dir}/{i+1:0>2}") is False:
                os.mkdir(f"{save_dir}/{i+1:0>2}/")
            for v, im in pers_imgs:
                theta, phi = v
                im = ToPILImage()(im[0].cpu())
                im.save(f'/{save_dir}/{i+1:0>2}/pers_{theta}_{phi}.png')
        
        return ret_imgs
        
    @torch.no_grad()
    def text2erp(self,
                prompts, 
                negative_prompts='', 
                height=512, width=1024, 
                num_inference_steps=50,
                guidance_scale=7.5,
                save_dir=None):
        
        if isinstance(prompts, str):
            prompts = [prompts]
        
        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define ERP source noise
        erp_latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)

        # Conditional white noise sampling
        erp_up_latent = cond_noise_sampling(erp_latent, self.up_level)
        count = torch.zeros_like(erp_up_latent)
        value = []  # 각 view의 feature를 저장하기 위한 리스트

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.no_grad():
            
            HW_pers = (64, 64)

            pers_latents, erp2pers_indices, fin_v_num =\
                erp2pers_latent_warping(erp_up_latent, HW_pers, self.views, glctx)

            imgs = []
            imgs = self.latent_to_image_and_save(-1, pers_latents, save_dir, imgs)

            for i, t in enumerate(tqdm(self.scheduler.timesteps)):

                if os.path.exists(f"{save_dir}/{i+1:0>2}") is False:
                    os.mkdir(f"{save_dir}/{i+1:0>2}/")
                
                count.zero_()
                value.clear()  # 리스트를 초기화

                for pers_latent, erp2pers_ind, view in zip(pers_latents, erp2pers_indices, self.views):
                    
                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    pers_latent_model_input = torch.cat([pers_latent] * 2)

                    # predict the noise residual
                    pers_noise_pred = self.unet(pers_latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = pers_noise_pred.chunk(2)
                    pers_noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    pers_latent_denoised = self.scheduler.step(pers_noise_pred, t, pers_latent)['prev_sample']

                    # compute erp_denoised_latent for valid region
                    erp_up_noise_denoised, erp_up_valid_region = compute_erp_up_noise_denoised(pers_latent_denoised, erp2pers_ind, fin_v_num)

                    # overlap region에 대한 feature 저장
                    value.append(torch.where(erp_up_valid_region, erp_up_noise_denoised, torch.zeros_like(erp_up_noise_denoised)))

                    count[:, :] += erp_up_valid_region

                    theta, phi = view
                    erp_up_valid_region = erp_up_valid_region[0, 0]
                    mask = ToPILImage()(erp_up_valid_region.float().cpu())
                    mask.save(f'/{save_dir}/{i+1:0>2}/mask_{theta}_{phi}.png')

                # overlap region에서 랜덤하게 feature 선택
                value_stack = torch.stack(value, dim=0)  # [num_views, 1, C, H, W]

                # 각 픽셀의 count 범위 내에서 랜덤 인덱스 생성
                valid_counts = count[0, 0].long()  # [H, W], count는 각 픽셀의 count 수

                # 랜덤 인덱스를 생성, 각 픽셀별로 0부터 (count - 1) 사이의 값을 uniform하게 샘플링
                random_indices = torch.randint(0, valid_counts.max().item(), size=valid_counts.shape, device=self.device)
                
                random_indices = random_indices.float() / random_indices.max().float()
                random_indices_img = ToPILImage()(random_indices.cpu())
                random_indices_img.save(f'/{save_dir}/{i+1:0>2}/random_ind.png')

                # 픽셀별 랜덤 인덱싱 수행
                random_indices_expanded = random_indices[None, None, None, :, :]  # Expand to [1, 1, 1, H, W]
                erp_up_latent = torch.gather(
                    value_stack, 0, random_indices_expanded.expand(value_stack.shape[0], 1, value_stack.shape[2], *random_indices.shape)
                )[0]  # Extract the [1, C, H, W] portion

                count = count.float() / count.max().float()
                count_img = ToPILImage()(count.cpu()[0][0])
                count_img.save(f'/{save_dir}/{i+1:0>2}/count.png')

                # update pers_latents from denoised erp_up_latent
                pers_latents, _, _ = erp2pers_latent_warping(erp_up_latent, HW_pers, self.views, glctx)

                imgs = self.latent_to_image_and_save(i, pers_latents, save_dir, imgs)
        
        return imgs


In [None]:
class ERPMultiDiffusion_v3_6_2(MultiDiffusion):
    
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__(device, sd_version, hf_key)
        self.up_level = 3
        self.views = [
                (0.0, -22.5), (15.0, -22.5), (30.0, -22.5), (-15.0, -22.5), (-30.0, -22.5),
                (0.0,   0.0), (15.0,   0.0), (30.0,   0.0), (-15.0,   0.0), (-30.0,   0.0),
                (0.0,  22.5), (15.0,  22.5), (30.0,  22.5), (-15.0,  22.5), (-30.0,  22.5),
            ]
        self.views = [
                (0.0,   -40.0), (0.0,   -30.0), (0.0,   -35.0),
                (0.0,   -30.0), (0.0,   -25.0), (0.0,   -20.0), (0.0,   -15.0),
                (0.0,   -10.0), (0.0,   -5.0), (0.0,   0.0), (0.0,   5.0),
            ]
        self.views = [
                (0.0,   0.0), (-45.0,   0.0),
            ]
        # self.views = [
        #         (0.0,   0.0),
        #     ]
        
    @torch.no_grad()
    def latent_to_image_and_save(self, i, pers_latents, save_dir, ret_imgs):

        pers_imgs = []
        for k, pers_latent in enumerate(pers_latents):
            pers_img = self.decode_latents(pers_latent)
            pers_imgs.append((self.views[k], pers_img)) # [(theta, phi), img]
        ret_imgs.append((i+1, pers_imgs)) # [i+1, [(theta, phi), img]]
        
        if save_dir is not None:
            # save image
            if os.path.exists(f"{save_dir}/{i+1:0>2}") is False:
                os.mkdir(f"{save_dir}/{i+1:0>2}/")
            for v, im in pers_imgs:
                theta, phi = v
                im = ToPILImage()(im[0].cpu())
                im.save(f'/{save_dir}/{i+1:0>2}/pers_{theta}_{phi}.png')
        
        return ret_imgs
        
    @torch.no_grad()
    def text2erp(self,
                prompts, 
                negative_prompts='', 
                height=512, width=1024, 
                num_inference_steps=50,
                guidance_scale=7.5,
                save_dir=None):
        
        if isinstance(prompts, str):
            prompts = [prompts]
        
        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define ERP source noise
        erp_latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)

        # Conditional white noise sampling
        erp_up_latent = cond_noise_sampling(erp_latent, self.up_level)
        count = torch.zeros_like(erp_up_latent)
        value = []  # 각 view의 feature를 저장하기 위한 리스트

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.no_grad():
            
            HW_pers = (64, 64)

            pers_latents, erp2pers_indices, fin_v_num =\
                erp2pers_latent_warping(erp_up_latent, HW_pers, self.views, glctx)

            imgs = []
            imgs = self.latent_to_image_and_save(-1, pers_latents, save_dir, imgs)

            for i, t in enumerate(tqdm(self.scheduler.timesteps)):

                if os.path.exists(f"{save_dir}/{i+1:0>2}") is False:
                    os.mkdir(f"{save_dir}/{i+1:0>2}/")
                
                count.zero_()
                value.clear()  # 리스트를 초기화

                for pers_latent, erp2pers_ind, view in zip(pers_latents, erp2pers_indices, self.views):
                    
                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    pers_latent_model_input = torch.cat([pers_latent] * 2)

                    # predict the noise residual
                    pers_noise_pred = self.unet(pers_latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = pers_noise_pred.chunk(2)
                    pers_noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    pers_latent_denoised = self.scheduler.step(pers_noise_pred, t, pers_latent)['prev_sample']

                    # compute erp_noise_pred for valid region
                    erp_up_noise_pred, erp_up_valid_region = compute_erp_up_noise_pred(pers_noise_pred, erp2pers_ind, fin_v_num)
                    erp_up_noise_denoised = self.scheduler.step(erp_up_noise_pred, t, erp_up_latent)['prev_sample']

                    # overlap region에 대한 feature 저장
                    value.append(torch.where(erp_up_valid_region, erp_up_noise_denoised, torch.zeros_like(erp_up_noise_denoised)))

                    count[:, :] += erp_up_valid_region

                    theta, phi = view
                    erp_up_valid_region = erp_up_valid_region[0, 0]
                    mask = ToPILImage()(erp_up_valid_region.float().cpu())
                    mask.save(f'/{save_dir}/{i+1:0>2}/mask_{theta}_{phi}.png')

                # overlap region에서 랜덤하게 feature 선택
                value_stack = torch.stack(value, dim=0)  # [num_views, 1, C, H, W]

                # 각 픽셀의 count 범위 내에서 랜덤 인덱스 생성
                valid_counts = count[0, 0].long()  # [H, W], count는 각 픽셀의 count 수

                # 랜덤 인덱스를 생성, 각 픽셀별로 0부터 (count - 1) 사이의 값을 uniform하게 샘플링
                random_indices = torch.randint(0, valid_counts.max().item(), size=valid_counts.shape, device=self.device)

                # 픽셀별 랜덤 인덱싱 수행
                random_indices_expanded = random_indices[None, None, None, :, :]  # Expand to [1, 1, 1, H, W]
                erp_up_latent = torch.gather(
                    value_stack, 0, random_indices_expanded.expand(value_stack.shape[0], 1, value_stack.shape[2], *random_indices.shape)
                )[0]  # Extract the [1, C, H, W] portion

                count = count.float() / count.max().float()
                count_img = ToPILImage()(count.cpu()[0][0])
                count_img.save(f'/{save_dir}/{i+1:0>2}/count.png')

                # update pers_latents from denoised erp_up_latent
                pers_latents, _, _ = erp2pers_latent_warping(erp_up_latent, HW_pers, self.views, glctx)

                imgs = self.latent_to_image_and_save(i, pers_latents, save_dir, imgs)
        
        return imgs


In [None]:
class ERPMultiDiffusion_v3_7(MultiDiffusion):
    
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__(device, sd_version, hf_key)
        self.up_level = 3
        self.views = [
                (0.0, -22.5), (15.0, -22.5), (30.0, -22.5), (-15.0, -22.5), (-30.0, -22.5),
                (0.0,   0.0), (15.0,   0.0), (30.0,   0.0), (-15.0,   0.0), (-30.0,   0.0),
                (0.0,  22.5), (15.0,  22.5), (30.0,  22.5), (-15.0,  22.5), (-30.0,  22.5),
            ]
        self.views = [
                (0.0,   -40.0), (0.0,   -30.0), (0.0,   -35.0),
                (0.0,   -30.0), (0.0,   -25.0), (0.0,   -20.0), (0.0,   -15.0),
                (0.0,   -10.0), (0.0,   -5.0), (0.0,   0.0), (0.0,   5.0),
            ]
        self.views = [
                (0.0,   0.0), (-45.0,   0.0),
            ]
        # self.views = [
        #         (0.0,   0.0),
        #     ]
        
    @torch.no_grad()
    def latent_to_image_and_save(self, i, pers_latents, save_dir, ret_imgs):

        pers_imgs = []
        for k, pers_latent in enumerate(pers_latents):
            pers_img = self.decode_latents(pers_latent)
            pers_imgs.append((self.views[k], pers_img)) # [(theta, phi), img]
        ret_imgs.append((i+1, pers_imgs)) # [i+1, [(theta, phi), img]]
        
        if save_dir is not None:
            # save image
            if os.path.exists(f"{save_dir}/{i+1:0>2}") is False:
                os.mkdir(f"{save_dir}/{i+1:0>2}/")
            for v, im in pers_imgs:
                theta, phi = v
                im = ToPILImage()(im[0].cpu())
                im.save(f'/{save_dir}/{i+1:0>2}/pers_{theta}_{phi}.png')
        
        return ret_imgs
        
    @torch.no_grad()
    def text2erp(self,
                prompts, 
                negative_prompts='', 
                height=512, width=1024, 
                num_inference_steps=50,
                guidance_scale=7.5,
                save_dir=None):
        
        if isinstance(prompts, str):
            prompts = [prompts]
        
        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define ERP source noise
        erp_latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)

        # Conditional white noise sampling
        erp_up_latent = cond_noise_sampling(erp_latent, self.up_level)
        count = torch.zeros_like(erp_up_latent)
        value = []  # 각 view의 feature를 저장하기 위한 리스트

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.no_grad():
            
            HW_pers = (64, 64)

            pers_latents, erp2pers_indices, fin_v_num =\
                erp2pers_latent_warping(erp_up_latent, HW_pers, self.views, glctx)

            imgs = []
            imgs = self.latent_to_image_and_save(-1, pers_latents, save_dir, imgs)

            for i, t in enumerate(tqdm(self.scheduler.timesteps)):

                if os.path.exists(f"{save_dir}/{i+1:0>2}") is False:
                    os.mkdir(f"{save_dir}/{i+1:0>2}/")
                
                count.zero_()
                value.clear()  # 리스트를 초기화

                for pers_latent, erp2pers_ind, view in zip(pers_latents, erp2pers_indices, self.views):
                    
                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    pers_latent_model_input = torch.cat([pers_latent] * 2)

                    # predict the noise residual
                    pers_noise_pred = self.unet(pers_latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = pers_noise_pred.chunk(2)
                    pers_noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    pers_latent_denoised = self.scheduler.step(pers_noise_pred, t, pers_latent)['prev_sample']

                    # compute erp_noise_pred for valid region
                    erp_up_noise_pred, erp_up_valid_region = compute_erp_up_noise_pred(pers_noise_pred, erp2pers_ind, fin_v_num)
                    erp_up_noise_denoised = self.scheduler.step(erp_up_noise_pred, t, erp_up_latent)['prev_sample']

                    # overlap region에 대한 feature 저장
                    value.append(torch.where(erp_up_valid_region, erp_up_noise_denoised, torch.zeros_like(erp_up_noise_denoised)))

                    count[:, :] += erp_up_valid_region

                    theta, phi = view
                    erp_up_valid_region = erp_up_valid_region[0, 0]
                    mask = ToPILImage()(erp_up_valid_region.float().cpu())
                    mask.save(f'/{save_dir}/{i+1:0>2}/mask_{theta}_{phi}.png')

				# TODO: fill the below

                # Initialize erp_up_latent for overlap resolution
                erp_up_latent.zero_()

                # Create random binary mask for selecting features in overlap region
                random_mask = torch.randint(0, 2, size=value[0].shape, device=value[0].device).bool()

                for idx, feature in enumerate(value):
                    # For valid regions, add features to erp_up_latent
                    erp_up_latent += torch.where(
                        count == 1,  # Non-overlap region: valid region of current view only
                        feature,
                        torch.zeros_like(feature)
                    )

                    # For overlap region, select randomly between views
                    if idx == 0:
                        overlap_latent = torch.where(random_mask, feature, torch.zeros_like(feature))
                    else:
                        overlap_latent += torch.where(~random_mask, feature, torch.zeros_like(feature))

                # Add overlap features to erp_up_latent
                erp_up_latent += torch.where(
                    count > 1,  # Overlap region
                    overlap_latent,
                    torch.zeros_like(overlap_latent)
                )

                count = count.float() / count.max().float()
                count_img = ToPILImage()(count.cpu()[0][0])
                count_img.save(f'/{save_dir}/{i+1:0>2}/count.png')

                # update pers_latents from denoised erp_up_latent
                pers_latents, _, _ = erp2pers_latent_warping(erp_up_latent, HW_pers, self.views, glctx)

                imgs = self.latent_to_image_and_save(i, pers_latents, save_dir, imgs)

        return imgs


In [None]:
class ERPMultiDiffusion_v3_7_2(MultiDiffusion):
    
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__(device, sd_version, hf_key)
        self.up_level = 3
        self.views = [
                (0.0, -22.5), (15.0, -22.5), (30.0, -22.5), (-15.0, -22.5), (-30.0, -22.5),
                (0.0,   0.0), (15.0,   0.0), (30.0,   0.0), (-15.0,   0.0), (-30.0,   0.0),
                (0.0,  22.5), (15.0,  22.5), (30.0,  22.5), (-15.0,  22.5), (-30.0,  22.5),
            ]
        self.views = [
                (0.0,   -40.0), (0.0,   -30.0), (0.0,   -35.0),
                (0.0,   -30.0), (0.0,   -25.0), (0.0,   -20.0), (0.0,   -15.0),
                (0.0,   -10.0), (0.0,   -5.0), (0.0,   0.0), (0.0,   5.0),
            ]
        self.views = [
                (0.0,   0.0), (-45.0,   0.0),
            ]
        # self.views = [
        #         (0.0,   0.0),
        #     ]
        
    @torch.no_grad()
    def latent_to_image_and_save(self, i, pers_latents, save_dir, ret_imgs):

        pers_imgs = []
        for k, pers_latent in enumerate(pers_latents):
            pers_img = self.decode_latents(pers_latent)
            pers_imgs.append((self.views[k], pers_img)) # [(theta, phi), img]
        ret_imgs.append((i+1, pers_imgs)) # [i+1, [(theta, phi), img]]
        
        if save_dir is not None:
            # save image
            if os.path.exists(f"{save_dir}/{i+1:0>2}") is False:
                os.mkdir(f"{save_dir}/{i+1:0>2}/")
            for v, im in pers_imgs:
                theta, phi = v
                im = ToPILImage()(im[0].cpu())
                im.save(f'/{save_dir}/{i+1:0>2}/pers_{theta}_{phi}.png')
        
        return ret_imgs
        
    @torch.no_grad()
    def text2erp(self,
                prompts, 
                negative_prompts='', 
                height=512, width=1024, 
                num_inference_steps=50,
                guidance_scale=7.5,
                save_dir=None):
        
        if isinstance(prompts, str):
            prompts = [prompts]
        
        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define ERP source noise
        erp_latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)

        # Conditional white noise sampling
        erp_up_latent = cond_noise_sampling(erp_latent, self.up_level)
        count = torch.zeros_like(erp_up_latent)
        value = []  # 각 view의 feature를 저장하기 위한 리스트

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.no_grad():
            
            HW_pers = (64, 64)

            pers_latents, erp2pers_indices, fin_v_num =\
                erp2pers_latent_warping(erp_up_latent, HW_pers, self.views, glctx)

            imgs = []
            imgs = self.latent_to_image_and_save(-1, pers_latents, save_dir, imgs)

            for i, t in enumerate(tqdm(self.scheduler.timesteps)):

                if os.path.exists(f"{save_dir}/{i+1:0>2}") is False:
                    os.mkdir(f"{save_dir}/{i+1:0>2}/")
                
                count.zero_()
                value.clear()  # 리스트를 초기화

                for pers_latent, erp2pers_ind, view in zip(pers_latents, erp2pers_indices, self.views):
                    
                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    pers_latent_model_input = torch.cat([pers_latent] * 2)

                    # predict the noise residual
                    pers_noise_pred = self.unet(pers_latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = pers_noise_pred.chunk(2)
                    pers_noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    pers_latent_denoised = self.scheduler.step(pers_noise_pred, t, pers_latent)['prev_sample']

                    # compute erp_noise_pred for valid region
                    erp_up_noise_denoised, erp_up_valid_region = compute_erp_up_noise_denoised(pers_latent_denoised, erp2pers_ind, fin_v_num)
                    # erp_up_noise_denoised = self.scheduler.step(erp_up_noise_pred, t, erp_up_latent)['prev_sample']

                    # overlap region에 대한 feature 저장
                    value.append(torch.where(erp_up_valid_region, erp_up_noise_denoised, torch.zeros_like(erp_up_noise_denoised)))

                    count[:, :] += erp_up_valid_region

                    theta, phi = view
                    erp_up_valid_region = erp_up_valid_region[0, 0]
                    mask = ToPILImage()(erp_up_valid_region.float().cpu())
                    mask.save(f'/{save_dir}/{i+1:0>2}/mask_{theta}_{phi}.png')

				# TODO: fill the below

                # Initialize erp_up_latent for overlap resolution
                erp_up_latent.zero_()

                # Create random binary mask for selecting features in overlap region
                random_mask = torch.randint(0, 2, size=value[0].shape, device=value[0].device).bool()

                for idx, feature in enumerate(value):
                    # For valid regions, add features to erp_up_latent
                    erp_up_latent += torch.where(
                        count == 1,  # Non-overlap region: valid region of current view only
                        feature,
                        torch.zeros_like(feature)
                    )

                    # For overlap region, select randomly between views
                    if idx == 0:
                        overlap_latent = torch.where(random_mask, feature, torch.zeros_like(feature))
                    else:
                        overlap_latent += torch.where(~random_mask, feature, torch.zeros_like(feature))

                # Add overlap features to erp_up_latent
                erp_up_latent += torch.where(
                    count > 1,  # Overlap region
                    overlap_latent,
                    torch.zeros_like(overlap_latent)
                )

                count = count.float() / count.max().float()
                count_img = ToPILImage()(count.cpu()[0][0])
                count_img.save(f'/{save_dir}/{i+1:0>2}/count.png')

                # update pers_latents from denoised erp_up_latent
                pers_latents, _, _ = erp2pers_latent_warping(erp_up_latent, HW_pers, self.views, glctx)

                imgs = self.latent_to_image_and_save(i, pers_latents, save_dir, imgs)

        return imgs


In [None]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True
    
seed_everything(2024)
device = torch.device('cuda')

# opt variables
sd_version = '2.0'
negative = ''
steps = 50

In [None]:
prompt = "Realistic cityscape of Florence."

H, W = 1024, 2048
sd = ERPMultiDiffusion_v3_5(device=device, sd_version=sd_version)
dir_name = "imgs"

if os.path.exists(f'/content/{dir_name}/') is False:
    os.mkdir(f'/content/{dir_name}/')

if os.path.exists(f'/content/{dir_name}/{prompt.split(" ")[0]}/') is False:
    os.mkdir(f'/content/{dir_name}/{prompt.split(" ")[0]}/')

dir = f'/content/{dir_name}/{prompt.split(" ")[0]}'
outputs = sd.text2erp(prompt, negative, height=H, width=W, num_inference_steps=steps, save_dir=dir)