# Load the diffusion model, SAM, Mediapipe

In [1]:
import torch
import random
from dataclasses import dataclass
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os
import os.path as osp
import skimage.io as io
import cv2
import mediapipe as mp
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from models import vqvae 
from models import vit
from diffusion import create_diffusion
from utils.utils import (
    scale_keypoint,  
    keypoint_heatmap, 
    check_keypoints_validity)
from utils.segment_hoi import init_sam, show_mask
from tqdm import tqdm
from scipy.spatial import ConvexHull


def frames_to_video(frames, output_path, fps=30, resize=None, rgb2bgr=False):
    """
    Convert a list of frames (numpy arrays) to a video using OpenCV.

    Args:
    - frames (list or numpy.ndarray): List of numpy frames.
    - output_path (str): Path where the video will be saved.
    - fps (int): Frames per second.
    - resize (tuple or None): Resize frames to (width, height) if not None.
    """
    if len(frames) == 0:
        raise ValueError("No frames provided")

    height, width = frames[0].shape[:2]

    # If a resize shape is provided, adjust width and height
    if resize is not None:
        width, height = resize

    fourcc = cv2.VideoWriter_fourcc(*'XVID')  # Codec for .avi format
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    for frame in tqdm(frames):
        if len(frames[0].shape) == 2:
            frame = frame[..., None].repeat(3, axis=-1)
        if resize is not None:
            frame = cv2.resize(frame, (width, height))
        if rgb2bgr:
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        out.write(frame)

    out.release()
    

def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return text[len(prefix) :]
    return text


def unnormalize(x):
    return (((x + 1) / 2) * 255).astype(np.uint8)


def make_ref_cond(img, keypts, hand_mask, device='cuda', target_size=(256, 256), latent_size=(32, 32)):
    image_transform=Compose([
        ToTensor(),
        Resize(target_size),
        Normalize(
            mean=[0.5, 0.5, 0.5], 
            std=[0.5, 0.5, 0.5], inplace=True),
    ])
    image = image_transform(img).to(device)
    kpts_valid = check_keypoints_validity(keypts, target_size)
    heatmaps = torch.tensor(keypoint_heatmap(
        scale_keypoint(keypts, target_size, latent_size), 
        latent_size, var=1.) * kpts_valid[:, None, None], dtype=torch.float, device=device)[None, ...]
    mask = torch.tensor(
        cv2.resize(hand_mask.astype(int), dsize=latent_size, interpolation=cv2.INTER_NEAREST), 
        dtype=torch.float, device=device).unsqueeze(0)[None, ...]
    return image[None, ...], heatmaps, mask


def visualize_hand(ax, all_joints, img):
# Define the connections between joints for drawing lines and their corresponding colors
    connections = [
        ((0, 1), 'red'), ((1, 2), 'green'), ((2, 3), 'blue'), ((3, 4), 'purple'),
        ((0, 5), 'orange'), ((5, 6), 'pink'), ((6, 7), 'brown'), ((7, 8), 'cyan'),
        ((0, 9), 'yellow'), ((9, 10), 'magenta'), ((10, 11), 'lime'), ((11, 12), 'indigo'),
        ((0, 13), 'olive'), ((13, 14), 'teal'), ((14, 15), 'navy'), ((15, 16), 'gray'),
        ((0, 17), 'lavender'), ((17, 18), 'silver'), ((18, 19), 'maroon'), ((19, 20), 'fuchsia')
    ]
    H, W, C = img.shape
    
    # Plot joints as points
    ax.imshow(img)
    for start_i in [0, 21]: 
        joints = all_joints[start_i: start_i+21]
        for connection, color in connections:
            joint1 = joints[connection[0]]
            joint2 = joints[connection[1]]
            ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color)

    ax.set_xlim([0, W])
    ax.set_ylim([0, H])
    ax.grid(False)
    ax.set_axis_off()
    ax.invert_yaxis()
    

@dataclass
class HandDiffOpts:
    run_name: str = 'ViT_256_handmask_heatmap_nvs_b25_lr1e-5'
    sd_path: str = '/users/kchen157/scratch/weights/SD/sd-v1-4.ckpt'
    log_dir: str = '/users/kchen157/scratch/log'
    data_root: str = '/users/kchen157/data/users/kchen157/dataset/handdiff'
    image_size: tuple = (256, 256)
    latent_size: tuple = (32, 32)
    latent_dim: int = 4
    mask_bg: bool = False
    kpts_form: str = 'heatmap'
    n_keypoints: int = 42
    n_mask: int = 1
    noise_steps: int = 1000
    test_sampling_steps: int = 250
    ddim_steps: int = 100
    ddim_discretize: str = "uniform"
    ddim_eta: float = 0.
    beta_start: float = 8.5e-4
    beta_end: float = 0.012
    latent_scaling_factor: float = 0.18215
    cfg_pose: float = 5.
    cfg_appearance: float = 3.5
    batch_size: int = 25
    lr: float = 1e-5
    max_epochs: int = 500
    log_every_n_steps: int = 100
    limit_val_batches: int = 1
    n_gpu: int = 8
    num_nodes: int = 1
    precision: str = '16-mixed'
    profiler: str = 'simple'
    swa_epoch_start: int = 10
    swa_lrs: float = 1e-3
    num_workers: int = 10
    n_val_samples: int = 4
        

opts = HandDiffOpts()
model_weights_dir = '../weights'
model_path = osp.join(model_weights_dir, 'DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt')
vae_path = osp.join(model_weights_dir, 'vae-ft-mse-840000-ema-pruned.ckpt')
sam_path = osp.join(model_weights_dir, 'sam_vit_h_4b8939.pth')

print('Load diffusion model...')
diffusion = create_diffusion(str(opts.test_sampling_steps))
model = vit.DiT_XL_2(
    input_size=opts.latent_size[0],
    latent_dim=opts.latent_dim,
    in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
    learn_sigma=True,
).cuda()

# ckpt_state_dict = torch.load(model_path)['model_state_dict']
ckpt_state_dict = torch.load(model_path)['ema_state_dict']
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
model.eval()
print(missing_keys, extra_keys)
assert len(missing_keys) == 0


vae_state_dict = torch.load(vae_path)['state_dict']
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).cuda()
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
autoencoder.eval()
assert len(missing_keys) == 0


print('Mediapipe hand detector and SAM ready...')
mp_hands = mp.solutions.hands
hands = mp_hands.Hands(
    static_image_mode=True,  # Use False if image is part of a video stream
    max_num_hands=2,         # Maximum number of hands to detect
    min_detection_confidence=0.1)
sam_predictor = init_sam(ckpt_path=sam_path)

Load diffusion model...
[] ['projector.0.weight', 'projector.0.bias', 'projector.2.weight', 'projector.2.bias', 'projector.4.weight', 'projector.4.bias']
Mediapipe hand detector and SAM ready...


I0000 00:00:1750622747.767905  417560 gl_context_egl.cc:85] Successfully initialized EGL. Major : 1 Minor: 5
I0000 00:00:1750622747.800162  719404 gl_context.cc:357] GL version: 3.2 (OpenGL ES 3.2 NVIDIA 535.129.03), renderer: NVIDIA RTX A6000/PCIe/SSE2
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
W0000 00:00:1750622747.858383  718959 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1750622747.881670  718971 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


# 3D functions

In [2]:
def plot_3d_skeleton(ax, keypoints):
    # Define connections between keypoints for OpenPose hand model
    connections = [
        (0, 1), (1, 2), (2, 3), (3, 4),
        (0, 5), (5, 6), (6, 7), (7, 8),
        (0, 9), (9, 10), (10, 11), (11, 12),
        (0, 13), (13, 14), (14, 15), (15, 16),
        (0, 17), (17, 18), (18, 19), (19, 20)
    ]
#     ax.scatter(keypoints[:, 0], keypoints[:, 2], keypoints[:, 1], c='red', marker='o', s=50)

    # Plot connections
    for connection in connections:
        start, end = connection
        ax.plot3D(keypoints[[start, end], 0], keypoints[[start, end], 1], keypoints[[start, end], 2], color='red')
        
    for connection in connections:
        start, end = connection
        start += 21
        end += 21
        ax.plot3D(keypoints[[start, end], 0], keypoints[[start, end], 1], keypoints[[start, end], 2], color='blue')
        

def area_of_convex_hull(points):
    try:
        return ConvexHull(points).volume
    except:
        return 0


def are_points_clustered_or_linear(points, reference_points, scale=0.5):
    reference_area = area_of_convex_hull(reference_points)
    points_area = area_of_convex_hull(points)
    
    # Check if the area formed by the points is less than the scaled area of the reference
    return points_area < scale * reference_area


def generate_intrinsic_matrix(size, focal):
    cx = cy = size/2  # Center of the image
    fx = fy = focal  # Focal length (arbitrary for this exercise)
    intrinsic = np.array([[fx, 0, cx],
                          [0, fy, cy],
                          [0, 0, 1]])
    return intrinsic


def original_camera_extrinsic():
    return np.eye(4)


def unproject_points(points_3d, intrinsic):
    # Separate the depth (z-values) from the (x, y) image coordinates
    xy_image = points_3d[:, :2]
    z_depth = points_3d[:, 2]

    # Convert (x, y) to homogeneous coordinates
    homogeneous_coords = np.hstack((xy_image, np.ones((xy_image.shape[0], 1))))

    # Compute the inverse of the intrinsic matrix
    K_inv = np.linalg.inv(intrinsic)

    # Map the points to normalized camera coordinates
    xy_normalized = np.dot(K_inv, homogeneous_coords.T).T  # (N, 3)

    # Multiply the normalized coordinates by the z depth to get the 3D coordinates
    points_camera_frame = -xy_normalized * z_depth[:, np.newaxis]

    return points_camera_frame


def compute_centroid(points_3d):
    return np.mean(points_3d, axis=0)


def get_rotated_camera_position(angle, centroid, original_camera_pos):
    rotation_matrix = np.array([
        [np.cos(angle), -np.sin(angle), 0],
        [np.sin(angle), np.cos(angle), 0],
        [0, 0, 1]
    ])
    translated_camera = original_camera_pos - centroid
    rotated_position = np.dot(rotation_matrix, translated_camera.T).T + centroid
    return rotated_position


def extrinsics_from_lookat(camera_pos, target_pos, up=np.array([0, 1, 0])):
    z_axis = target_pos - camera_pos
    z_axis = z_axis / np.linalg.norm(z_axis)

    x_axis = np.cross(up, z_axis)  # Use the global up vector to define "right"
    x_axis /= np.linalg.norm(x_axis)
    
    y_axis = np.cross(z_axis, x_axis)  # Re-compute the "up" vector
    lookat_matrix = np.eye(4)
    R = np.vstack((x_axis, y_axis, z_axis))
    lookat_matrix[:3, :3] = R
    lookat_matrix[:3, -1] = -R @ camera_pos
    return lookat_matrix


def project_points(points, K, E):
    # Ensure points is a 2D array
    points = np.asarray(points)
    if points.ndim == 1:
        points = points[np.newaxis, :]
        
    # Convert to homogeneous coordinates
    if points.shape[1] == 3:
        points_h = np.hstack([points, np.ones((points.shape[0], 1))])
    else:
        points_h = points
    
    # Transform points to camera coordinates
    points_camera = E[:3] @ points_h.T  # 3xN matrix
    
    # Get depths (Z coordinates in camera space)
    depths = points_camera[2]
    
    # Create mask for points in front of camera (positive Z)
    mask = depths != 0
    
    # Avoid division by zero by setting invalid depths to 1
    depths_safe = np.where(mask, depths, 1.0)
    
    # Project to image coordinates
    points_proj = K @ points_camera
    points_proj = points_proj / depths_safe
    
    # Convert to pixel coordinates (keep only x, y)
    pixels = points_proj[:2].T
    
    return pixels


def compute_azimuthal_angle(point, centroid):
    delta = point - centroid
    return np.arctan2(delta[2], delta[0])  # Based on x and z components


def sample_upright_camera_positions_around_globe(centroid, distance, num_samples=5, height_factor=0.5):
    sampled_positions = []
    min_height, max_height = centroid[1] - height_factor * distance, centroid[1] + height_factor * distance

    for _ in range(num_samples):
        theta = np.random.uniform(0, 2*np.pi)  # Yaw
        h = np.random.uniform(min_height, max_height)  # Height

        # Compute the radius at this height for the spherical coordinates
        r = np.sqrt(distance**2 - (h - centroid[1])**2)
        
        x = r * np.cos(theta) + centroid[0]
        z = r * np.sin(theta) + centroid[2]

        sampled_positions.append(np.array([x, h, z]))

    return sampled_positions


def get_direction_vector(camera_pos, centroid):
    return camera_pos - centroid

def rotation_matrix(axis, theta):
    """
    Return the rotation matrix associated with counterclockwise rotation about
    the given axis by theta radians.
    """
    axis = np.asarray(axis)
    axis = axis / np.sqrt(np.dot(axis, axis))
    a = np.cos(theta / 2.0)
    b, c, d = -axis * np.sin(theta / 2.0)
    aa, bb, cc, dd = a * a, b * b, c * c, d * d
    bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
    return np.array([
        [aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
        [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
        [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]
    ])


def uniform_sphere_sampling(num_samples):
    u = np.random.uniform(0, 1, num_samples)
    v = np.random.uniform(0, 1, num_samples)
    theta = 2 * np.pi * u  # Azimuthal angle
    phi = np.arccos(2 * v - 1) - np.pi / 2  # Polar angle, subtracting np.pi/2 to map [-pi/2, pi/2]
    return theta, phi

def scale_keypoints(keypoints_3d):
    # Compute the scaling factor: 1 divided by the maximum absolute coordinate value
    scale_factor = 1.0 / np.max(np.abs(keypoints_3d))
    
    # Scale the 3D keypoints
    scaled_keypoints = keypoints_3d * scale_factor

    return scaled_keypoints

def sample_sphere_vectors(N, radius):
    phi = np.random.uniform(0, 2*np.pi, N)
    cos_theta = np.random.uniform(-1, 1, N)
    theta = np.arccos(cos_theta)
    
    # Convert to Cartesian coordinates
    x = np.sin(theta) * np.cos(phi)
    y = np.sin(theta) * np.sin(phi)
    z = cos_theta
    
    # Stack the coordinates
    vectors = np.stack([x, y, z], axis=-1) * radius
    
    return vectors

def circular_camera_poses(radius, zenith_range, azimuth_range, num_poses=120):
    camera_positions = []
    zenith = np.linspace(zenith_range[0], zenith_range[1], num_poses)
    azimuth = np.linspace(azimuth_range[0], azimuth_range[1], num_poses)
    for i in range(num_poses):
        # Convert spherical coordinates to Cartesian coordinates
        x = radius * np.sin(np.radians(zenith[i])) * np.cos(np.radians(azimuth[i]))
        z = radius * np.sin(np.radians(zenith[i])) * np.sin(np.radians(azimuth[i]))
        y = radius * np.cos(np.radians(zenith[i])) 
        
        camera_positions.append(np.array([x, y, z]))

    return camera_positions

# Load reference views

In [None]:
import os
import pickle

root = 
seq = '7-0032_pinkytip'
data_root = './test_data/nvs_interhand'
frame_cam_map = {}
for pair in [file[:-4].split('-') for file in os.listdir(data_root)]:
    if pair[0] not in frame_cam_map:
        frame_cam_map[pair[0]] = set([pair[1]])
    else:
        frame_cam_map[pair[0]].add(pair[1])

frame_id = '11929'
ref_cam_id = '400002'

N_test = 21
test_cams = list(frame_cam_map[frame_id]) 
random.shuffle(test_cams)
test_cams = test_cams[:N_test]
test_cam_target_conds = []
test_cam_keypts = []
test_cam_gt = []
test_cam_ids = []
test_cam_params = {}
for i, cam_id in enumerate(test_cams):
    if cam_id == '400067':
        continue
        
    img = io.imread(osp.join(data_root, f'{frame_id}-{cam_id}.jpg'))
    original_img_size = tuple(img.shape[:2])
    img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
    
    anns = np.load(osp.join(data_root, f'{frame_id}-{cam_id}.npz'))
    hand_mask, keypts = anns['hand_mask'], anns['kpts']
    hand_mask = np.array(hand_mask).astype(int)
    keypts = scale_keypoint(keypts, original_img_size, opts.image_size)

    masked_img = img * hand_mask[..., None] + 255*(1 - hand_mask[..., None])

    image, heatmaps, mask = make_ref_cond(
        img, keypts, hand_mask, device='cuda', target_size=opts.image_size, latent_size=opts.latent_size)
    latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
    
    E = np.eye(4)
    E[:3, :3] = anns['camrot']
    E[:3, -1] = -anns['camrot'] @ anns['campos']
    K = np.eye(3)
    K[0, 0] = anns['focal'][0]
    K[1, 1] = anns['focal'][1]
    K[0, -1] = anns['princpt'][0]
    K[1, -1] = anns['princpt'][1]
    
    test_cam_params[cam_id] = {
        'joint3d': anns['joint3d'],
        'joint_valid': anns['joint_valid'][..., None],
        'K': K,
        'E': E,
    }
    if cam_id == ref_cam_id:
        ref_image = img
        ref_keypts = keypts
        fig, axs = plt.subplots(1, 2, figsize=(3*2, 3))
        visualize_hand(axs[0], keypts, img)
        visualize_hand(axs[1], keypts, masked_img)
        plt.tight_layout()
        plt.show()
        src_ref_cond = torch.cat([latent, heatmaps, mask], 1) 
    else:
        test_cam_ids.append(cam_id)
        test_cam_keypts.append(keypts)
        test_cam_gt.append(img)
        test_cam_target_conds.append(torch.cat([heatmaps, torch.zeros_like(mask)], 1))
        
        
save_dir = f'./test_data/nvs_test/{seq}'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    
io.imsave(osp.join(save_dir, f'ref_{ref_cam_id}.jpg'), ref_image)
with open(osp.join(save_dir, 'joint3d_cam_params.pkl'), 'wb') as f:
    pickle.dump(test_cam_params, f)


# Sample Test View

In [None]:
cfg_scale = 3.0
z = torch.randn((1, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device='cuda')
# novel view synthesis mode = on
nvs = torch.ones(1, dtype=torch.int, device='cuda')
z = torch.cat([z, z], 0)

ref_conds = []
for k, target_cond in enumerate(test_cam_target_conds[:3]):
    model_kwargs = dict(
        target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]), 
        ref_cond=torch.cat([src_ref_cond, torch.zeros_like(src_ref_cond)]), 
        nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]), 
        cfg_scale=cfg_scale)
    
    samples, _ = diffusion.p_sample_loop(
        model.forward_with_cfg, z.shape, z, clip_denoised=False,
        model_kwargs=model_kwargs, ref_conds=[src_ref_cond] + ref_conds, progress=True, device='cuda'
    ).chunk(2)
    ref_conds.append(torch.cat([samples, target_cond], 1))
    
    sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor) 
    sampled_images = torch.clamp(sampled_images, min=-1., max=1.)
    sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
    sampled_image = sampled_images[0]

    #visualize
    fig, axs = plt.subplots(1, 4, figsize=(6*4, 6))
    for i, vis_img in enumerate([ref_image, test_cam_gt[k], sampled_image]):
        axs[i].imshow(vis_img)
        axs[i].axis('off')
        axs[i].grid(False)
        axs[i].set_title(test_cam_ids[k])
    visualize_hand(axs[3], test_cam_keypts[k], sampled_image)
    plt.tight_layout()
    plt.show()
    io.imsave(osp.join(save_dir, f'gt_{test_cam_ids[k]}.jpg'), test_cam_gt[k])
    io.imsave(osp.join(save_dir, f'sampled_{test_cam_ids[k]}.jpg'), sampled_image)
    

# Sample Camera Trajectory

In [None]:
E_ref = test_cam_params[ref_cam_id]['E']
hand_pose3d = (E_ref @ np.concatenate([test_cam_params[ref_cam_id]['joint3d'], np.ones((42, 1))], -1).transpose(1, 0)).transpose(1, 0)[:, :3]
centroid = hand_pose3d[9].copy()
hand_pose3d -= centroid
E_ref_c2w = np.eye(4) 
E_ref_c2w[:3, -1] -= centroid
E_ref_w2c = np.linalg.inv(E_ref_c2w)


radius = (centroid ** 2).sum() ** 0.5
zenith_range = (80, 100)
azimuth_range = (-140, -40)
num_poses = 45
camera_positions = circular_camera_poses(radius, zenith_range, azimuth_range, num_poses=num_poses)
cam_flythrough = np.array([extrinsics_from_lookat(camera_position, np.zeros(3)) for camera_position in camera_positions])
np.savez(osp.join(save_dir, f'cam_motion.npz'), data=cam_flythrough)

intrinsic = test_cam_params[ref_cam_id]['K']
keypts_flythrough = []
for i in range(num_poses):
    projected_kpts = project_points(hand_pose3d, intrinsic, cam_flythrough[i])
    keypts_flythrough.append(projected_kpts)


target_conds = []
for keypts in keypts_flythrough:
    kpts_valid = check_keypoints_validity(keypts, opts.image_size)
    target_heatmaps = torch.tensor(keypoint_heatmap(
        scale_keypoint(keypts, opts.image_size, opts.latent_size), 
        opts.latent_size, var=1.) * kpts_valid[:, None, None], dtype=torch.float, device='cuda')[None, ...]
    target_cond = torch.cat([
        target_heatmaps, 
        torch.zeros_like(mask)], 1)
    target_conds.append(target_cond)
    

cam_flythrough_c2w = np.linalg.inv(cam_flythrough)
dirs =cam_flythrough_c2w[:, :3, :3] @ np.array([0, 0, 1])
origins = cam_flythrough_c2w[:, :3, -1]

ax = plt.figure(figsize=(12, 8)).add_subplot(projection='3d')
plot_3d_skeleton(ax, hand_pose3d)

_ = ax.quiver(
  origins[..., 0].flatten(),
  origins[..., 1].flatten(),
  origins[..., 2].flatten(),
  dirs[..., 0].flatten(),
  dirs[..., 1].flatten(), 
  dirs[..., 2].flatten(),
  length=0.1*radius, normalize=True, color='green')


# original camera
dirs = (E_ref_c2w[:3, :3] @ np.array([0, 0, 1]))[None]
origins = E_ref_c2w[None, :3, -1]

_ = ax.quiver(
  origins[..., 0].flatten(),
  origins[..., 1].flatten(),
  origins[..., 2].flatten(),
  dirs[..., 0].flatten(),
  dirs[..., 1].flatten(), 
  dirs[..., 2].flatten(),
  length=100, normalize=True, color='black')


ax.set_xlim(-radius, radius)
ax.set_ylim(-radius, radius)
ax.set_zlim(-radius, radius)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.show()

fig = plt.figure()
visualize_hand(fig.add_subplot(111), project_points(hand_pose3d, intrinsic, E_ref_w2c), ref_image)

for kpts in keypts_flythrough:
    fig = plt.figure()
    visualize_hand(fig.add_subplot(111), kpts, ref_image)



# NVS

In [None]:
last_N_frames = 1
z = torch.randn((1, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device='cuda')
# novel view synthesis mode = on
nvs = torch.ones(1, dtype=torch.int, device='cuda')
z = torch.cat([z, z], 0)

video_frames = [] 
last_frames = []
for k, target_cond in enumerate(target_conds):
    model_kwargs = dict(
        target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]), 
        ref_cond=torch.cat([src_ref_cond, torch.zeros_like(src_ref_cond)]), 
        nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]), 
        cfg_scale=cfg_scale)
    
    samples, _ = diffusion.p_sample_loop(
        model.forward_with_cfg, z.shape, z, clip_denoised=False,
        model_kwargs=model_kwargs, ref_conds=[src_ref_cond] + ref_conds + last_frames, progress=True, device='cuda'
    ).chunk(2)
    
    if len(last_frames) >= last_N_frames:
        last_frames.pop(0)
    last_frames.append(torch.cat([samples, target_cond], 1) )

    
    sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor) 
    sampled_images = torch.clamp(sampled_images, min=-1., max=1.)
    sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
    sampled_image = sampled_images[0]
    video_frames.append(sampled_image)

    #visualize
    fig, axs = plt.subplots(1, 3, figsize=(6*3, 6))
    for i, vis_img in enumerate([ref_image, sampled_image]):
        axs[i].imshow(vis_img)
        axs[i].axis('off')
        axs[i].grid(False)
    visualize_hand(axs[2], keypts_flythrough[k], sampled_image)
    plt.tight_layout()
    plt.show()


frames_to_video(video_frames, osp.join(save_dir, 'cam_motion.mp4'), fps=15, rgb2bgr=True)