## Demo script for the forecasting model

### Follow the setup instructions in README.md to install dependencies and download the pretrained model.

In [None]:
import os
import sys
import argparse
import random
from easydict import EasyDict

# Set this path to your downloads directory which contains all the data & models
os.environ["DOWNLOADS_DIR"] = '../downloads_forehand4d' # "<path to downloads directory>"

import numpy as np
from glob import glob
from tqdm import tqdm
import torch
from loguru import logger

import src.factory as factory
from common.torch_utils import reset_all_seeds
import common.data_utils as data_utils
from common.args_utils import set_default_params, set_extra_params
from src.parsers.generic_parser import add_generic_args
import src.parsers.configs.mdm as config

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Load and set all the required args
sys.argv = ['']
parser = argparse.ArgumentParser()
parser = add_generic_args(parser)
args, unknown = parser.parse_known_args()
args = EasyDict(vars(parser.parse_args()))

default_args = config.DEFAULT_ARGS_EGO
args = set_default_params(args, default_args) # only preserves keys from args
args = set_extra_params(args, default_args) # also preserves new keys from default_args which are not present in args

# default values
args.exp_key = 'logs/mdm_demo'
args.experiment = None
args.img_norm_mean = [0.485, 0.456, 0.406]
args.img_norm_std = [0.229, 0.224, 0.225]
args.num_workers = 0
args.batch_size = 1 # fix this for demo
args.focal_length = 1000.0
args.debug = False

args.use_gt_k = True
args.aug_data = False
args.rot_factor = 0.0
args.scale_factor = 1.0
args.flip_prob = 0.0
args.noise_factor = 1.0
augm_dict = data_utils.augm_params(
            args.aug_data,
            args.flip_prob,
            args.noise_factor,
            args.rot_factor,
            args.scale_factor,
        )
use_gt_k = args.use_gt_k

from torchvision.transforms import Normalize
normalize_img = Normalize(mean=args.img_norm_mean, std=args.img_norm_std)

In [None]:
# cleanup cuda memory
import gc
torch.cuda.empty_cache()
torch.cuda.memory_summary(device=None, abbreviated=False)
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Load the pretrained model
args.method = 'mdm_light'
args.load_ckpt = f'{os.environ["DOWNLOADS_DIR"]}/model/forehand4d/model_v2.ckpt'

args.seed = random.randint(0, 100000)
reset_all_seeds(args.seed)
torch.set_num_threads(args.num_threads)

device = "cuda" if torch.cuda.is_available() else "cpu"
wrapper = factory.fetch_model(args).to(device)

ckpt = torch.load(args.load_ckpt)
wrapper.load_state_dict(ckpt["state_dict"])
logger.info(f"Loaded weights from {args.load_ckpt}")
wrapper = wrapper.eval()

In [None]:
# Load a sample image with hand-object scenario
img_dir = './assets/examples'
img_files = glob(os.path.join(img_dir, '*.jpg'))

imgname = random.choice(img_files)

# load image
cv_img, img_status = data_utils.read_img(imgname, (2800, 2000, 3))

# process image into the required format
img_w, img_h = cv_img.shape[1], cv_img.shape[0]
center = [img_w // 2, img_h // 2]
scale = max(img_w, img_h) / 200.0
rgb_img = data_utils.rgb_processing(
            args.aug_data,
            cv_img,
            center,
            scale,
            augm_dict,
            img_res=args.img_res,
        )
print ('Original image shape:', cv_img.shape)
print ('Processed image shape:', rgb_img.shape)
# plot cv_img and rgb_img side by side
plt.close('all')
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(cv_img.astype(np.uint8))
ax[0].axis('off')
ax[0].set_title('Original Image')
ax[1].imshow((255*rgb_img).transpose(1, 2, 0).astype(np.uint8))
ax[1].axis('off')
ax[1].set_title('Processed Image')
plt.show()

In [None]:
# Set the required inputs to the model
torch_img = torch.from_numpy(rgb_img).float()
norm_img = normalize_img(torch_img)
inputs = {}
targets = {}
meta_info = {}
inputs["img"] = norm_img.unsqueeze(0) # history size is 1
meta_info["imgname"] = imgname

# define ground truth camera intrinsics, this is for the 224x224 image, not the original image
# this value is for ARCTIC
intrx = np.array([[192.79396,   0.     , 112.32694],
                  [0.     , 192.7464 , 103.70519],
                  [0.     ,   0.     ,   1.     ]], dtype=np.float32)
meta_info["intrinsics"] = torch.FloatTensor(intrx).unsqueeze(0)

# these are required so that code doesn't break
is_valid, left_valid, right_valid = 1, 1, 1
num_joints = 21
targets["is_valid"] = torch.tensor([float(is_valid)])
targets["left_valid"] = torch.tensor([float(left_valid) * float(is_valid)])
targets["right_valid"] = torch.tensor([float(right_valid) * float(is_valid)])
targets["joints_valid_r"] = (torch.ones((num_joints)) * targets["right_valid"]).unsqueeze(0)
targets["joints_valid_l"] = (torch.ones((num_joints)) * targets["left_valid"]).unsqueeze(0)

vis_timesteps = 60 # number of timesteps to visualize
meta_info['mask_timesteps'] = torch.zeros(args.max_motion_length).bool()
meta_info['mask_timesteps'][:vis_timesteps] = True
meta_info['lengths'] = torch.tensor([vis_timesteps], dtype=torch.int32)
targets['future_valid_r'] = torch.ones(args.max_motion_length, num_joints)
targets['future_valid_l'] = torch.ones(args.max_motion_length, num_joints)

In [None]:
# model does not predict betas, but need these for getting MANO mesh
# mean values of beta, computed from val set of arctic, used for datasets without MANO fits
# can also use default beta values in MANO, either is fine as long as it is consistent across training
mean_beta_r = [0.82747316,  0.13775729, -0.39435294, 0.17889787, -0.73901576, 0.7788163, -0.5702684, 0.4947751, -0.24890041, 1.5943261]
mean_beta_l = [-0.19330633, -0.08867972, -2.5790455, -0.10344583, -0.71684015, -0.28285977, 0.55171007, -0.8403888, -0.8490544, -1.3397144]
targets['future_betas_r'] = torch.tensor(mean_beta_r).unsqueeze(0).repeat(args.max_motion_length, 1)
targets['future_betas_l'] = torch.tensor(mean_beta_l).unsqueeze(0).repeat(args.max_motion_length, 1)

In [None]:
# Move to device and add batch dimension
def to_device(data, device):
    for k, v in data.items():
        if isinstance(v, torch.Tensor):
            data[k] = v.to(device)
    return data

def unsqueeze_batch(data):
    for k, v in data.items():
        if isinstance(v, torch.Tensor):
            data[k] = v.unsqueeze(0)
        elif isinstance(v, float):
            data[k] = torch.tensor([v]).to(device)
    return data

inputs = to_device(inputs, device)
targets = to_device(targets, device)
meta_info = to_device(meta_info, device)

# unsqueeze batch dimension
inputs = unsqueeze_batch(inputs)
targets = unsqueeze_batch(targets)
meta_info = unsqueeze_batch(meta_info)

In [None]:
# Run inference
bz = meta_info['mask_timesteps'].shape[0]
num_samples = 5 # number of samples to generate from diffusion model
wrapper.max_vis_examples = bz
all_vis_dict = wrapper.inference(inputs, targets, meta_info, num_samples=num_samples)

In [None]:
# Visualize the generated motion
os.environ['PYOPENGL_PLATFORM'] = 'egl' 
from src.callbacks.vis.visualize_arctic import visualize_motion_viz
vis_fn = visualize_motion_viz
all_vis_imgs = []
for n, vis_dict in enumerate(tqdm(all_vis_dict)):
    curr_im_list = vis_fn(
        vis_dict,
        wrapper.max_vis_examples,
        wrapper.renderer,
        postfix='',
        no_tqdm=True,
        only_hands=True,
    )

    for b in range(bz):
        imgname = curr_im_list[b]['fig_name']
        imgidx = imgname.split('/')[-1].split('.')[0]

        img = curr_im_list[b]['im'] # (4*H, W, 3)
        inp = curr_im_list[b]['inp_img']
        save_name = os.path.join(img_dir, f'{n:02d}.png')
        # separate into 4 images
        img = np.split(img, 4, axis=0)
        img = np.hstack(img).astype(np.uint8)
        inp = (inp * 255).astype(np.uint8)

        # concatenate inp and img
        combined = np.hstack((inp, img)).clip(0, 255).astype(np.uint8)
        all_vis_imgs.append(combined)

In [None]:
# lighter shades denote father away timesteps
plt.close('all')
fig, ax = plt.subplots(num_samples, 1, figsize=(15, 15))
for i in range(num_samples):
    ax[i].imshow(all_vis_imgs[i])
    ax[i].axis('off')
plt.show()