In [1]:
cd ..

/experiments_motion


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
import os
import gc
import importlib.util
sys.path.append('./Motion')

In [4]:
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
import yaml
from tqdm import tqdm
from IPython.display import display, HTML

In [5]:
from h36m.skeleton import H36MSkeleton
from h36m.dataset.h36m_torch_dataset import H36MTorchDataset
from h36m.dataset.h36m_dataset import H36MDataset
from h36m.dataset.h36m_test_dataset import H36MTestDataset
from helper.helper import add_static_nodes

In [6]:
from metrics import MeanAngleL2Error, MeanPerJointPositionError, NegativeLogLikelihood

# Load H3.6M Skeleton

In [7]:
# Load skeleton configuration
with open('./config/h36m_skeleton.yaml', 'r') as stream:
    skeleton = H36MSkeleton(**yaml.safe_load(stream))

## Define Eval Metrics

In [8]:
mae_l2_metric = MeanAngleL2Error(ignore_root=True) # We ignore the root rotation in the world for the origin joint
mpjpe_metric = MeanPerJointPositionError()
nll_metric = NegativeLogLikelihood()

In [9]:
t = [0.08, 0.16, 0.32, 0.4, 0.56, 0.72, 0.88, 1.]

In [10]:
from scipy.stats import gaussian_kde
def compute_kde_nll(y, y_pred):
    bs, sp, ts, ns, d = y_pred.shape
    kde_ll = torch.zeros((bs, ts, ns))

    for b in range(bs):
        for t in range(ts):
            for n in range(ns):
                try:
                    kde = gaussian_kde(y_pred[b, :, t, n].T)
                    pdf = kde.logpdf(y[b, t, n].T)
                    kde_ll[b, t, n] = torch.tensor(pdf)
                except np.linalg.LinAlgError:
                    print(b, t, n)
                    print('nan')
                    pass

    return -kde_ll

In [11]:
models = ['abl_1_modes', 
          'abl_2_modes', 
          'abl_3_modes', 
          'abl_4_modes', 
          'abl_5_modes', 
          'abl_6_modes', 
          'abl_gmm', 
          'abl_no_tg', 
          'abl_one_hot', 
          'abl_lg']

In [12]:
DATA_PATH = './data/processed/h3.6m.npz'

In [13]:
h36m_dataset = H36MDataset(DATA_PATH, dataset_fps=50, dataset_downsample_factor=2)

In [14]:
dataset_test = H36MTestDataset(h36m_dataset,
                                action='average',
                                subjects=['S5'],
                                num_samples=256,
                                history_length=50,
                                prediction_horizon=25)

In [15]:
data_loader = DataLoader(dataset_test, batch_size=64)

## Load Evaluation Data

In [16]:
ph = 25
n_samp = 1000
removed_joints = {0, 1, 6, 11}
kept_joints = np.array([x for x in range(32) if x not in removed_joints])

In [17]:
sys.path.remove('./Motion')

In [18]:
def delet_motion_import():
    keys = []
    for key in sys.modules.keys():   # iter on both keys and values
        if key.startswith('motion'):
            keys.append(key)
    for key in keys:
        sys.modules.pop(key, None)

In [19]:
result = pd.DataFrame(columns=pd.MultiIndex.from_product([['Mean Angle Error (L2)', 'NLL'], models]))
for model_str in models:
    model_path = f'./output/h36m/{model_str}/'
    load_adapted_code = False
    
    if os.path.isdir(os.path.join(model_path, 'Motion')):
        motion_path = os.path.join(model_path, 'Motion')
    else:
        motion_path = './Motion'
    delet_motion_import()
    spec = importlib.util.spec_from_file_location("motion", os.path.join(motion_path, 'motion', '__init__.py'))
    sys.path.append(motion_path)
    motion = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(motion)
    Motion = motion.Motion
    Quaternion = motion.Quaternion
    
    # Load model config
    with open(os.path.join(model_path, 'config.yaml'), 'r') as file:
        model_config = yaml.safe_load(file)
    checkpoint_path = os.path.join(model_path, 'model.pth.tar')
    model = Motion(skeleton, T=skeleton.nodes_type_id_dynamic, **model_config)
    model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
    model.eval()
    
    mae_l2_metric.reset()
    nll_list = []
    with torch.no_grad():
        for x, y in tqdm(data_loader):
            # Remove static nodes
            x_dynamic = Quaternion.qfix_positive_(x[:, :, skeleton.dynamic_nodes])

            # Run Model 
            p_q, _, _, _ = model(x_dynamic, None, ph=ph)
            
            q_w = p_q.weighted_mean

            # Add static rotations q = [1., 0., 0., 0.]
            q_w = add_static_nodes(q_w, skeleton.static_nodes, skeleton.dynamic_nodes)

            mae_l2_metric.update((q_w, y))

            q_samp = p_q.sample((n_samp,))
            # Add static rotations q = [1., 0., 0., 0.]
            q_samp_all = add_static_nodes(q_samp, skeleton.static_nodes, skeleton.dynamic_nodes)
            q_samp_all = q_samp_all.permute(1, 0, 2, 3, 4).contiguous()

            pos_sampl_all = skeleton(q_samp_all.view(-1, 32, 4), ignore_root=False).view(-1, n_samp, ph, 32, 3) / 1000.
            pos_y = skeleton(y.view(-1, 32, 4), ignore_root=False).view(-1, ph, 32, 3) / 1000.

            nll_list.append(compute_kde_nll(pos_y[..., kept_joints, :], pos_sampl_all[..., kept_joints, :]))
                                                          
        mae_l2 = mae_l2_metric.compute().numpy()
        nll = torch.cat(nll_list, dim=0).clip(max=20).sum(-1).mean(0).numpy()
        
        result['Mean Angle Error (L2)', model_str] = pd.Series(data=mae_l2, index=0.04*np.arange(1, mae_l2.shape[0] + 1))
        result['NLL', model_str] = pd.Series(data=nll, index=0.04*np.arange(1, mae_l2.shape[0] + 1))
        
        del nll
        del q_samp
        del q_samp_all
        del pos_sampl_all
    sys.path.remove(motion_path)
    del motion
    motion = None
    gc.collect()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [32:09<00:00, 32.16s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [36:46<00:00, 36.77s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [41:16<00:00, 41.28s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [45:25<00:00, 45.42s/it]
100%|███████████████████████████████████████

In [20]:
display(HTML(result['Mean Angle Error (L2)'].loc[t].to_html(float_format=lambda x: '%.2f' % x)))

Unnamed: 0,abl_1_modes,abl_2_modes,abl_3_modes,abl_4_modes,abl_5_modes,abl_6_modes,abl_gmm,abl_no_tg,abl_one_hot,abl_lg
0.08,0.29,0.28,0.29,0.29,0.28,0.28,0.29,0.3,0.29,0.29
0.16,0.53,0.52,0.52,0.52,0.51,0.51,0.53,0.54,0.53,0.52
0.32,0.9,0.88,0.88,0.88,0.87,0.88,0.92,0.93,0.9,0.88
0.4,1.05,1.02,1.02,1.01,1.01,1.02,1.06,1.07,1.04,1.02
0.56,1.26,1.23,1.23,1.22,1.22,1.23,1.28,1.28,1.25,1.23
0.72,1.44,1.4,1.41,1.4,1.4,1.41,1.46,1.47,1.43,1.41
0.88,1.59,1.55,1.56,1.54,1.54,1.55,1.6,1.63,1.57,1.55
1.0,1.68,1.63,1.65,1.63,1.63,1.64,1.69,1.73,1.67,1.64


In [21]:
display(HTML(result['NLL'].loc[t].to_html(float_format=lambda x: '%.2f' % x)))

Unnamed: 0,abl_1_modes,abl_2_modes,abl_3_modes,abl_4_modes,abl_5_modes,abl_6_modes,abl_gmm,abl_no_tg,abl_one_hot,abl_lg
0.08,-265.37,-275.68,-277.7,-280.17,-283.44,-274.9,-246.83,-281.06,-279.45,-281.2
0.16,-229.74,-241.36,-243.04,-244.2,-245.28,-241.41,-221.95,-240.14,-242.39,-242.75
0.32,-177.2,-189.85,-192.27,-192.83,-193.2,-191.03,-174.47,-186.56,-190.91,-190.69
0.4,-160.77,-173.21,-176.25,-176.74,-177.01,-174.5,-158.56,-170.19,-174.78,-174.68
0.56,-138.96,-150.98,-154.8,-155.07,-155.2,-152.2,-135.97,-148.28,-153.19,-152.95
0.72,-124.27,-135.33,-139.82,-139.76,-139.98,-136.18,-120.38,-132.33,-138.08,-137.62
0.88,-113.04,-123.53,-128.45,-128.31,-128.48,-124.71,-109.09,-120.04,-126.48,-126.41
1.0,-106.1,-117.2,-121.86,-122.03,-122.02,-117.7,-102.24,-112.35,-119.33,-119.99


In [27]:
pd.options.display.float_format = "{:.2f}".format
display(result['NLL'].sum())

abl_1_modes   -4032.58
abl_2_modes   -4320.73
abl_3_modes   -4405.46
abl_4_modes   -4418.98
abl_5_modes   -4432.40
abl_6_modes   -4340.94
abl_gmm       -3879.12
abl_no_tg     -4264.17
abl_one_hot   -4372.70
abl_lg        -4374.77
dtype: float32

## Bingham

In [19]:
data_loader = DataLoader(dataset_test, batch_size=8)

In [20]:
result = pd.DataFrame(columns=pd.MultiIndex.from_product([['Mean Angle Error (L2)', 'NLL'], models]))
model_str =  'abl_bingham'
model_path = f'./output/h36m/{model_str}/'
load_adapted_code = False

if os.path.isdir(os.path.join(model_path, 'Motion')):
    motion_path = os.path.join(model_path, 'Motion')
else:
    motion_path = './Motion'
delet_motion_import()
spec = importlib.util.spec_from_file_location("motion", os.path.join(motion_path, 'motion', '__init__.py'))
sys.path.append(motion_path)
motion = importlib.util.module_from_spec(spec)
spec.loader.exec_module(motion)
Motion = motion.Motion
Quaternion = motion.Quaternion

# Load model config
with open(os.path.join(model_path, 'config.yaml'), 'r') as file:
    model_config = yaml.safe_load(file)
checkpoint_path = os.path.join(model_path, 'model.pth.tar')
model = Motion(skeleton, T=skeleton.nodes_type_id_dynamic, **model_config)
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
model.eval()

mae_l2_metric.reset()
nll_list = []
with torch.no_grad():
    for x, y in tqdm(data_loader):
        # Remove static nodes
        x_dynamic = Quaternion.qfix_positive_(x[:, :, skeleton.dynamic_nodes])

        # Run Model 
        p_q, _, _, _ = model(x_dynamic, None, ph=ph)

        q_w = p_q.weighted_mean

        # Add static rotations q = [1., 0., 0., 0.]
        q_w = add_static_nodes(q_w, skeleton.static_nodes, skeleton.dynamic_nodes)

        mae_l2_metric.update((q_w, y))

        q_samp = p_q.sample((n_samp,))
        # Add static rotations q = [1., 0., 0., 0.]
        q_samp_all = add_static_nodes(q_samp, skeleton.static_nodes, skeleton.dynamic_nodes)
        q_samp_all = q_samp_all.permute(1, 0, 2, 3, 4).contiguous()

        pos_sampl_all = skeleton(q_samp_all.view(-1, 32, 4), ignore_root=False).view(-1, n_samp, ph, 32, 3) / 1000.
        pos_y = skeleton(y.view(-1, 32, 4), ignore_root=False).view(-1, ph, 32, 3) / 1000.

        nll_list.append(compute_kde_nll(pos_y[..., kept_joints, :], pos_sampl_all[..., kept_joints, :]))

    mae_l2 = mae_l2_metric.compute().numpy()
    nll = torch.cat(nll_list, dim=0).clip(max=20).sum(-1).mean(0).numpy()

    result['Mean Angle Error (L2)', model_str] = pd.Series(data=mae_l2, index=0.04*np.arange(1, mae_l2.shape[0] + 1))
    result['NLL', model_str] = pd.Series(data=nll, index=0.04*np.arange(1, mae_l2.shape[0] + 1))

    del nll
    del q_samp
    del q_samp_all
    del pos_sampl_all
sys.path.remove(motion_path)
del motion
motion = None
gc.collect()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 480/480 [2:26:21<00:00, 18.30s/it]


0

In [21]:
display(HTML(result['Mean Angle Error (L2)'].loc[t].to_html(float_format=lambda x: '%.2f' % x)))

Unnamed: 0,abl_1_modes,abl_2_modes,abl_3_modes,abl_4_modes,abl_5_modes,abl_6_modes,abl_gmm,abl_no_tg,abl_one_hot,abl_lg,abl_bingham
0.08,,,,,,,,,,,0.31
0.16,,,,,,,,,,,0.54
0.32,,,,,,,,,,,0.92
0.4,,,,,,,,,,,1.07
0.56,,,,,,,,,,,1.29
0.72,,,,,,,,,,,1.46
0.88,,,,,,,,,,,1.59
1.0,,,,,,,,,,,1.67


In [22]:
display(HTML(result['NLL'].loc[t].to_html(float_format=lambda x: '%.2f' % x)))

Unnamed: 0,abl_1_modes,abl_2_modes,abl_3_modes,abl_4_modes,abl_5_modes,abl_6_modes,abl_gmm,abl_no_tg,abl_one_hot,abl_lg,abl_bingham
0.08,,,,,,,,,,,-247.26
0.16,,,,,,,,,,,-223.11
0.32,,,,,,,,,,,-178.43
0.4,,,,,,,,,,,-162.58
0.56,,,,,,,,,,,-140.18
0.72,,,,,,,,,,,-124.88
0.88,,,,,,,,,,,-114.0
1.0,,,,,,,,,,,-107.52


In [23]:
pd.options.display.float_format = "{:.2f}".format
display(result['NLL'].sum())

abl_1_modes          0
abl_2_modes          0
abl_3_modes          0
abl_4_modes          0
abl_5_modes          0
abl_6_modes          0
abl_gmm              0
abl_no_tg            0
abl_one_hot          0
abl_lg               0
abl_bingham   -3983.10
dtype: object