In [1]:
%load_ext autoreload
%autoreload 2

import torch
from text3d2video.artifacts.anim_artifact import AnimationArtifact
from text3d2video.artifacts.texture_artifact import TextureArtifact

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f53cd167cd0>

In [2]:
from text3d2video.pipelines.generative_rendering_pipeline import GenerativeRenderingPipeline
from text3d2video.pipelines.pipeline_utils import load_pipeline

pipe = load_pipeline(GenerativeRenderingPipeline)

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

In [3]:
tex_tag = "human_mvlatest_Stormtrooper:latest"
tex_tag = "human_mvlatest_Deadpool:latest"
tex_tag = 'cat_statue_mvlatest_SilverCatStatue:latest'

texture_art = TextureArtifact.from_wandb_artifact_tag(tex_tag)
texture = texture_art.read_texture()

anim_tag = 'mma_20:latest'
anim_tag = 'mv_cat_statue:latest'
anim_tag = 'ymca_20:latest'
anim_tag = 'rumba_20:latest'
anim_tag = 'catwalk_180_20:latest'
anim_tag = 'mv_cat_statue_25:latest'

anim_art = AnimationArtifact.from_wandb_artifact_tag(anim_tag)
anim = anim_art.read_anim_seq()

In [4]:
from text3d2video.rendering import render_texture
from text3d2video.utilities.video_comparison import display_vids
from text3d2video.utilities.video_util import pil_frames_to_clip

anim_uvs = anim.render_rgb_uv_maps()
anim_renders = render_texture(
    anim.meshes, anim.cams, texture, anim.verts_uvs, anim.faces_uvs, return_pil=True
)
display_vids([pil_frames_to_clip(anim_renders), pil_frames_to_clip(anim_uvs)])

In [19]:
from text3d2video.artifacts.anim_artifact import AnimSequence

# src_indices = [0, ]
# src_indices = [0, 15]
# src_indices = [-1]
# src_indices = ordered_sample_indices(anim_renders, 5).tolist()
src_indices = [0, 4, 8, 12, 16]

src_anim = AnimSequence(
    cams=anim.cams[src_indices],
    meshes=anim.meshes[src_indices],
    verts_uvs=anim.verts_uvs,
    faces_uvs=anim.faces_uvs,
)

In [8]:
src_anim = AnimationArtifact.from_wandb_artifact_tag(
    "mv_cat_statue:latest"
).read_anim_seq()

In [7]:
from torch import Tensor
from pytorch3d.structures.meshes import join_meshes_as_batch
from pytorch3d.renderer.camera_utils import join_cameras_as_batch

mesh = anim.meshes[-1]
cam = anim.cams[-1]

offsets = Tensor([[0, 0.0, 0], [0, 0.3, 0], [0, 0.6, 0], [0, 0.9, 0]]).cuda()

cams = []
for o in offsets:
    shifted = cam.clone()
    shifted.T = shifted.T + o
    cams.append(shifted)

cams = join_cameras_as_batch(cams)
meshes = join_meshes_as_batch([mesh] * len(offsets))

src_anim = AnimSequence(
    cams=cams,
    meshes=meshes,
    verts_uvs=anim.verts_uvs,
    faces_uvs=anim.faces_uvs,
)

In [20]:
from text3d2video.utilities.video_comparison import display_vid

src_vid = pil_frames_to_clip(src_anim.render_rgb_uv_maps(), fps=5)
display_vid(src_vid, title="src")

In [23]:
from pathlib import Path
from torch import Generator
from text3d2video.pipelines.generative_rendering_pipeline import (
    GenerativeRenderingConfig,
)
from text3d2video.utilities.logging import H5Logger

gr_config = GenerativeRenderingConfig(
    num_inference_steps=15,
    do_pre_attn_injection=True,
    do_post_attn_injection=False,
    num_keyframes=5,
    attend_to_self_kv=True,
)

torch.cuda.empty_cache()

generator = Generator(device="cuda").manual_seed(1)
kf_generator = Generator(device="cuda").manual_seed(1)

logger = H5Logger(Path("data.h5"), enabled=False)
logger.delete_data()
logger.open_write()

logger.key_greenlists["layer"] = ["up_blocks.3.attentions.2.transformer_blocks.0.attn1"]

start_noise_level = 0.25

input_src_anim = None
input_src_anim = src_anim

input_texture = None
input_texture = texture

out = pipe(
    "Shiny Cat Statue",
    anim,
    gr_config,
    src_anim=input_src_anim,
    start_noise_level=start_noise_level,
    texture=input_texture,
    logger=logger,
    kf_generator=kf_generator,
    generator=generator,
)

logger.close()

100%|██████████| 12/12 [00:17<00:00,  1.49s/it]


In [24]:
vids = [pil_frames_to_clip(out.images)]

title = "Video"
if input_texture is not None:
    title += f" (from_noise {start_noise_level})"

titles = [title]

if out.extr_images is not None:
    extr_vid = pil_frames_to_clip(out.extr_images, fps=10)
    vids.append(extr_vid)
    titles.append("Extraction Images")

# prepend vid
if input_texture is not None:
    vids = [pil_frames_to_clip(anim_renders)] + vids
    titles = ["Renders"] + titles

display_vids(vids, titles=titles, padding_mode="slow_down")

: 