In [None]:
%load_ext autoreload
%autoreload 2

import torch
from text3d2video.artifacts.anim_artifact import AnimationArtifact
from text3d2video.artifacts.texture_artifact import TextureArtifact

torch.set_grad_enabled(False)

In [43]:
from text3d2video.pipelines.generative_rendering_pipeline import GenerativeRenderingPipeline
from text3d2video.pipelines.pipeline_utils import load_pipeline

pipe = load_pipeline(GenerativeRenderingPipeline)

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

In [44]:
tex_tag = "human_mvlatest_Stormtrooper:latest"
tex_tag = 'cat_statue_mvlatest_MetalicCatStatue:latest'
tex_tag = "human_mvlatest_Deadpool:latest"

texture_art = TextureArtifact.from_wandb_artifact_tag(tex_tag)
texture = texture_art.read_texture()

anim_tag = 'mma_20:latest'
anim_tag = 'mv_cat_statue:latest'
anim_tag = 'ymca_20:latest'
anim_tag = 'rumba_20:latest'
anim_tag = 'mv_cat_statue_25:latest'
anim_tag = 'catwalk_180_20:latest'
anim_tag = 'mv_helmet:latest'
anim_tag = 'mv_helmet_25:latest'

anim_art = AnimationArtifact.from_wandb_artifact_tag(anim_tag)
anim = anim_art.read_anim_seq()

In [45]:
from text3d2video.artifacts.anim_artifact import AnimSequence
from text3d2video.util import ordered_sample_indices

# src_indices = [0, ]
# src_indices = [0, 15]
# src_indices = [-1]
src_indices = ordered_sample_indices(anim.cams, 5).tolist()
# src_indices = [0]

src_anim = AnimSequence(
    cams=anim.cams[src_indices],
    meshes=anim.meshes[src_indices],
    verts_uvs=anim.verts_uvs,
    faces_uvs=anim.faces_uvs,
)

In [32]:
src_anim = AnimationArtifact.from_wandb_artifact_tag(
    "mv_cat_statue:latest"
).read_anim_seq()

In [33]:
from torch import Tensor
from pytorch3d.structures.meshes import join_meshes_as_batch
from pytorch3d.renderer.camera_utils import join_cameras_as_batch

mesh = anim.meshes[-1]
cam = anim.cams[-1]

offsets = Tensor([[0, 0.0, 0], [0, 0.3, 0], [0, 0.6, 0], [0, 0.9, 0]]).cuda()

cams = []
for o in offsets:
    shifted = cam.clone()
    shifted.T = shifted.T + o
    cams.append(shifted)

cams = join_cameras_as_batch(cams)
meshes = join_meshes_as_batch([mesh] * len(offsets))

src_anim = AnimSequence(
    cams=cams,
    meshes=meshes,
    verts_uvs=anim.verts_uvs,
    faces_uvs=anim.faces_uvs,
)

In [68]:
from text3d2video.rendering import render_texture
from text3d2video.util import chw_to_hwc
from text3d2video.utilities.testing_utils import checkerboard_img
from text3d2video.utilities.video_comparison import display_vid, video_grid
from text3d2video.utilities.video_util import pil_frames_to_clip

prompt = "Shiny Mandalorian Helmet"
input_src_anim = None
input_src_anim = src_anim

input_texture = texture
input_texture = None
start_noise_level = 0.20


def inputs_video(prompt, input_src_anim, input_texture, start_noise_level):
    if input_texture is None:
        white = (256, 256, 256)
        gray = (200, 200, 200)
        texture = checkerboard_img(
            return_type="pt", color1=white, color2=gray, res=500, square_size=30
        ).cuda()
        texture = chw_to_hwc(texture)
    else:
        texture = input_texture

    vids = []
    titles = []
    if input_src_anim is not None:
        src_frames = render_texture(
            input_src_anim.meshes,
            input_src_anim.cams,
            texture,
            input_src_anim.verts_uvs,
            input_src_anim.faces_uvs,
            return_pil=True,
        )
        src_vid = pil_frames_to_clip(src_frames)
        vids.append(src_vid)
        titles.append(f"src({start_noise_level})")

    anim_frames = render_texture(
        anim.meshes,
        anim.cams,
        texture,
        anim.verts_uvs,
        anim.faces_uvs,
        return_pil=True,
    )

    anim_vid = pil_frames_to_clip(anim_frames)
    vids.append(anim_vid)
    titles.append("anim")

    return video_grid(
        [vids], x_labels=titles, padding_mode="slow_down", y_labels=[prompt]
    )


display_vid(inputs_video(prompt, input_src_anim, input_texture, start_noise_level))

In [75]:
from pathlib import Path
from torch import Generator
from text3d2video.pipelines.generative_rendering_pipeline import (
    GenerativeRenderingConfig,
)
from text3d2video.utilities.logging import H5Logger

gr_config = GenerativeRenderingConfig(
    num_inference_steps=15,
    do_pre_attn_injection=True,
    do_post_attn_injection=True,
    num_keyframes=5,
    attend_to_self_kv=True,
    guidance_scale=7.5
)

torch.cuda.empty_cache()

generator = Generator(device="cuda").manual_seed(0)
kf_generator = Generator(device="cuda").manual_seed(1)

logger = H5Logger(Path("data.h5"), enabled=True)
logger.delete_data()
logger.open_write()

logger.key_greenlists["layer"] = ["up_blocks.3.attentions.2.transformer_blocks.0.attn1"]

out = pipe(
    prompt,
    anim,
    gr_config,
    src_anim=input_src_anim,
    start_noise_level=start_noise_level,
    texture=input_texture,
    logger=logger,
    kf_generator=kf_generator,
    generator=generator,
)

logger.close()
logger.open_read()

100%|██████████| 15/15 [01:10<00:00,  4.68s/it]


In [76]:
from text3d2video.utilities.video_comparison import display_vids

vid = pil_frames_to_clip(out.images)
vid_extraction = pil_frames_to_clip(out.extr_images)

display_vids([vid_extraction, vid], titles=["extr", "anim"], padding_mode='slow_down')

In [67]:
from text3d2video.backprojection import (
    compute_texel_projections,
    project_views_to_video_texture,
)
import torchvision.transforms.functional as TF


projections = compute_texel_projections(
    anim.meshes, anim.cams, anim.verts_uvs, anim.faces_uvs, 500
)
frames_pt = torch.stack([TF.to_tensor(i) for i in out.images]).cuda()
uv_vid = project_views_to_video_texture(frames_pt, 500, projections)
uv_vid_pil = [TF.to_pil_image(i) for i in uv_vid]
uv_vid_pil = pil_frames_to_clip(uv_vid_pil)
display_vid(uv_vid_pil)