In [40]:
%load_ext autoreload
%autoreload 2

from icecream import ic
import torch
import numpy as np
from diffusers.models.controlnet import ControlNetModel

from pipelines.pipeline_latentman import StableDiffusionLatentMan
from pipelines.scheduling_ddim import DDIMScheduler
from misc.cross_frame_attn import CrossFrameAttnProcessor

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [46]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
dtype = torch.float16

controlnet_depth = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=dtype) 
pipe = StableDiffusionLatentMan.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet_depth, torch_dtype=dtype, scheduler=scheduler
)
# pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload(gpu_id=0)

# cross frame attention
cross_frame_processor = CrossFrameAttnProcessor(unet_chunk_size=2)
pipe.unet.set_attn_processor(processor=cross_frame_processor)
pipe.controlnet.set_attn_processor(processor=cross_frame_processor)

Loading pipeline components...: 100%|██████████| 7/7 [00:02<00:00,  2.91it/s]


In [43]:
import matplotlib.pyplot as plt
from external.mdm.utils.fixseed import fixseed
from moviepy.editor import ImageSequenceClip, ipython_display
from torchvision.utils import make_grid
from inference import load_motion, get_dense_correspond, added_prompt, negative_prompt

generator = torch.Generator(device="cuda")
BATCH_SIZE = 8 # This batch size needs a GPU with 40 GB of memory, Otherwise, you can reduce it!
SKIP=1  ## A step for skipping frames


########### LatentMan Parameters ################
reference_mode = "prev_new"  # prev | first | prev_new
inference_steps = 30
latent_alignment_steps = range(40)
pixel_wise_guidance_steps = range(0)
rho = 0.01
correspondence_res = [64, 256] if pixel_wise_guidance_steps else [64]
   
        
########### Load Guidance Data and Define Prompt ################
motion_prompt = "valz_dance_5"  # <motion_ptompt>_<seed>
prompt = "An Stormtrooper dances valz" 
seed = 45
start_from = 0  # Which frame to start from
shift_x = 0   # Shift the depth maps horizontally
shift_y = 40 # Shift the depth maps vertically

dps, guidance_imgs, batches, motion_dir, out_dir = load_motion(motion=motion_prompt, prompt=prompt, seed=seed, start_from=start_from, shift_x=shift_x, shift_y=shift_y, batch_size=BATCH_SIZE,skip=SKIP) 


########### Compute Dense Correspondences ################
xy_xy_dict = {}
new_dps_dict = {}        
for new_res in correspondence_res:
    new_dps, xy_xy = get_dense_correspond(dps, new_res, reference_mode, motion_dir, num_dps=BATCH_SIZE)
    new_dps_dict[new_res] = new_dps
    xy_xy_dict[new_res] = xy_xy
    

########### Generate ################
prompts = prompt + ", " + added_prompt

    
fixseed(seed)                   
generator.manual_seed(seed)

ident = f"latentman_"
ident = ident + f"L[{min(latent_alignment_steps)},{max(latent_alignment_steps)}]" if latent_alignment_steps else ident+"L[]"
ident = ident + f"_X[{min(pixel_wise_guidance_steps)},{max(pixel_wise_guidance_steps)}]" if pixel_wise_guidance_steps else ident+"_X[]"
ident = ident + f"_r{rho}_{reference_mode}_B{BATCH_SIZE}"
print(out_dir, ident)       


start_code = torch.randn([1, 4, 64, 64], generator=generator, device=device).to(dtype)

==> Motion: valz_dance_5, 	 start_from: 0, 	 shift_x: 0, 	 shift_y: 40
60 60
Using `prev_new` reference: [0, 1, 2, 3, 4, 5, 6]


100%|██████████| 7/7 [00:00<00:00, 48.78it/s]

8 8
/home/jorge/thesis/data/workspace/valz_dance_5/An_Stormtrooper_dances_valz_s45_x0_y40_f0 latentman_L[0,39]_X[]_r0.01_prev_new_B8





In [44]:
# # inference the synthesized image
out = []
for i, batch in enumerate(batches):        
    
    latents = start_code.repeat(len(batch), 1, 1, 1)   
    img_new = pipe(
        [prompts]*len(batch),
        [batch/255],
        latents=latents,
        num_inference_steps=inference_steps,
        generator=generator,
        negative_prompt=[negative_prompt]*len(batch),
        controlnet_conditioning_scale=1.0,
        guidance_scale=9.0,
        sla_steps = latent_alignment_steps,
        pwg_steps = pixel_wise_guidance_steps,
        rho=rho,
        dps_low=xy_xy_dict,
        reference_mode=reference_mode,                
    ).images

    out = [np.asarray(img) for img in img_new ]

    # The code only supports processing the first batch, but can be extended 
    # for more batches in an autoregreissve manner
    if i == 0 :     
        break

100%|██████████| 30/30 [00:11<00:00,  2.65it/s]


In [45]:
print(len(out))

out_path =f'outs/video_{ident}.mp4'
vid = ImageSequenceClip(out, fps=10)
vid.write_videofile(out_path)
img_collage = make_grid([torch.as_tensor(np.asarray(img).transpose(2,0,1)) for img in img_new], padding=10, pad_value=255).permute(1,2,0).numpy()
plt.imsave(out_path.replace("mp4", "png"), img_collage)
print(f"Saved to {out_path}")
ipython_display(vid)

8
Moviepy - Building video outs/video_latentman_L[0,39]_X[]_r0.01_prev_new_B8.mp4.
Moviepy - Writing video outs/video_latentman_L[0,39]_X[]_r0.01_prev_new_B8.mp4



                                                  

Moviepy - Done !
Moviepy - video ready outs/video_latentman_L[0,39]_X[]_r0.01_prev_new_B8.mp4
Saved to outs/video_latentman_L[0,39]_X[]_r0.01_prev_new_B8.mp4
Moviepy - Building video __temp__.mp4.
Moviepy - Writing video __temp__.mp4



                                                  

Moviepy - Done !
Moviepy - video ready __temp__.mp4
