In [None]:
%cd /content

In [None]:
import os
if os.path.exists('/content/hiwyn'):
  %rm -rf /content/hiwyn
!git clone https://github.com/ByeongHyunPak/hiwyn.git
%cd /content/hiwyn

In [None]:
!sudo apt-get install ninja-build
!ninja --version

In [None]:
if os.path.exists('/content/hiwyn/nvdiffrast'):
  %rm -rf /content/hiwyn/nvdiffrast
!git clone --recursive https://github.com/NVlabs/nvdiffrast
%cd /content/hiwyn/nvdiffrast
!pip install .
%cd /content/hiwyn

In [None]:
import cv2

import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import nvdiffrast.torch as dr

from tqdm import tqdm
from IPython.display import Image
from torchvision.transforms import ToPILImage, ToTensor
from einops import rearrange, reduce, repeat

In [None]:
use_opengl = False # On T4 GPU, only False works, but rasterizer works much better if = True
glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()

In [None]:
from multidiffusion import MultiDiffusion, get_views
from utils import cond_noise_sampling, identity_latent_warping

In [None]:
class HIWYN_identity_MultiDiffusion(MultiDiffusion):
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__(device, sd_version, hf_key)
        self.up_level = 3
    
    @torch.no_grad()
    def latent_to_image_and_save(self, i, latents, save_dir, ret_imgs):

        imgs = []
        for k, latent in enumerate(latents):
            pers_img = self.decode_latents(latent)
            imgs.append((k, pers_img)) # [k, img]
        ret_imgs.append((i+1, imgs)) # [i+1, [k, img]]
        
        if save_dir is not None:
            # save image
            if os.path.exists(f"{save_dir}/{i+1:0>2}") is False:
                os.mkdir(f"{save_dir}/{i+1:0>2}/")
            for k, im in imgs:
                im = ToPILImage()(im[0].cpu())
                im.save(f'/{save_dir}/{i+1:0>2}/pers_{k}.png')
        
        return ret_imgs

    @torch.no_grad()
    def text2img(self,
                 prompts, 
                 negative_prompts='', 
                 height=512, 
                 width=512, 
                 num_inference_steps=50,
                 guidance_scale=7.5, 
                 save_dir=None):

        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define panorama grid and get views
        latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)
        up_latent = cond_noise_sampling(latent, self.up_level)
        hiwyn_latent = identity_latent_warping(up_latent, latent.shape[-2:], glctx)[0][0]
        
        views = get_views(height, width)
        count = torch.zeros_like(latent)
        value = torch.zeros_like(latent)
        
        hiwyn_count = torch.zeros_like(up_latent)
        hiwyn_value = torch.zeros_like(up_latent)

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.autocast('cuda'):

            imgs = []
            imgs = self.latent_to_image_and_save(-1, [latent, hiwyn_latent], save_dir, imgs)

            for i, t in enumerate(tqdm(self.scheduler.timesteps)):

                if os.path.exists(f"{save_dir}/{i+1:0>2}") is False:
                    os.mkdir(f"{save_dir}/{i+1:0>2}/")

                count.zero_()
                value.zero_()
                hiwyn_count.zero_()
                hiwyn_value.zero_()

                for h_start, h_end, w_start, w_end in views:
                    # TODO we can support batches, and pass multiple views at once to the unet
                    latent_view = latent[:, :, h_start:h_end, w_start:w_end]
                    hiwyn_view = hiwyn_latent[:, :, h_start:h_end, w_start:w_end]

                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latent_view] * 2)
                    hiwyn_latent_model_input = torch.cat([hiwyn_view] * 2)

                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
                    hiwyn_noise_pred = self.unet(hiwyn_latent_model_input, t, encoder_hidden_states=text_embeds)['sample']
                    
                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
                    hiwyn_noise_pred_uncond, hiwyn_noise_pred_cond = hiwyn_noise_pred.chunk(2)
                    hiwyn_noise_pred = hiwyn_noise_pred_uncond + guidance_scale * (hiwyn_noise_pred_cond - hiwyn_noise_pred_uncond)

                    # compute the denoising step with the reference model
                    latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']
                    value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
                    count[:, :, h_start:h_end, w_start:w_end] += 1

                    # denoising up_latent with noise_pred
                    up_noise_pred = torch.zeros_like(up_latent)
                    up_noise_pred[:, :, 8*h_start:8*h_end, 8*w_start:8*w_end] = F.interpolate(hiwyn_noise_pred, size=(512, 512), mode='nearest') / 8
                    up_latent_denoised = self.scheduler.step(up_noise_pred, t, up_latent)['prev_sample']

                    hiwyn_value[:, :, 8*h_start:8*h_end, 8*w_start:8*w_end] = up_latent_denoised[:, :, 8*h_start:8*h_end, 8*w_start:8*w_end]
                    hiwyn_count[:, :, 8*h_start:8*h_end, 8*w_start:8*w_end] += 1
                    
                    print()
                    print(f"{i+1} / latent_view             : power {latent_view.pow(2).mean():>8.5f}")
                    print(f"{i+1} / noise_pred              : power {noise_pred.pow(2).mean():>8.5f}")
                    print(f"{i+1} / latents_view_denoised   : power {latents_view_denoised.pow(2).mean():>8.5f}")
                    print(f"{i+1} / up_latent               : power {up_latent.pow(2).mean():>8.5f}")
                    print(f"{i+1} / up_noise_pred           : power {up_noise_pred[:, :, 8*h_start:8*h_end, 8*w_start:8*w_end].pow(2).mean():>8.5f}")
                    print(f"{i+1} / up_latent_denoised      : power {up_latent_denoised[:, :, 8*h_start:8*h_end, 8*w_start:8*w_end].pow(2).mean():>8.5f}")

                # take the MultiDiffusion step
                latent = torch.where(count > 0, value / count, value)
                up_latent_denoised = torch.where(hiwyn_count > 0, hiwyn_value / hiwyn_count, hiwyn_value)
                hiwyn_latent = identity_latent_warping(up_latent_denoised, latent.shape[-2:], glctx)[0][0]
                up_latent = up_latent_denoised
                
                hiwyn_count = hiwyn_count.float() / hiwyn_count.max().float()
                hiwyn_count_img = ToPILImage()(hiwyn_count.cpu()[0][0])
                hiwyn_count_img.save(f'/{save_dir}/{i+1:0>2}/count.png')
                

                imgs = self.latent_to_image_and_save(i, [hiwyn_latent, latent], save_dir, imgs)

        # Img latents -> imgs
        imgs = self.decode_latents(latent)  # [1, 3, 512, 512]
        img = T.ToPILImage()(imgs[0].cpu())

        return [img]

In [None]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True
    
seed_everything(2024)
device = torch.device('cuda')

# opt variables
sd_version = '2.0'
negative = ''
steps = 50

In [None]:
prompt = "Realistic cityscape of Florence."

H, W = 512, 1024
sd = HIWYN_identity_MultiDiffusion(device=device, sd_version=sd_version)
dir_name = "imgs"

if os.path.exists(f'/content/{dir_name}/') is False:
    os.mkdir(f'/content/{dir_name}/')

if os.path.exists(f'/content/{dir_name}/{prompt.split(" ")[0]}/') is False:
    os.mkdir(f'/content/{dir_name}/{prompt.split(" ")[0]}/')

dir = f'/content/{dir_name}/{prompt.split(" ")[0]}'
outputs = sd.text2img(prompt, negative, height=H, width=W, num_inference_steps=steps, save_dir=dir)