In [1]:
%load_ext autoreload
%autoreload 2

import torch
from diffusers.models import ControlNetModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from pytorch3d.renderer import FoVOrthographicCameras
from pytorch3d.io import load_obj, load_objs_as_meshes

from text3d2video.utilities.camera_placement import front_facing_extrinsics
from text3d2video.utilities.mesh_processing import normalize_meshes
from text3d2video.rendering import render_depth_map
from text3d2video.utilities.video_util import pil_frames_to_clip

torch.set_grad_enabled(False)

mesh_path = "data/meshes/mixamo-human.obj"
device = "cuda"

# read mesh
verts, faces, aux = load_obj(mesh_path)
verts_uvs = aux.verts_uvs.to(device)
faces_uvs = faces.textures_idx.to(device)

mesh = load_objs_as_meshes([mesh_path], device=device)
mesh = normalize_meshes(mesh)

s = 1.8
dist = 1

In [2]:

# target frames
n_frames = 20

angles = torch.linspace(0, 360, n_frames)
xs = torch.linspace(-0.5, 0.5, 20)
R, T = front_facing_extrinsics(degrees=angles, zs=dist)

frame_cams = FoVOrthographicCameras(R=R, T=T, device="cuda", scale_xyz=[(s, s, s)])
frame_meshes = mesh.extend(len(frame_cams))

# animation = AnimationArtifact.from_wandb_artifact_tag('handstand:latest')
# frame_indices = animation.frame_indices(n_frames)
# frame_cams, frame_meshes = animation.load_frames(frame_indices)

In [3]:
import numpy as np
from text3d2video.utilities.camera_placement import turntable_extrinsics
from pytorch3d.renderer import FoVPerspectiveCameras

R, T = front_facing_extrinsics(degrees=[0], zs=dist)

aggr_cams = FoVOrthographicCameras(R=R, T=T, device="cuda", scale_xyz=[(s, s, s)])
aggr_meshes = mesh.extend(len(aggr_cams))

# angles = np.linspace(0, 360, 5, endpoint=False)
# angles = list(reversed(list(angles)))

# R, T = turntable_extrinsics(angles=angles, dists=1)
# aggr_cams = FoVPerspectiveCameras(R=R, T=T, device="cuda", fov=60)
# aggr_meshes = mesh.extend(len(aggr_cams))

In [4]:
from text3d2video.utilities.ipython_utils import display_vid


frame_depths = render_depth_map(frame_meshes, frame_cams)
aggr_depths = render_depth_map(aggr_meshes, aggr_cams)
torch.cuda.empty_cache()

# display_vids([pil_frames_to_clip(aggr_depths), pil_frames_to_clip(frame_depths)])
display(display_vid(pil_frames_to_clip(aggr_depths)))
display(display_vid(pil_frames_to_clip(frame_depths)))

In [5]:
from text3d2video.generative_rendering.reposable_diffusion_pipeline import (
    ReposableDiffusionPipeline,
)

device = torch.device("cuda")
dtype = torch.float16

sd_repo = "runwayml/stable-diffusion-v1-5"
controlnet_repo = "lllyasviel/control_v11f1p_sd15_depth"

controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=dtype).to(
    device
)

pipe: ReposableDiffusionPipeline = ReposableDiffusionPipeline.from_pretrained(
    sd_repo, controlnet=controlnet, torch_dtype=dtype
).to(device)

pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

In [12]:
from text3d2video.artifacts.gr_data import GrSaveConfig
from text3d2video.generative_rendering.configs import ReposableDiffusionConfig
from text3d2video.noise_initialization import FixedNoiseInitializer, UVNoiseInitializer

module_paths = [
    # "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
    # "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
    # "down_blocks.1.attentions.0.transformer_blocks.0.attn1",
    # "down_blocks.1.attentions.1.transformer_blocks.0.attn1",
    # "down_blocks.2.attentions.0.transformer_blocks.0.attn1",
    # "down_blocks.2.attentions.1.transformer_blocks.0.attn1",
    # "mid_block.attentions.0.transformer_blocks.0.attn1",
    "up_blocks.1.attentions.0.transformer_blocks.0.attn1",
    "up_blocks.1.attentions.1.transformer_blocks.0.attn1",
    "up_blocks.1.attentions.2.transformer_blocks.0.attn1",
    "up_blocks.2.attentions.0.transformer_blocks.0.attn1",
    "up_blocks.2.attentions.1.transformer_blocks.0.attn1",
    "up_blocks.2.attentions.2.transformer_blocks.0.attn1",
    "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
    "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
    "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
]

rd_config = ReposableDiffusionConfig(
    seed=1,
    resolution=512,
    do_pre_attn_injection=True,
    do_post_attn_injection=True,
    aggregate_queries=True,
    feature_blend_alpha=0.8,
    attend_to_self_kv=True,
    mean_features_weight=1.0,
    chunk_size=5,
    num_inference_steps=10,
    guidance_scale=7.5,
    controlnet_conditioning_scale=1.0,
    module_paths=module_paths,
    noise_threshold=1,
)

gr_save_cfg = GrSaveConfig(
    enabled=False,
    n_frames=5,
    n_timesteps=5,
    save_latents=False,
    save_q=True,
    save_k=True,
    save_v=True,
    out_artifact="rumba",
    module_paths=module_paths,
    save_kf_post_attn=False,
    save_aggregated_features=False,
    save_feature_images=False,
)

prompt = "Deadpool, simple blank background"

noise_initializer = FixedNoiseInitializer()
noise_initializer = UVNoiseInitializer()

video_frames = pipe(
    prompt,
    frame_meshes,
    frame_cams,
    aggr_meshes,
    aggr_cams,
    verts_uvs,
    faces_uvs,
    reposable_diffusion_config=rd_config,
    noise_initializer=noise_initializer,
    gr_save_config=gr_save_cfg,
)

art = pipe.gr_data_artifact

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:33<00:00,  3.31s/it]


In [None]:
from text3d2video.utilities.ipython_utils import display_vid

vid_frames = video_frames[0:len(frame_cams)]
aggr_frames = video_frames[len(frame_cams):]

display(display_vid(pil_frames_to_clip(vid_frames)))
display(display_vid(pil_frames_to_clip(aggr_frames)))

: 

In [38]:
import os
from text3d2video.artifacts.gr_data import GrDataArtifact
from text3d2video.sd_feature_extraction import read_layer_paths
from text3d2video.utilities.h5_util import print_datasets
print(os.system(f'du -h {art.folder}'))

art: GrDataArtifact = art

frame_indices = art.diffusion_data.save_frame_indices
time_steps = art.diffusion_data.save_step_times
modules = art.diffusion_data.save_module_paths

enc_layers, mid_layers, dec_layers = read_layer_paths(modules)

print_datasets(art.h5_file_path())

8,0K	/tmp/local_artifacts/gr_data/rumba
0
