In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import torch
from diffusers import DPMSolverMultistepScheduler
from diffusers.utils import export_to_video
from external.video_crafter_diffusers.pipeline_text_to_video_videocrafter import TextToVideoVideoCrafterPipeline, tensor2vid
from external.video_crafter_diffusers.unet_3d_videocrafter import UNet3DVideoCrafterConditionModel
from models.multiplane_sync.processors_videocrafter2 import (
    apply_custom_processors_for_unet,
    apply_custom_processors_for_vae,
)

pipe = TextToVideoVideoCrafterPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
pipe.unet = UNet3DVideoCrafterConditionModel.from_pretrained("adamdad/videocrafterv2_diffusers", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras=True, algorithm_type="sde-dpmsolver++")
pipe.enable_model_cpu_offload()

# pipe.vae.enable_tiling()
# pipe.vae.enable_slicing()

apply_custom_processors_for_unet(
    pipe.unet,
    enable_sync_self_attn=True,
    enable_sync_cross_attn=False,
    enable_sync_conv2d=True,
    enable_sync_gn=True,
)
apply_custom_processors_for_vae(
    pipe.vae,
    mode='all',
    enable_sync_attn=False,
    enable_sync_conv2d=True,
    enable_sync_gn=True,
)


In [7]:
import torch
import torch.amp
import numpy as np
from PIL import Image
from einops import rearrange, repeat
from utils.cube import images_to_equi_and_dice


def frames_to_pano_video(frames):
    assert frames.ndim == 5 and frames.shape[0] % 6 == 0, f"Expected 5D tensor with shape (B*M, F, H, W, C), got {frames.shape}" 
    
    frames = rearrange(frames, "(B M) F H W C -> B M F H W C", M=6)

    B, M, F, H, W, C = frames.shape

    pano_frames, cube_frames = [], []
    for i in range(F):
        panos, cubes = images_to_equi_and_dice(frames[:, :, i])
        pano_frames.append(panos)  # [B, H, W, C]
        cube_frames.append(cubes)  # [B, H, W, C]
    pano_frames = np.stack(pano_frames, axis=1)  # [B, F, H, W, C]
    cube_frames = np.stack(cube_frames, axis=1)  # [B, F, H, W, C]

    return pano_frames, cube_frames


def decode_latents(vae, latents):
    latents = 1 / vae.config.scaling_factor * latents

    batch_size, channels, num_frames, height, width = latents.shape
    # latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
    latents = latents.permute(0, 2, 1, 3, 4)  # [batch_size, num_frames, channels, height, width]

    image = [vae.decode(latents[:, i]).sample for i in range(num_frames)]
    image = torch.stack(image, dim=1)  # [batch_size, num_frames, channels, height, width]
    image = rearrange(image, 'b f c h w -> (b f) c h w')

    video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
    video = video.float()

    return video


prompt = 'Underwater world, Ghibli style'
filename = 'vc2'
height = width = 512

with torch.inference_mode():
    with torch.amp.autocast('cuda'):
        pipe.unet = pipe.unet.to('cuda')
        pipe.text_encoder = pipe.text_encoder.to('cuda')
        
        video_frames = pipe(
            [prompt] * 6,
            height=height,
            width=width,
            num_frames=16, # donot modify this because it is used in the multiplane sync processor
            num_inference_steps=50,
            output_type='latent',
        ).frames  # torch.Tensor, (6, 4, 16, 64, 64)

        pipe.unet = pipe.unet.to('cpu')
        pipe.text_encoder = pipe.text_encoder.to('cpu')
        
        torch.cuda.empty_cache()

        video_frames = decode_latents(pipe.vae, video_frames)
        video_frames = tensor2vid(video_frames, pipe.image_processor, output_type='np')


pano_frames, cube_frames = frames_to_pano_video(video_frames)  # (6, 16, 1024, 2048, 3), ...

Image.fromarray(np.asarray(pano_frames[0][0] * 255.0, np.uint8)).save(f'{filename}_pano_frame.jpg')
Image.fromarray(np.asarray(cube_frames[0][0] * 255.0, np.uint8)).save(f'{filename}_cube_frame.jpg')

export_to_video(pano_frames[0], output_video_path=f'{filename}_pano.mp4', fps=8)
export_to_video(cube_frames[0], output_video_path=f'{filename}_cube.mp4', fps=8)
