In [None]:
%cd /content

In [None]:
import os
if os.path.exists('/content/hiwyn'):
  %rm -rf /content/hiwyn
!git clone https://github.com/ByeongHyunPak/hiwyn.git
%cd /content/hiwyn

In [None]:
!sudo apt-get install ninja-build
!ninja --version

In [None]:
if os.path.exists('/content/hiwyn/nvdiffrast'):
  %rm -rf /content/hiwyn/nvdiffrast
!git clone --recursive https://github.com/NVlabs/nvdiffrast
%cd /content/hiwyn/nvdiffrast
!pip install .
%cd /content/hiwyn

In [None]:
import cv2

import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import nvdiffrast.torch as dr

from tqdm import tqdm
from IPython.display import Image
from torchvision.transforms import ToPILImage, ToTensor
from einops import rearrange, reduce, repeat

In [None]:
use_opengl = False # On T4 GPU, only False works, but rasterizer works much better if = True
glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()

In [None]:
from multidiffusion import MultiDiffusion
from utils import cond_noise_sampling, erp2pers_latent_warping, compute_erp_up_noise_pred

In [None]:
class ERPMultiDiffusion_v3_4(MultiDiffusion):
    
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__(device, sd_version, hf_key)
        self.up_level = 3
        self.views = [
                (0.0, -22.5), (15.0, -22.5), (30.0, -22.5), (-15.0, -22.5), (-30.0, -22.5),
                (0.0,   0.0), (15.0,   0.0), (30.0,   0.0), (-15.0,   0.0), (-30.0,   0.0),
                (0.0,  22.5), (15.0,  22.5), (30.0,  22.5), (-15.0,  22.5), (-30.0,  22.5),
            ]
        
    @torch.no_grad()
    def latent_to_image_and_save(self, i, pers_latents, save_dir, ret_imgs):

        pers_imgs = []
        for k, pers_latent in enumerate(pers_latents):
            pers_img = self.decode_latents(pers_latent)
            pers_imgs.append((self.views[k], pers_img)) # [(theta, phi), img]
        ret_imgs.append((i+1, pers_imgs)) # [i+1, [(theta, phi), img]]
        
        if save_dir is not None:
            # save image
            if os.path.exists(f"{save_dir}/{i+1:0>2}") is False:
                os.mkdir(f"{save_dir}/{i+1:0>2}/")
            for v, im in pers_imgs:
                theta, phi = v
                im = ToPILImage()(im[0].cpu())
                im.save(f'/{save_dir}/{i+1:0>2}/pers_{theta}_{phi}.png')
        
        return ret_imgs
    
    @torch.no_grad()
    def text2erp(self,
                 prompts, 
                 negative_prompts='', 
                 height=512, width=1024, 
                 num_inference_steps=50,
                 guidance_scale=7.5,
                 save_dir=None):
        
        if isinstance(prompts, str):
            prompts = [prompts]
        
        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define ERP source noise
        erp_latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)

        # Conditional white noise sampling
        erp_up_latent = cond_noise_sampling(erp_latent, self.up_level)
        count = torch.zeros_like(erp_up_latent)
        value = torch.zeros_like(erp_up_latent)

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.no_grad():
            
            HW_pers = (64, 64)

            pers_latents, erp2pers_indices, fin_v_num =\
                erp2pers_latent_warping(erp_up_latent, HW_pers, self.views, glctx)

            imgs = []
            imgs = self.latent_to_image_and_save(-1, pers_latents, save_dir, imgs)

            for i, t in enumerate(tqdm(self.scheduler.timesteps)):
                
                count.zero_()
                value.zero_()

                for pers_latent, erp2pers_ind in zip(pers_latents, erp2pers_indices):
                    
                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    pers_latent_model_input = torch.cat([pers_latent] * 2)

                    # predict the noise residual
                    pers_noise_pred = self.unet(pers_latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = pers_noise_pred.chunk(2)
                    pers_noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # # compute the denoising step with the reference model
                    # pers_latent_denoised = self.scheduler.step(pers_noise_pred, t, pers_latent)['prev_sample']

                    # compute erp_noise_pred for valid region
                    erp_up_noise_pred, erp_up_valid_region = compute_erp_up_noise_pred(pers_noise_pred, erp2pers_ind, fin_v_num)
                    erp_up_noise_denoised = self.scheduler.step(erp_up_noise_pred, t, erp_up_latent)['prev_sample']

                    value += torch.where(erp_up_valid_region, erp_up_noise_denoised, torch.zeros_like(erp_up_noise_denoised))
                    count += torch.where(erp_up_valid_region, torch.ones_like(erp_up_noise_denoised), torch.zeros_like(erp_up_noise_denoised))
                
                # average erp_up_latent on overlap region
                count = torch.clamp(count, min=1) 
                erp_up_latent = value / count

                # update pers_latents from denoised erp_up_latent
                pers_latents, _, _ = erp2pers_latent_warping(erp_up_latent, HW_pers, self.views, glctx)

                imgs = self.latent_to_image_and_save(i, pers_latents, save_dir, imgs)
        
        return imgs

In [None]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True
    
seed_everything(2024)
device = torch.device('cuda')

# opt variables
sd_version = '2.0'
negative = ''
steps = 50

In [None]:
prompt = "Realistic cityscape of Florence."

H, W = 1024, 2048
sd = ERPMultiDiffusion_v3_3(device=device, sd_version=sd_version)
dir_name = "imgs"

if os.path.exists(f'/content/{dir_name}/') is False:
    os.mkdir(f'/content/{dir_name}/')

if os.path.exists(f'/content/{dir_name}/{prompt.split(" ")[0]}/') is False:
    os.mkdir(f'/content/{dir_name}/{prompt.split(" ")[0]}/')

dir = f'/content/{dir_name}/{prompt.split(" ")[0]}'
outputs = sd.text2erp(prompt, negative, height=H, width=W, num_inference_steps=steps, save_dir=dir)