# Load the diffusion model, SAM, Mediapipe

In [1]:
import torch
from dataclasses import dataclass
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os.path as osp
import skimage.io as io
import cv2
import mediapipe as mp
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from models import vqvae 
from models import vit
from diffusion import create_diffusion
from utils.utils import (
    scale_keypoint,  
    keypoint_heatmap, 
    check_keypoints_validity)
from utils.segment_hoi import init_sam, show_mask
import pickle
        

def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return text[len(prefix) :]
    return text


def unnormalize(x):
    return (((x + 1) / 2) * 255).astype(np.uint8)


def visualize_hand(ax, all_joints, img):
# Define the connections between joints for drawing lines and their corresponding colors
    connections = [
        ((0, 1), 'red'), ((1, 2), 'green'), ((2, 3), 'blue'), ((3, 4), 'purple'),
        ((0, 5), 'orange'), ((5, 6), 'pink'), ((6, 7), 'brown'), ((7, 8), 'cyan'),
        ((0, 9), 'yellow'), ((9, 10), 'magenta'), ((10, 11), 'lime'), ((11, 12), 'indigo'),
        ((0, 13), 'olive'), ((13, 14), 'teal'), ((14, 15), 'navy'), ((15, 16), 'gray'),
        ((0, 17), 'lavender'), ((17, 18), 'silver'), ((18, 19), 'maroon'), ((19, 20), 'fuchsia')
    ]
    H, W, C = img.shape
    
    # Plot joints as points
    ax.imshow(img)
    for start_i in [0, 21]: 
        joints = all_joints[start_i: start_i+21]
        for connection, color in connections:
            joint1 = joints[connection[0]]
            joint2 = joints[connection[1]]
            ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color)

    ax.set_xlim([0, W])
    ax.set_ylim([0, H])
    ax.grid(False)
    ax.set_axis_off()
    ax.invert_yaxis()
    

@dataclass
class HandDiffOpts:
    run_name: str = 'ViT_256_handmask_heatmap_nvs_b25_lr1e-5'
    sd_path: str = '/users/kchen157/scratch/weights/SD/sd-v1-4.ckpt'
    log_dir: str = '/users/kchen157/scratch/log'
    data_root: str = '/users/kchen157/data/users/kchen157/dataset/handdiff'
    image_size: tuple = (256, 256)
    latent_size: tuple = (32, 32)
    latent_dim: int = 4
    mask_bg: bool = False
    kpts_form: str = 'heatmap'
    n_keypoints: int = 42
    n_mask: int = 1
    noise_steps: int = 1000
    test_sampling_steps: int = 250
    ddim_steps: int = 100
    ddim_discretize: str = "uniform"
    ddim_eta: float = 0.
    beta_start: float = 8.5e-4
    beta_end: float = 0.012
    latent_scaling_factor: float = 0.18215
    cfg_pose: float = 5.
    cfg_appearance: float = 3.5
    batch_size: int = 25
    lr: float = 1e-5
    max_epochs: int = 500
    log_every_n_steps: int = 100
    limit_val_batches: int = 1
    n_gpu: int = 8
    num_nodes: int = 1
    precision: str = '16-mixed'
    profiler: str = 'simple'
    swa_epoch_start: int = 10
    swa_lrs: float = 1e-3
    num_workers: int = 10
    n_val_samples: int = 4
        

opts = HandDiffOpts()
model_weights_dir = '../weights'
model_path = osp.join(model_weights_dir, 'DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt')
vae_path = osp.join(model_weights_dir, 'vae-ft-mse-840000-ema-pruned.ckpt')
sam_path = osp.join(model_weights_dir, 'sam_vit_h_4b8939.pth')

print('Load diffusion model...')
diffusion = create_diffusion(str(opts.test_sampling_steps))
model = vit.DiT_XL_2(
    input_size=opts.latent_size[0],
    latent_dim=opts.latent_dim,
    in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
    learn_sigma=True,
).cuda()

# ckpt_state_dict = torch.load(model_path)['model_state_dict']
ckpt_state_dict = torch.load(model_path)['ema_state_dict']
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
model.eval()
print(missing_keys, extra_keys)
assert len(missing_keys) == 0


vae_state_dict = torch.load(vae_path)['state_dict']
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).cuda()
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
autoencoder.eval()
assert len(missing_keys) == 0


print('Mediapipe hand detector and SAM ready...')
mp_hands = mp.solutions.hands
hands = mp_hands.Hands(
    static_image_mode=True,  # Use False if image is part of a video stream
    max_num_hands=2,         # Maximum number of hands to detect
    min_detection_confidence=0.1)
sam_predictor = init_sam(ckpt_path=sam_path)

Load diffusion model...
[] ['projector.0.weight', 'projector.0.bias', 'projector.2.weight', 'projector.2.bias', 'projector.4.weight', 'projector.4.bias']
Mediapipe hand detector and SAM ready...


I0000 00:00:1750404705.115486 3524965 gl_context_egl.cc:85] Successfully initialized EGL. Major : 1 Minor: 5
I0000 00:00:1750404705.144195 3702037 gl_context.cc:357] GL version: 3.2 (OpenGL ES 3.2 NVIDIA 535.129.03), renderer: NVIDIA RTX A6000/PCIe/SSE2
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
W0000 00:00:1750404705.194073 3701596 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1750404705.205530 3701630 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


In [None]:
import pickle

def draw_keypoint_trajectories(image_path, keypoints_list, image_size=(256, 256)):
    # Create a blank image to draw on
    image = cv2.imread(image_path)
    
    # Number of keypoints (based on the first time step)
    num_keypoints = 42
    
    # Use a color map to get different colors for each keypoint
    cmap = plt.get_cmap('hsv')
    colors = [cmap(i / num_keypoints) for i in range(num_keypoints)]
    
    # Iterate over keypoints
    for kp_idx in range(num_keypoints):
        # Iterate over time steps and draw lines connecting consecutive positions
        for t in range(1, len(keypoints_list)):
            pt1 = tuple(keypoints_list[t-1][kp_idx].astype(int))  # Previous position
            pt2 = tuple(keypoints_list[t][kp_idx].astype(int))    # Current position
            # Convert matplotlib color to BGR format (used in OpenCV)
            color = tuple(int(255 * c) for c in colors[kp_idx][:3][::-1])  # Convert to BGR
            # Draw line between consecutive points
            cv2.line(image, pt1, pt2, color, 1)  # Line thickness of 2
    
    return image


data_root = '../test_data/iphone_video'
idx = 'IMG_1173'
start_frame = 6
max_frames = 100
image_file = osp.join(data_root, idx, f'{start_frame:04d}.jpg')
path_file = osp.join(data_root, f'{idx}.pkl')

image_idx = idx
path_idx = idx
right_hand_only = False
with open(path_file, 'rb') as f: 
    data = pickle.load(f)

print(len(data[start_frame:start_frame+max_frames]))
image_with_trajectories = draw_keypoint_trajectories(image_file, data[start_frame:start_frame+max_frames])

# Display the resulting image using matplotlib
plt.imshow(cv2.cvtColor(image_with_trajectories, cv2.COLOR_BGR2RGB))
plt.grid(False)
plt.axis('off')
plt.show()

# Get sequence poses

In [None]:
with open(path_file, 'rb') as f:
    keypts_sequence = pickle.load(f)
    
target_conds = []
for keypts in keypts_sequence[start_frame:start_frame+max_frames]:
    kpts_valid = check_keypoints_validity(keypts, opts.image_size)
    if right_hand_only:
        kpts_valid[21:] *= 0 
    target_heatmaps = torch.tensor(keypoint_heatmap(
        scale_keypoint(keypts, opts.image_size, opts.latent_size), 
        opts.latent_size, var=1.) * kpts_valid[:, None, None], dtype=torch.float, device='cuda')[None, ...]
    target_cond = torch.cat([
        target_heatmaps, 
        torch.zeros((1, 1, opts.latent_size[0], opts.latent_size[1])).to(target_heatmaps)], 1)
    target_conds.append(target_cond)


# Load the reference image and get annotations.

In [None]:
def make_ref_cond(img, keypts, hand_mask, device='cuda', target_size=(256, 256), latent_size=(32, 32)):
    image_transform=Compose([
        ToTensor(),
        Resize(target_size),
        Normalize(
            mean=[0.5, 0.5, 0.5], 
            std=[0.5, 0.5, 0.5], inplace=True),
    ])
    image = image_transform(img).to(device)
    kpts_valid = check_keypoints_validity(keypts, target_size)
    heatmaps = torch.tensor(keypoint_heatmap(
        scale_keypoint(keypts, target_size, latent_size), 
        latent_size, var=1.) * kpts_valid[:, None, None], dtype=torch.float, device=device)[None, ...]
    mask = torch.tensor(
        cv2.resize(hand_mask.astype(int), dsize=latent_size, interpolation=cv2.INTER_NEAREST), 
        dtype=torch.float, device=device).unsqueeze(0)[None, ...]
    return image[None, ...], heatmaps, mask

ref_conds = []
bootstrap_frames = [start_frame]
for k, frame in enumerate(bootstrap_frames):
    image_file = osp.join(data_root, idx, f'{frame:04d}.jpg')
    flip_image = False
    img = io.imread(image_file)
    print(img.shape)
    if flip_image:
        img = np.fliplr(img)
    img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
    keypts = keypts_sequence[frame]

    sam_predictor.set_image(img)
    l = keypts[:21].shape[0]
    if keypts[0].sum() != 0 and keypts[21].sum() != 0:
        input_point = np.array([keypts[0], keypts[21]])
        input_label = np.array([1, 1])
    elif keypts[0].sum() != 0:
        input_point = np.array(keypts[:1])
        input_label = np.array([1])
    elif keypts[21].sum() != 0:
        input_point = np.array(keypts[21:22])
        input_label = np.array([1])
    masks, _, _ = sam_predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False,
    )
    hand_mask = masks[0]
    masked_img = img * hand_mask[..., None] + 255*(1 - hand_mask[..., None])
    fig, axs = plt.subplots(1, 2, figsize=(3*2, 3))
    visualize_hand(axs[0], keypts, img)
    visualize_hand(axs[1], keypts, masked_img)
    plt.tight_layout()
    plt.show()
    
    
    image, heatmaps, mask = make_ref_cond(
        img, keypts, hand_mask, device='cuda', target_size=opts.image_size, latent_size=opts.latent_size)
    latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
    if k == 0:
        ref_image = img
        ref_keypts = keypts
        src_ref_cond = torch.cat([latent, heatmaps, mask], 1)
    else:
        ref_conds.append(torch.cat([latent, heatmaps, mask], 1))
        
print('ref_conds:', len(ref_conds))

# Sample from diffusion model

In [None]:
from tqdm import tqdm

def frames_to_video(frames, output_path, fps=30, resize=None, rgb2bgr=False):
    """
    Convert a list of frames (numpy arrays) to a video using OpenCV.

    Args:
    - frames (list or numpy.ndarray): List of numpy frames.
    - output_path (str): Path where the video will be saved.
    - fps (int): Frames per second.
    - resize (tuple or None): Resize frames to (width, height) if not None.
    """
    if len(frames) == 0:
        raise ValueError("No frames provided")

    height, width = frames[0].shape[:2]

    # If a resize shape is provided, adjust width and height
    if resize is not None:
        width, height = resize

    fourcc = cv2.VideoWriter_fourcc(*'XVID')  # Codec for .avi format
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    for frame in tqdm(frames):
        if len(frames[0].shape) == 2:
            frame = frame[..., None].repeat(3, axis=-1)
        if resize is not None:
            frame = cv2.resize(frame, (width, height))
        if rgb2bgr:
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        out.write(frame)

    out.release()


cfg_scale = 2.5
last_N_frames = 1
save_dir = './video_results'
# novel view synthesis mode = off
nvs = torch.zeros(1, dtype=torch.int, device='cuda')
z = torch.randn((1, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device='cuda')
z = torch.cat([z, z], 0)

temp_ref_conds = []
video_frames = [] 
for k, target_cond in enumerate(target_conds): 
    print(f'{k}/{min(max_frames, len(target_conds))}')
    model_kwargs = dict(
        target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]), 
        ref_cond=torch.cat([src_ref_cond, torch.zeros_like(src_ref_cond)]), 
        nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]), 
        cfg_scale=cfg_scale)
    
    samples, _ = diffusion.p_sample_loop(
        model.forward_with_cfg, z.shape, z, clip_denoised=False,
        model_kwargs=model_kwargs, ref_conds=[src_ref_cond] + ref_conds + temp_ref_conds, progress=True, device='cuda'
    ).chunk(2)
    
    sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor) 
    sampled_images = torch.clamp(sampled_images, min=-1., max=1.)
    sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
    sampled_image = sampled_images[0]
    video_frames.append(sampled_image)
    
    sam_predictor.set_image(sampled_image)
    masks, _, _ = sam_predictor.predict(
        point_coords=np.array([keypts_sequence[start_frame+k][0]]),
        point_labels=np.array([1]),
        multimask_output=False,
    )
    hand_mask = masks[0]
    mask = torch.tensor(
            cv2.resize(masks[0].astype(int), dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST), 
            dtype=torch.float, device='cuda').unsqueeze(0)[None, ...]
    ref_cond = torch.cat([samples, target_cond[:, :-1], mask], 1) 
    if len(temp_ref_conds) >= last_N_frames and len(temp_ref_conds) > 0:
        temp_ref_conds.pop(0)  
    temp_ref_conds.append(torch.cat([samples, target_cond], 1))

    #visualize
    fig, axs = plt.subplots(1, 3, figsize=(6*3, 6))
    for i, vis_img in enumerate([ref_image, sampled_image]):
        axs[i].imshow(vis_img)
        axs[i].axis('off')
        axs[i].grid(False)
    visualize_hand(axs[2], keypts_sequence[start_frame+k], sampled_image)
    axs[2].imshow(cv2.resize(
        target_cond.cpu().numpy()[0, :42].sum(0), opts.image_size, interpolation=cv2.INTER_AREA), cmap='hot', alpha=0.5)
    plt.tight_layout()
    plt.title(f'{k}/{len(target_conds)}')
    plt.show()
    io.imsave(osp.join(save_dir, f'{idx}_{k}.jpg'), sampled_image)


frames_to_video(video_frames, osp.join(save_dir, f'{idx}.mp4'), fps=20, rgb2bgr=True)