In [1]:
%load_ext autoreload
%autoreload 2

### Stable Diffusion + Multi-Plane Synchronization

In [None]:
import torch
import numpy as np
from PIL import Image
from einops import rearrange
from diffusers import StableDiffusionPipeline, DiffusionPipeline, AutoencoderKL

from models.multiplane_sync_legacy import apply_custom_processors_for_unet, apply_custom_processors_for_vae
from utils.cube import images_to_equi_and_dice, concat_dice_mask


def build_sd_pipeline(device, version: str):
    if version == 'sd2':
        pipe = StableDiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2",
            torch_dtype=torch.float16,
            local_files_only=True,
        )
    
    elif version == 'sdxl':
        pipe = DiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            vae=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16),
            torch_dtype=torch.float16,
            variant="fp16",
            local_files_only=True,
        )
    
    apply_custom_processors_for_unet(pipe.unet, enable_sync_self_attn=True, enable_sync_cross_attn=False, enable_sync_conv2d=True, enable_sync_gn=True)
    apply_custom_processors_for_vae(pipe.vae, enable_sync_attn=True, enable_sync_gn=True, enable_sync_conv2d=True)
    
    return pipe.to(device)


# Get device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Build the Stable Diffusion pipeline with Multi-Plane Sync.
version = 'sd2'  # 'sd2' or 'sdxl'
pipe = build_sd_pipeline(device, version=version)

# Image generation
image_size = 512 if version == 'sd2' else 768
prompts = ["Vast cosmos in the style of Van Gogh"] * 6
latents = torch.randn(6, 4, image_size//8, image_size//8).to(device, dtype=torch.float16)
images = pipe(prompts, latents=latents, output_type='np').images
images = rearrange(images, '(b m) ... -> b m ...', m=6)

# Visualization of generated images
equis, dices = images_to_equi_and_dice(images)
equi_rgb_pil = Image.fromarray((equis[0] * 255).astype(np.uint8))
dice_rgb_pil = Image.fromarray((dices[0] * 255).astype(np.uint8))
equi_rgb_pil.resize((1024, 512)).show()
concat_dice_mask(dice_rgb_pil).resize((1024, 768)).show()


### Marigold + Multi-Plane Synchronization

In [None]:
import torch
import numpy as np
from PIL import Image
from einops import rearrange
from typing import Optional, Union
from diffusers.pipelines.marigold.pipeline_marigold_depth import MarigoldDepthPipeline

from models.multiplane_sync_legacy import apply_custom_processors_for_unet, apply_custom_processors_for_vae
from utils.cube import Cubemap, images_to_equi_and_dice, concat_dice_mask
from utils.depth import z_distance_to_depth
from utils.equi import Equirectangular


def load_images_from_panorama(pano_path: str, cube_size: int) -> np.ndarray:
    equi = np.array(Image.open(pano_path).convert('RGB'))
    cube = Equirectangular(equi).to_cubemap(cube_size)
    images = Cubemap.cube_all2all_equilib(cube.faces, cube.cube_format, 'list', to_equilib=True)
    images = np.array(images, dtype=np.float32) / 255.0  # Normalize to [0, 1]
    images = rearrange(images, 'm h w c -> 1 m h w c', m=6)
    return images


def build_marigold_pipeline(device):
    pipe = MarigoldDepthPipeline.from_pretrained(
        'prs-eth/marigold-depth-v1-0',
        variant="fp16",
        torch_dtype=torch.float16,
        local_files_only=True,
    )
    apply_custom_processors_for_unet(pipe.unet, enable_sync_self_attn=True, enable_sync_cross_attn=True, enable_sync_conv2d=True, enable_sync_gn=True)
    apply_custom_processors_for_vae(pipe.vae, enable_sync_attn=True, enable_sync_gn=True, enable_sync_conv2d=True)
    return pipe.to(device)


# Get device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Build the Marigold pipeline with Multi-Plane Sync.
pipe = build_marigold_pipeline(device)

# Load cube images from a panorama
images = load_images_from_panorama('assets/abandoned_hall.png', cube_size=512)

# Depth estimation
inputs = rearrange(torch.from_numpy(images), 'b m h w c -> (b m) c h w', m=6).to(device)
depths = pipe(inputs, output_type='np', batch_size=6).prediction
depths = rearrange(depths, '(b m) ... -> b m ...', m=6)
depths = z_distance_to_depth(depths, 90.0, 90.0)

# Visualization of images
equis, dices = images_to_equi_and_dice(images)
equi_rgb_pil = Image.fromarray((equis[0] * 255).astype(np.uint8))
dice_rgb_pil = Image.fromarray((dices[0] * 255).astype(np.uint8))
equi_rgb_pil.resize((1024, 512)).show()
concat_dice_mask(dice_rgb_pil).resize((1024, 768)).show()

# Visualization of depths
val_min, val_max = np.percentile(depths, 2), np.percentile(depths, 98)  # 0.0, 1.0
equis, dices = images_to_equi_and_dice(depths)
equi_depth_vis = pipe.image_processor.visualize_depth(equis, val_min, val_max)[0]
dice_depth_vis = pipe.image_processor.visualize_depth(dices, val_min, val_max)[0]
equi_depth_vis.resize((1024, 512)).show()
concat_dice_mask(dice_depth_vis).resize((1024, 768)).show()
