In [1]:
cd ..

/home/ubuntu/Code/experiments-motion


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
import os
sys.path.append('./Motion')

In [4]:
import torch
import numpy as np
import pandas as pd
import yaml
from torch.utils.data import DataLoader
from tqdm import tqdm
from IPython.display import display, HTML

In [5]:
from motion import Quaternion
from h36m.skeleton import H36MSkeleton
from h36m.dataset.h36m_torch_dataset import H36MTorchDataset
from h36m.dataset.h36m_dataset import H36MDataset
from h36m.dataset.h36m_test_dataset import H36MTestDataset

In [6]:
from metrics import MeanAngleL2Error, MeanPerJointPositionError, NegativeLogLikelihood

# Load H3.6M Skeleton

In [7]:
# Load skeleton configuration
with open('./config/h36m_skeleton.yaml', 'r') as stream:
    skeleton = H36MSkeleton(**yaml.safe_load(stream))

## Define Eval Metrics

In [8]:
mae_l2_metric = MeanAngleL2Error(ignore_root=True) # We ignore the root rotation in the world for the origin joint
mpjpe_metric = MeanPerJointPositionError()
nll_metric = NegativeLogLikelihood()

In [9]:
t = [0.08, 0.16, 0.32, 0.4, 0.56, 0.72, 0.88, 1.]

# Generative Evaluation DLow

In [10]:
DATA_PATH = './data/processed/h3.6m.npz'
h36m_dataset = H36MDataset(DATA_PATH, dataset_fps=50, dataset_downsample_factor=1)

In [11]:
cd ./compare/DLow

/home/ubuntu/Code/experiments-motion/compare/DLow


In [12]:
import pickle

In [13]:
from models.motion_pred import *
from motion_pred.utils.config import Config
from motion_pred.eval import *

In [14]:
from scipy.stats import gaussian_kde
def compute_kde_nll(y, y_pred):
    bs, sp, ts, ns, d = y_pred.shape
    kde_ll = torch.zeros((bs, ts, ns))

    for b in range(bs):
        for t in range(ts):
            for n in range(ns):
                try:
                    kde = gaussian_kde(y_pred[b, :, t, n].T)
                    pdf = kde.logpdf(y[b, t, n].T)
                    kde_ll[b, t, n] = torch.tensor(pdf)
                except np.linalg.LinAlgError:
                    print(b, t, n)
                    print('nan')
                    pass

    return -kde_ll

In [15]:
def skeleton_set(skeleton, s):
    jl = []
    for i, p in enumerate(skeleton._parents):
        if p > -1:
            jl.append((s[i] - s[p]) * 1000.)
    l = torch.tensor(np.linalg.norm(np.array(jl), axis=-1))
    skeleton._offsets[skeleton._offsets.abs() > 0.] = l[l.abs() > 0.] * torch.sign(skeleton._offsets[skeleton._offsets.abs() > 0.])

In [16]:
def get_prediction(data, algo, sample_num, num_seeds=1, concat_hist=True):
    traj_np = data[..., 1:, :].reshape(data.shape[0], data.shape[1], -1)
    traj = tensor(traj_np, device=device, dtype=dtype).permute(1, 0, 2).contiguous()
    X = traj[:t_his]

    if algo == 'dlow':
        X = X.repeat((1, num_seeds, 1))
        Z_g = models[algo].sample(X)
        X = X.repeat_interleave(sample_num, dim=1)
        Y = models['vae'].decode(X, Z_g)
    elif algo == 'vae':
        X = X.repeat((1, sample_num * num_seeds, 1))
        Y = models[algo].sample_prior(X)

    if concat_hist:
        Y = torch.cat((X, Y), dim=0)
    Y = Y.permute(1, 0, 2).contiguous().cpu().numpy()
    if Y.shape[0] > 1:
        Y = Y.reshape(-1, num_seeds*sample_num, Y.shape[-2], Y.shape[-1])
    else:
        Y = Y[None, ...]
    return Y

In [17]:
algos = ['dlow', 'vae']
cfg = 'h36m_nsamp50'
traj_dim = 48
num_seeds = 1
device = 'cpu'
dtype = torch.float32

In [18]:
cfg = Config(cfg)

In [19]:
t_his = cfg.t_his
t_pred = cfg.t_pred

In [20]:
"""models"""
model_generator = {
    'vae': get_vae_model,
    'dlow': get_dlow_model,
}
models = {}
for algo in algos:
    models[algo] = model_generator[algo](cfg, traj_dim)
    model_path = getattr(cfg, f"{algo}_model_path") % getattr(cfg, 'num_%s_epoch' % algo)
    print(f'loading {algo} model from checkpoint: {model_path}')
    model_cp = pickle.load(open(model_path, "rb"))
    models[algo].load_state_dict(model_cp['model_dict'])
    models[algo].to(device)
    models[algo].eval()

loading dlow model from checkpoint: results/h36m_nsamp50/models/dlow_0500.p
loading vae model from checkpoint: results/h36m_nsamp50/models/vae_0500.p


In [21]:
removed_joints = {4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31}
kept_joints = np.array([x for x in range(32) if x not in removed_joints])

In [22]:
kept_joints

array([ 0,  1,  2,  3,  6,  7,  8, 12, 13, 14, 15, 17, 18, 19, 25, 26, 27])

In [23]:
dataset_eval_9 = H36MTorchDataset(h36m_dataset,
                                subjects=['S9'],
                                history_length=25,
                                prediction_horizon=100,
                                step=25)
data_loader_eval = DataLoader(dataset_eval_9, shuffle=False, batch_size=128)

In [24]:
d = np.load('./data/data_3d_h36m.npz', allow_pickle=True)
pos = d['positions_3d'].item()

In [25]:
skeleton_set(skeleton, pos['S9']['Phoning'][100])

In [26]:
nll_list = []

In [27]:
y_pos_pred_l = []
y_pos_l = []
cfg.nk = 50
with torch.no_grad():
    for x, y in tqdm(data_loader_eval):
        x_pos = skeleton(x.view(-1, 32, 4), ignore_root=False).view(-1, x.shape[1], 32, 3) / 1000.
        y_pos = skeleton(y.view(-1, 32, 4), ignore_root=False).view(-1, y.shape[1], 32, 3) / 1000.
        
        # Remove static nodes
        data = torch.cat([x_pos, y_pos], dim=1)[..., kept_joints, :].clone()
        
        # Account for different coordinate system during training
        data = data[..., [0, 2, 1]] * torch.tensor([[[1., -1., 1.]]])
        
        y_pred = []
        for i in range(20):
            y_pred_i = get_prediction(data, 'dlow', sample_num=cfg.nk, num_seeds=num_seeds, concat_hist=False)
            y_pred_i = torch.tensor(y_pred_i).view(y_pred_i.shape[:-1] + (-1, 3))
            y_pred.append(y_pred_i)
        y_pred = torch.cat(y_pred, dim=1)
        
        # Add static positions
        y_final = y_pred
        y_pos = y_pos[..., kept_joints[1:], :] # Origin is filtered in get_prediction
        
        nll_list.append(compute_kde_nll(data[:, 25:, 1:], y_final))
        
        y_pos_pred_l.append(y_final)
        y_pos_l.append(data[:, 25:, 1:])
    
y_pos_pred_9 = torch.cat(y_pos_pred_l, dim=0)
y_pos_9 = torch.cat(y_pos_l, dim=0)

  traj = tensor(traj_np, device=device, dtype=dtype).permute(1, 0, 2).contiguous()
100%|██████████| 24/24 [00:39<00:00,  1.63s/it]


In [28]:
dataset_eval_11 = H36MTorchDataset(h36m_dataset,
                                subjects=['S11'],
                                history_length=25,
                                prediction_horizon=100,
                                skip_11_d=True,
                                step=25)
data_loader_eval = DataLoader(dataset_eval_11, shuffle=False, batch_size=128)
skeleton_set(skeleton, pos['S11']['Phoning 2'][100])

In [29]:
y_pos_pred_l = []
y_pos_l = []
cfg.nk = 50
with torch.no_grad():
    for x, y in tqdm(data_loader_eval):
        x_pos = skeleton(x.view(-1, 32, 4), ignore_root=False).view(-1, x.shape[1], 32, 3) / 1000.
        y_pos = skeleton(y.view(-1, 32, 4), ignore_root=False).view(-1, y.shape[1], 32, 3) / 1000.
        
        # Remove static nodes
        data = torch.cat([x_pos, y_pos], dim=1)[..., kept_joints, :].clone()
        
        data = data[..., [0, 2, 1]] * torch.tensor([[[1., -1., 1.]]])# + torch.tensor([[[0., 0., 1.]]])
        
        y_pred = []
        for i in range(20):
            y_pred_i = get_prediction(data, 'dlow', sample_num=cfg.nk, num_seeds=num_seeds, concat_hist=False)
            y_pred_i = torch.tensor(y_pred_i).view(y_pred_i.shape[:-1] + (-1, 3))
            y_pred.append(y_pred_i)
        y_pred = torch.cat(y_pred, dim=1)
        
        # Add static positions
        y_final = y_pred
        y_pos = y_pos[..., kept_joints[1:], :] # Origin is filtered in get_prediction
        
        nll_list.append(compute_kde_nll(data[:, 25:, 1:], y_final))
        
        y_pos_pred_l.append(y_final)
        y_pos_l.append(data[:, 25:, 1:])
    
y_pos_pred_11 = torch.cat(y_pos_pred_l, dim=0)
y_pos_11 = torch.cat(y_pos_l, dim=0)

  traj = tensor(traj_np, device=device, dtype=dtype).permute(1, 0, 2).contiguous()
100%|██████████| 17/17 [00:27<00:00,  1.60s/it]


In [30]:
y_pos_pred = torch.cat([y_pos_pred_9, y_pos_pred_11], dim=0)
y_pos = torch.cat([y_pos_9, y_pos_11], dim=0)

In [31]:
y_pos_pred.shape

torch.Size([5168, 50, 100, 16, 3])

In [32]:
apd = 0.
for i in range(y_pos.shape[0]):
    apd += compute_diversity(y_pos_pred[i, :50].flatten(start_dim=-2))

In [33]:
apd / y_pos.shape[0]

11.60039364358613

In [34]:
ade = 0.
for i in range(y_pos.shape[0]):
    ade += compute_ade(y_pos_pred[i, :50].flatten(start_dim=-2), y_pos[i].flatten(start_dim=-2))

In [35]:
ade / y_pos.shape[0]

0.41767544380386273

In [34]:
fde = 0.
for i in range(y_pos.shape[0]):
    fde += compute_fde(y_pos_pred[i, :50].flatten(start_dim=-2), y_pos[i].flatten(start_dim=-2))

In [35]:
fde / y_pos.shape[0]

0.5141418853964967

In [36]:
from scipy.spatial.distance import pdist, squareform
def get_multimodal_gt(dataset):
    all_data = []
    for x, y in dataset:
        data = torch.cat([x, y], dim=0)
        all_data.append(data)
    all_data = torch.stack(all_data, axis=0)
    all_data = skeleton(all_data.view(-1, 32, 4)).view(-1, all_data.shape[1], 32, 3) / 1000.
    all_data = all_data[..., [0, 2, 1]] * torch.tensor([[[1., -1., 1.]]])
    all_data = all_data[..., kept_joints, :]
    all_data = all_data[..., 1:, :]
    all_start_pose = all_data[:, t_his - 1, :].flatten(start_dim=-2)
    pd = squareform(pdist(all_start_pose))
    traj_gt_arr = []
    for i in range(pd.shape[0]):
        ind = np.nonzero(pd[i] < 0.5)
        traj_gt_arr.append(all_data[ind][:, t_his:, :])
    return traj_gt_arr

In [37]:
skeleton_set(skeleton, pos['S9']['Phoning'][100])
m_gt_9 = get_multimodal_gt(dataset_eval_9)
skeleton_set(skeleton, pos['S11']['Phoning 2'][100])
m_gt_11 = get_multimodal_gt(dataset_eval_11)
m_gt = m_gt_9 + m_gt_11

In [38]:
mmade = 0.
for i in tqdm(range(y_pos.shape[0])):
    mmade += compute_mmade(y_pos_pred[i, :50].flatten(start_dim=-2), None, m_gt[i].flatten(start_dim=-2))

100%|██████████| 5168/5168 [06:12<00:00, 13.89it/s] 


In [39]:
mmade / y_pos.shape[0]

0.48554390516628054

In [40]:
mmfde = 0.
for i in tqdm(range(y_pos.shape[0])):
    mmfde += compute_mmfde(y_pos_pred[i, :50].flatten(start_dim=-2), None, m_gt[i].flatten(start_dim=-2))

100%|██████████| 5168/5168 [05:51<00:00, 14.69it/s] 


In [41]:
mmfde / y_pos.shape[0]

0.5261071764924578

In [42]:
apd = 0.
for i in range(y_pos.shape[0]):
    apd += compute_diversity(y_pos_pred[i, :50, :50].flatten(start_dim=-2))

In [43]:
apd / y_pos.shape[0]

5.180428011832206

In [44]:
ade = 0.
for i in range(y_pos.shape[0]):
    ade += compute_ade(y_pos_pred[i, :50, :50].flatten(start_dim=-2), y_pos[i, :50].flatten(start_dim=-2))

In [45]:
ade / y_pos.shape[0]

0.305018526300119

In [46]:
fde = 0.
for i in range(y_pos.shape[0]):
    fde += compute_fde(y_pos_pred[i, :50, :50].flatten(start_dim=-2), y_pos[i, :50].flatten(start_dim=-2))

In [47]:
fde / y_pos.shape[0]

0.4189206726572277

In [48]:
mmade = 0.
for i in tqdm(range(y_pos.shape[0])):
    mmade += compute_mmade(y_pos_pred[i, :50, :50].flatten(start_dim=-2), None, m_gt[i][:, :50].flatten(start_dim=-2))

100%|██████████| 5168/5168 [03:07<00:00, 27.52it/s] 


In [49]:
mmade / y_pos.shape[0]

0.41667923668324486

In [50]:
mmfde = 0.
for i in tqdm(range(y_pos.shape[0])):
    mmfde += compute_mmfde(y_pos_pred[i, :50, :50].flatten(start_dim=-2), None, m_gt[i][:, :50].flatten(start_dim=-2))

100%|██████████| 5168/5168 [03:02<00:00, 28.24it/s] 


In [51]:
mmfde / y_pos.shape[0]

0.453848048191137

In [36]:
nll = torch.cat(nll_list, dim=0)

In [38]:
nll.clip(max=20).sum(-1).mean(0)

tensor([115.6573,  53.8431, -12.9685, -51.4870, -69.4903, -78.1288, -82.5270,
        -84.7968, -85.8376, -86.4732, -86.9339, -87.1482, -87.2519, -87.2515,
        -87.2234, -87.0629, -86.8170, -86.5770, -86.1807, -85.7386, -85.3267,
        -84.9387, -84.5291, -84.0931, -83.6289, -83.1264, -82.6000, -82.0796,
        -81.4962, -80.8790, -80.2513, -79.6091, -79.0010, -78.3597, -77.7429,
        -77.1220, -76.4933, -75.9142, -75.3304, -74.7518, -74.1841, -73.6319,
        -73.0862, -72.5533, -72.0395, -71.5494, -71.0771, -70.6249, -70.1895,
        -69.7316, -69.2791, -68.8391, -68.3947, -67.9507, -67.5247, -67.1117,
        -66.6923, -66.2854, -65.8880, -65.5115, -65.1341, -64.7701, -64.4187,
        -64.0621, -63.7194, -63.3797, -63.0506, -62.7341, -62.4353, -62.1468,
        -61.8750, -61.6272, -61.3977, -61.1749, -60.9442, -60.7227, -60.5096,
        -60.3074, -60.1050, -59.9169, -59.7381, -59.5634, -59.4109, -59.2708,
        -59.1418, -59.0219, -58.9086, -58.8024, -58.6978, -58.60

In [39]:
nll.sum(-1).mean(0)

tensor([ 3.9678e+03,  1.8360e+03,  7.4814e+02,  3.4263e+02,  1.6973e+02,
         8.4026e+01,  3.4978e+01,  2.8629e+00, -1.8358e+01, -3.3132e+01,
        -4.4400e+01, -5.2544e+01, -5.8882e+01, -6.3797e+01, -6.7706e+01,
        -7.0686e+01, -7.3053e+01, -7.4968e+01, -7.6394e+01, -7.7393e+01,
        -7.8195e+01, -7.8833e+01, -7.9099e+01, -7.9325e+01, -7.9435e+01,
        -7.9377e+01, -7.9292e+01, -7.9068e+01, -7.8844e+01, -7.8516e+01,
        -7.8109e+01, -7.7710e+01, -7.7343e+01, -7.6873e+01, -7.6407e+01,
        -7.5954e+01, -7.5395e+01, -7.4909e+01, -7.4412e+01, -7.3916e+01,
        -7.3413e+01, -7.2934e+01, -7.2450e+01, -7.1975e+01, -7.1517e+01,
        -7.1058e+01, -7.0632e+01, -7.0210e+01, -6.9807e+01, -6.9374e+01,
        -6.8953e+01, -6.8557e+01, -6.8147e+01, -6.7707e+01, -6.7294e+01,
        -6.6893e+01, -6.6487e+01, -6.6094e+01, -6.5707e+01, -6.5341e+01,
        -6.4977e+01, -6.4620e+01, -6.4275e+01, -6.3927e+01, -6.3591e+01,
        -6.3258e+01, -6.2936e+01, -6.2623e+01, -6.2

In [36]:
from scipy.spatial.distance import pdist, squareform
t_his=25
def get_multimodal_gt(dataset, threshold):
    all_data = []
    for x, y in dataset:
        data = torch.cat([x, y], dim=0)
        all_data.append(data)
    all_data = torch.stack(all_data, axis=0)
    all_data = skeleton(all_data.view(-1, 32, 4)).view(-1, all_data.shape[1], 32, 3) / 1000.
    all_data = all_data[..., [0, 2, 1]] * torch.tensor([[[1., -1., 1.]]])
    all_data = all_data[..., kept_joints, :]
    all_data = all_data[..., 1:, :]
    all_start_pose = all_data[:, t_his - 1, :].flatten(start_dim=-2)
    pd = squareform(pdist(all_start_pose))
    traj_gt_arr = []
    for i in range(pd.shape[0]):
        ind = np.nonzero(pd[i] < threshold)
        traj_gt_arr.append(all_data[ind][:, t_his:, :])
    return traj_gt_arr

In [37]:
mmade_list = []
mmfde_list = []
with torch.no_grad():
    for t in [0.1, 0.2, 0.3, 0.4, 0.5]:
        skeleton_set(skeleton, pos['S9']['Phoning'][100])
        m_gt_9 = get_multimodal_gt(dataset_eval_9, t)
        skeleton_set(skeleton, pos['S11']['Phoning 2'][100])
        m_gt_11 = get_multimodal_gt(dataset_eval_11, t)
        m_gt = m_gt_9 + m_gt_11

        mmade = 0.
        for i in tqdm(range(y_pos.shape[0])):
            mmade += compute_mmade(y_pos_pred[i, :50].flatten(start_dim=-2), None, m_gt[i].flatten(start_dim=-2))
        mmade_list.append(mmade / y_pos.shape[0])

        mmfde = 0.
        for i in tqdm(range(y_pos.shape[0])):
            mmfde += compute_mmfde(y_pos_pred[i, :50].flatten(start_dim=-2), None, m_gt[i].flatten(start_dim=-2))
        mmfde_list.append(mmfde / y_pos.shape[0])

100%|██████████| 5168/5168 [00:04<00:00, 1169.19it/s]
100%|██████████| 5168/5168 [00:05<00:00, 991.56it/s] 
100%|██████████| 5168/5168 [00:08<00:00, 582.30it/s]
100%|██████████| 5168/5168 [00:07<00:00, 653.46it/s]
100%|██████████| 5168/5168 [00:26<00:00, 194.52it/s]
100%|██████████| 5168/5168 [00:26<00:00, 197.13it/s]
100%|██████████| 5168/5168 [01:41<00:00, 50.88it/s] 
100%|██████████| 5168/5168 [01:32<00:00, 55.83it/s] 
100%|██████████| 5168/5168 [04:59<00:00, 17.23it/s] 
100%|██████████| 5168/5168 [04:46<00:00, 18.04it/s] 


In [38]:
print(mmade_list)
print(mmfde_list)

[0.41866363429582615, 0.42348777630507667, 0.43801822789278655, 0.459772969079968, 0.4856139307916626]
[0.5127856655870983, 0.5126198851950351, 0.5158396173649159, 0.5196737145065062, 0.5255972387344107]
