In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from diffusers import SanaPipeline, SanaPAGPipeline
from models.multiplane_sync.processors_sana import (
    apply_custom_processors_for_vae,
    apply_custom_processors_for_transformer,
    get_patch_embed_forward,
)

def build_pipeline(model_name: str):
    if model_name == 'sana':
        pipe = SanaPipeline.from_pretrained(
            "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
            variant="bf16",
            torch_dtype=torch.bfloat16,
            local_files_only=True,
        )
    elif model_name == 'sana-pag':
        pipe = SanaPAGPipeline.from_pretrained(
            "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
            variant="bf16",
            torch_dtype=torch.bfloat16,
            pag_applied_layers="transformer_blocks.8",
            local_files_only=True,
        )
    pipe.text_encoder.to(torch.bfloat16)
    pipe.vae.to(torch.bfloat16)
    return pipe.to("cuda")

model_name = 'sana-pag'

pipe = build_pipeline(model_name)

apply_custom_processors_for_vae(pipe.vae, enable_sync_attn=True, enable_sync_conv2d=True, enable_sync_gn=True)
apply_custom_processors_for_transformer(pipe.transformer, enable_sync_self_attn=True, enable_sync_conv2d=True, enable_sync_gn=True)
pipe.transformer.patch_embed.forward = get_patch_embed_forward(pipe.transformer.patch_embed)


In [None]:
from PIL import Image
from einops import rearrange
import numpy as np
from utils.cube import images_to_equi_and_dice, concat_dice_mask


prompt = 'Floating in the sky, pixel style, white clouds, hot air balloons, and a blue sky'
height, width = 1024, 1024

# Inference
outputs = pipe(
    prompt=[prompt] * 6,
    height=height,
    width=width,
    output_type='np',
).images

# To PIL
images = rearrange(outputs, '(b m) ... -> b m ...', m=6)
equis, dices = images_to_equi_and_dice(images)
equi_rgb_pil = Image.fromarray((equis[0] * 255).astype(np.uint8))
dice_rgb_pil = Image.fromarray((dices[0] * 255).astype(np.uint8))
dice_rgb_pil = concat_dice_mask(dice_rgb_pil)

# Show
equi_rgb_pil.resize((1024, 512)).show()
dice_rgb_pil.resize((1024, 768)).show()
