In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
from PIL import Image
from einops import rearrange

from external.Lotus.my_pipeline import build_lotus_pipeline
from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
from models.multiplane_sync.processors_sd import apply_custom_processors_for_unet, apply_custom_processors_for_vae
from utils.cube import Cubemap, images_to_equi_and_dice, concat_dice_mask
from utils.depth import z_distance_to_depth
from utils.equi import Equirectangular


def load_images_from_panorama(pano_path: str, cube_size: int) -> np.ndarray:
    equi = np.array(Image.open(pano_path).convert('RGB'))
    cube = Equirectangular(equi).to_cubemap(cube_size)
    images = Cubemap.cube_all2all_equilib(cube.faces, cube.cube_format, 'list', to_equilib=True)
    images = np.array(images, dtype=np.float32) / 255.0  # Normalize to [0, 1]
    images = rearrange(images, 'm h w c -> 1 m h w c', m=6)
    return images


# Get device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Build the Marigold pipeline with Multi-Plane Sync.
pipe = build_lotus_pipeline(
    pretrained_model_name_or_path='jingheya/lotus-depth-g-v2-1-disparity',
    mode='generation',
    task_name='depth',
    half_precision=True,
    enable_xformers_memory_efficient_attention=False,
).to(device)

apply_custom_processors_for_unet(pipe.unet, enable_sync_self_attn=True, enable_sync_cross_attn=True, enable_sync_conv2d=True, enable_sync_gn=True)
apply_custom_processors_for_vae(pipe.vae, mode='all', enable_sync_attn=True, enable_sync_gn=True, enable_sync_conv2d=True)

# Load cube images from a panorama
images = load_images_from_panorama('assets/abandoned_hall.png', cube_size=512)

# Depth estimation
inputs = rearrange(torch.from_numpy(images), 'b m h w c -> (b m) c h w', m=6).to(device)
depths, _ = pipe.batch_inference(inputs)
depths = depths[..., None]  # np.ndarray, (6, 512, 512, 1)
depths = 1.0 - depths  # np.ndarray, (6, 512, 512, 1)
depths = rearrange(depths, '(b m) ... -> b m ...', m=6)
depths = z_distance_to_depth(depths, 90.0, 90.0)

# Visualization of images
equis, dices = images_to_equi_and_dice(images)
equi_rgb_pil = Image.fromarray((equis[0] * 255).astype(np.uint8))
dice_rgb_pil = Image.fromarray((dices[0] * 255).astype(np.uint8))
equi_rgb_pil.resize((1024, 512)).show()
concat_dice_mask(dice_rgb_pil).resize((1024, 768)).show()

# Visualization of depths
val_min, val_max = np.percentile(depths, 2), np.percentile(depths, 98)  # 0.0, 1.0
equis, dices = images_to_equi_and_dice(depths)
equi_depth_vis = MarigoldImageProcessor.visualize_depth(equis, val_min, val_max)[0]
dice_depth_vis = MarigoldImageProcessor.visualize_depth(dices, val_min, val_max)[0]
equi_depth_vis.resize((1024, 512)).show()
concat_dice_mask(dice_depth_vis).resize((1024, 768)).show()
