<a href="https://colab.research.google.com/github/ByeongHyunPak/omni-proj/blob/main/scratchpad.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/ByeongHyunPak/omni-proj.git
!pip install tensorboardX

In [None]:
import os
os.chdir('/content/omni-proj/omni-proj')

In [None]:
import utils
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
from PIL import Image

def pil_to_tensor(pil_image):
  transform = transforms.ToTensor()
  tensor_image = transform(pil_image)
  return tensor_image

def tensor_to_pil(tensor_image):
    transform = transforms.ToPILImage()
    pil_image = transform(tensor_image)
    return pil_image

def load_images_from_folder(folder):
  images = []
  for filename in os.listdir(folder):
    img_path = os.path.join(folder, filename)
    if os.path.isfile(img_path) and img_path.lower().endswith(('.png', '.jpg', '.jpeg')):
      try:
        img = Image.open(img_path)
        images.append(img)
      except Exception as e:
        print(f"Error loading image {img_path}: {e}")
  return images

imgs_folder = '/content/omni-proj/imgs/erps' 
images = load_images_from_folder(imgs_folder)

for img in images:
    img = pil_to_tensor(img)
    print(img.shape)

In [None]:
def erp2per(hr_erp_img, THETA, PHI, FOVy, FOVx, HWy=(512, 512)):
    hr_erp_img = pil_to_tensor(hr_erp_img)
    HWx = hr_erp_img.shape[-2:]

    gridy = utils.make_coord(HWy)
    gridy2x, masky = utils.gridy2x_erp2per(
        gridy, HWy, HWx, THETA, PHI, FOVy, FOVx)
    gridy2x = gridy2x.view(*HWy, 2)

    inp = F.grid_sample(hr_erp_img.unsqueeze(0),
                        gridy2x.unsqueeze(0).flip(-1),
                        mode='bicubic',
                        padding_mode='reflection',
                        align_corners=False).clamp_(0, 1)[0]

    gridx = utils.make_coord(HWx, flatten=False)
    gridx2y, maskx = utils.gridy2x_per2erp(
        gridx, HWx, HWy, THETA, PHI, FOVx, FOVy)
    
    maskx = maskx.view(1, *HWx)
    valid_hr_erp_img = hr_erp_img * maskx
    
    return inp, valid_hr_erp_img, maskx

def erp2fis(hr_erp_img, THETA, PHI, FOVy, FOVx, HWy=(1024, 1024)):
    hr_erp_img = pil_to_tensor(hr_erp_img)
    HWx = hr_erp_img.shape[-2:]

    gridy = utils.make_coord(HWy)
    gridy2x, masky = utils.gridy2x_erp2fis(
        gridy, HWy, HWx, THETA, PHI, FOVy, FOVx)
    gridy2x = gridy2x.view(*HWy, 2)
    masky = masky.view(1, *HWy)

    inp = F.grid_sample(hr_erp_img.unsqueeze(0),
                        gridy2x.unsqueeze(0).flip(-1),
                        mode='bicubic',
                        padding_mode='reflection',
                        align_corners=False).clamp_(0, 1)[0]
    inp = inp * masky
    
    gridx = utils.make_coord(HWx, flatten=False)
    gridx2y, maskx = utils.gridy2x_fis2erp(
        gridx, HWx, HWy, THETA, PHI, FOVx, FOVy)

    maskx = maskx.view(1, *HWx)
    valid_hr_erp_img = hr_erp_img * maskx

    return inp, valid_hr_erp_img, maskx

In [None]:
import random

hr_erp_img = images[0]
display(hr_erp_img)

# ERP to Perspective
# THETA = random.uniform(-135, 135)
# PHI = random.uniform(-45, 45)

# pers_img, valid_hr_erp_img, _ = erp2per(hr_erp_img, THETA, PHI, FOVy=75, FOVx=360)

# display(tensor_to_pil(valid_hr_erp_img))
# display(tensor_to_pil(pers_img))

# ERP to Fisheye
THETA = random.uniform(-135, 135)
PHI = 0

fish_img, valid_hr_erp_img, _ = erp2fis(hr_erp_img, THETA, PHI, FOVy=180, FOVx=360)

display(tensor_to_pil(valid_hr_erp_img))
display(tensor_to_pil(fish_img))

In [None]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt

def draw_boundary_on_image(image_tensor, mask_tensor, boundary_color=(0, 0, 255), boundary_thickness=10):
    if mask_tensor.dtype != torch.uint8:
        mask_tensor = (mask_tensor * 255).to(torch.uint8)

    if image_tensor.dtype != torch.uint8:
        image_tensor = (image_tensor * 255).to(torch.uint8)

    if len(mask_tensor.shape) == 3:
        mask_tensor = mask_tensor[0, :, :]

    mask_np = mask_tensor.numpy().astype(np.uint8)
    contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    image_np = image_tensor.permute(1, 2, 0).numpy().copy()  # [C, H, W] -> [H, W, C]
    cv2.drawContours(image_np, contours, -1, boundary_color, boundary_thickness)
    output_image_tensor = torch.from_numpy(image_np).permute(2, 0, 1)  # 다시 [C, H, W]로 변경
    
    return output_image_tensor

In [None]:
import random

# ERP to Perspective
hr_erp_img = images[3]
lr_pers_hw = (512, 512)

num_rows = 4
num_cols = [3, 6, 6, 3]
phi_centers = [-67.5, -22.5, 22.5, 67.5]
phi_interval = 180 // num_rows


for i, n_cols in enumerate(num_cols):
    PHI = phi_centers[i]
    global_hr_erp_ten = pil_to_tensor(hr_erp_img)
    for j in np.arange(n_cols):
        theta_interval = 360 / n_cols
        THETA = j * theta_interval + theta_interval / 2
        pers_img, valid_hr_erp_img, valid_hr_erp_mask =\
            erp2per(hr_erp_img, THETA, PHI, FOVy=75, FOVx=360, HWy=lr_pers_hw)

        hr_erp_ten = pil_to_tensor(hr_erp_img)
        hr_erp_ten = draw_boundary_on_image(hr_erp_ten, valid_hr_erp_mask)

        random_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
        global_hr_erp_ten = draw_boundary_on_image(global_hr_erp_ten, valid_hr_erp_mask, boundary_color=random_color, boundary_thickness=15)

        pers_img = tensor_to_pil(pers_img)
        valid_hr_erp_img = tensor_to_pil(valid_hr_erp_img)
        valid_hr_erp_mask = tensor_to_pil(valid_hr_erp_mask)
        hr_erp_ten = tensor_to_pil(hr_erp_ten)

        display(pers_img)
        display(valid_hr_erp_img)
        display(valid_hr_erp_mask)
        display(hr_erp_ten)
    display(tensor_to_pil(global_hr_erp_ten))


In [None]:
!pip install transformers diffusers

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler

# suppress partial model loading warning
logging.set_verbosity_error()

import torch
import torch.nn as nn
import torchvision.transforms as T
import argparse
from tqdm import tqdm

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True

def get_views(panorama_height, panorama_width, window_size=64, stride=8):
    panorama_height /= 8
    panorama_width /= 8
    num_blocks_height = (panorama_height - window_size) // stride + 1
    num_blocks_width = (panorama_width - window_size) // stride + 1
    total_num_blocks = int(num_blocks_height * num_blocks_width)
    views = []
    for i in range(total_num_blocks):
        h_start = int((i // num_blocks_width) * stride)
        h_end = h_start + window_size
        w_start = int((i % num_blocks_width) * stride)
        w_end = w_start + window_size
        views.append((h_start, h_end, w_start, w_end))
    return views

In [None]:
class MultiDiffusion(nn.Module):
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__()

        self.device = device
        self.sd_version = sd_version

        print(f'[INFO] loading stable diffusion...')
        if hf_key is not None:
            print(f'[INFO] using hugging face custom model key: {hf_key}')
            model_key = hf_key
        elif self.sd_version == '2.1':
            model_key = "stabilityai/stable-diffusion-2-1-base"
        elif self.sd_version == '2.0':
            model_key = "stabilityai/stable-diffusion-2-base"
        elif self.sd_version == '1.5':
            model_key = "runwayml/stable-diffusion-v1-5"
        else:
            raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')

        # Create model
        self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device)
        self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
        self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
        self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device)

        self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")

        print(f'[INFO] loaded stable diffusion!')

    @torch.no_grad()
    def get_text_embeds(self, prompt, negative_prompt):
        # prompt, negative_prompt: [str]

        # Tokenize text and get embeddings
        text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
                                    truncation=True, return_tensors='pt')
        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]

        # Do the same for unconditional embeddings
        uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
                                      return_tensors='pt')

        uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

        # Cat for final embeddings
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
        return text_embeddings

    @torch.no_grad()
    def decode_latents(self, latents):
        latents = 1 / 0.18215 * latents
        imgs = self.vae.decode(latents).sample
        imgs = (imgs / 2 + 0.5).clamp(0, 1)
        return imgs

    @torch.no_grad()
    def text2panorama(self, prompts, negative_prompts='', height=512, width=2048, num_inference_steps=50,
                      guidance_scale=7.5):

        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define panorama grid and get views
        latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)
        views = get_views(height, width)
        count = torch.zeros_like(latent)
        value = torch.zeros_like(latent)

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.autocast('cuda'):
            for i, t in enumerate(tqdm(self.scheduler.timesteps)):
                count.zero_()
                value.zero_()

                for h_start, h_end, w_start, w_end in views:
                    # TODO we can support batches, and pass multiple views at once to the unet
                    latent_view = latent[:, :, h_start:h_end, w_start:w_end]

                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latent_view] * 2)

                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']
                    value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
                    count[:, :, h_start:h_end, w_start:w_end] += 1

                # take the MultiDiffusion step
                latent = torch.where(count > 0, value / count, value)

        # Img latents -> imgs
        imgs = self.decode_latents(latent)  # [1, 3, 512, 512]
        img = T.ToPILImage()(imgs[0].cpu())
        return img

In [None]:
seed_everything(2024)

device = torch.device('cuda')

# opt variables
sd_version = '2.0'
prompt = 'realistic firenze cityscape'
negative = ''
H = 512
W = 2024
steps = 50
outfile = f'/content/{prompt}.png'

sd = MultiDiffusion(device, sd_version)

img = sd.text2panorama(prompt, negative, H, W, steps)

# save image
img.save(outfile)

In [None]:
def erp_get_views(panorama_height, panorama_width, window_size=64, stride=8):
    panorama_height /= 8
    panorama_width /= 8
    num_blocks_height = (panorama_height - window_size) // stride + 1
    num_blocks_width = (panorama_width - window_size) // stride + 1
    total_num_blocks = int(num_blocks_height * num_blocks_width)
    views = []
    for i in range(total_num_blocks):
        h_start = int((i // num_blocks_width) * stride)
        h_end = h_start + window_size
        w_start = int((i % num_blocks_width) * stride)
        w_end = w_start + window_size
        views.append((h_start, h_end, w_start, w_end))
    return views

class ERPMultiDiffusion(MultiDiffusion):
    def __init__(self, device, sd_version='2.0', hf_key=None):
        super().__init__(device, sd_version, hf_key)

    @torch.no_grad()
    def text2erp(self, prompts, negative_prompts='', height=512, width=1024, num_inference_steps=50,
                      guidance_scale=7.5):

        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        text_embeds = self.get_text_embeds(prompts, negative_prompts)  # [2, 77, 768]

        # Define panorama grid and get views
        latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)
        views = get_views(height, width)
        count = torch.zeros_like(latent)
        value = torch.zeros_like(latent)

        self.scheduler.set_timesteps(num_inference_steps)

        with torch.autocast('cuda'):
            for i, t in enumerate(tqdm(self.scheduler.timesteps)):
                count.zero_()
                value.zero_()

                for h_start, h_end, w_start, w_end in views:
                    # TODO we can support batches, and pass multiple views at once to the unet
                    latent_view = latent[:, :, h_start:h_end, w_start:w_end]

                    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
                    latent_model_input = torch.cat([latent_view] * 2)

                    # predict the noise residual
                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds)['sample']

                    # perform guidance
                    noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

                    # compute the denoising step with the reference model
                    latents_view_denoised = self.scheduler.step(noise_pred, t, latent_view)['prev_sample']
                    value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
                    count[:, :, h_start:h_end, w_start:w_end] += 1

                # take the MultiDiffusion step
                latent = torch.where(count > 0, value / count, value)

        # Img latents -> imgs
        imgs = self.decode_latents(latent)  # [1, 3, 512, 512]
        img = T.ToPILImage()(imgs[0].cpu())
        return img