In [None]:
from misc_utils.train_utils import unit_test_create_model
from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images
config_path = 'configs/instruct_v2v_inference.yaml'
diffusion_model = unit_test_create_model(config_path)

In [None]:
import torch
ckpt = torch.load('insv2v.pth', map_location='cpu')
diffusion_model.load_state_dict(ckpt, strict=False)

In [None]:
# edit params
EDIT_PROMPT = 'make the car red Porsche and drive alone beach'
VIDEO_CFG = 1.2
TEXT_CFG = 7.5
LONG_VID_SAMPLING_CORRECTION_STEP = 0.5

# video params
VIDEO_PATH = 'data/car-turn.mp4'
IMGSIZE = 384
NUM_FRAMES = 32
VIDEO_SAMPLE_RATE = 10

# sampling params
FRAMES_IN_BATCH = 16
NUM_REF_FRAMES = 4
USE_MOTION_COMPENSATION = True

In [None]:
from pl_trainer.inference.inference import InferenceIP2PVideo, InferenceIP2PVideoOpticalFlow
if USE_MOTION_COMPENSATION:
    inf_pipe = InferenceIP2PVideoOpticalFlow(
        unet = diffusion_model.unet,
        num_ddim_steps=20,
        scheduler='ddpm'
    )
else:
    inf_pipe = InferenceIP2PVideo(
        unet = diffusion_model.unet,
        num_ddim_steps=20,
        scheduler='ddpm'
    )

In [None]:
from dataset.single_video_dataset import SingleVideoDataset
dataset = SingleVideoDataset(
    video_file=VIDEO_PATH,
    video_description='',
    sampling_fps=VIDEO_SAMPLE_RATE,
    num_frames=NUM_FRAMES,
    output_size=(IMGSIZE, IMGSIZE)
)
batch = dataset[20] # start from 20th frame
batch = {k: v.cuda()[None] if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

In [None]:
def split_batch(cond, frames_in_batch=16, num_ref_frames=4):
    frames_in_following_batch = frames_in_batch - num_ref_frames
    conds = [cond[:, :frames_in_batch]]
    frame_ptr = frames_in_batch
    num_ref_frames_each_batch = []

    while frame_ptr < cond.shape[1]:
        remaining_frames = cond.shape[1] - frame_ptr
        if remaining_frames < frames_in_batch:
            frames_in_following_batch = remaining_frames
        else:
            frames_in_following_batch = frames_in_batch - num_ref_frames
        this_ref_frames = frames_in_batch - frames_in_following_batch
        conds.append(cond[:, frame_ptr:frame_ptr+frames_in_following_batch])
        frame_ptr += frames_in_following_batch
        num_ref_frames_each_batch.append(this_ref_frames)

    return conds, num_ref_frames_each_batch

In [None]:
cond = [diffusion_model.encode_image_to_latent(frames) / 0.18215 for frames in batch['frames'].chunk(16, dim=1)] # when encoding, chunk the frames to avoid oom in vae, you can reduce the 16 if you have a smaller gpu
cond = torch.cat(cond, dim=1)
text_cond = diffusion_model.encode_text([EDIT_PROMPT])
text_uncond = diffusion_model.encode_text([''])
conds, num_ref_frames_each_batch = split_batch(cond, frames_in_batch=FRAMES_IN_BATCH, num_ref_frames=NUM_REF_FRAMES)
splitted_frames, _ = split_batch(batch['frames'], frames_in_batch=FRAMES_IN_BATCH, num_ref_frames=NUM_REF_FRAMES)

In [None]:
# First video clip
cond1 = conds[0]
latent_pred_list = []
init_latent = torch.randn_like(cond1)
latent_pred = inf_pipe(
    latent = init_latent,
    text_cond = text_cond,
    text_uncond = text_uncond,
    img_cond = cond1,
    text_cfg = TEXT_CFG,
    img_cfg = VIDEO_CFG,
)['latent']
latent_pred_list.append(latent_pred)


# Subsequent video clips
for prev_cond, cond_, prev_frame, curr_frame, num_ref_frames_ in zip(
    conds[:-1], conds[1:], splitted_frames[:-1], splitted_frames[1:], num_ref_frames_each_batch
):
    init_latent = torch.cat([init_latent[:, -num_ref_frames_:], torch.randn_like(cond_)], dim=1)
    cond_ = torch.cat([prev_cond[:, -num_ref_frames_:], cond_], dim=1)
    if USE_MOTION_COMPENSATION:
        ref_images = prev_frame[:, -num_ref_frames_:]
        query_images = curr_frame
        additional_kwargs = {
            'ref_images': ref_images,
            'query_images': query_images,
        }
    else:
        additional_kwargs = {}
    latent_pred = inf_pipe.second_clip_forward(
        latent = init_latent, 
        text_cond = text_cond,
        text_uncond = text_uncond,
        img_cond = cond_,
        latent_ref = latent_pred[:, -num_ref_frames_:],
        noise_correct_step = LONG_VID_SAMPLING_CORRECTION_STEP,
        text_cfg = TEXT_CFG,
        img_cfg = VIDEO_CFG,
        **additional_kwargs,
    )['latent']
    latent_pred_list.append(latent_pred[:, num_ref_frames_:])

# Save GIF
latent_pred = torch.cat(latent_pred_list, dim=1)
image_pred = diffusion_model.decode_latent_to_image(latent_pred).clip(-1, 1)

In [None]:
original_images = batch['frames'].cpu()
transferred_images = image_pred.float().cpu()
concat_images = torch.cat([original_images, transferred_images], dim=4)

save_tensor_to_gif(concat_images, 'results/video_edit.gif', fps=5)
save_tensor_to_images(transferred_images, 'results/video_edit_images')

In [None]:
# visualize the gif
from IPython.display import Image
Image(filename='results/video_edit.gif')