In [1]:
import os
import numpy as np
import json
from typing import Sequence
from collections import UserDict

import torch
from torch.utils.data import DataLoader
from ignite.contrib.handlers import ProgressBar, EpochOutputStore
from ignite.engine import Events, Engine, DeterministicEngine

In [2]:
import sys
sys.path.append(os.path.abspath('..'))

from model.utils.torch import  set_seed
from model.core.metrics.mpjpe import MeanPerJointPositionError
from model.core.metrics.fde import FinalDisplacementError

In [3]:
history_length = 50
prediction_horizon = 100

dataset_path = f"../data/WalkingDynamicsH36M_history_{history_length}_pred_horiz_{prediction_horizon}"

## Define models & Dataset

In [4]:
def center_pose(output):
    out = (output - output[..., 0:1, :])[..., 1:, :]
    return out

class ZeroPoseGTTrajectoryBaseline(torch.nn.Module):
    def __init__(self, **kwargs):
        super(ZeroPoseGTTrajectoryBaseline, self).__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor,ph: int = 1)  -> torch.Tensor:
        pose_centered = center_pose(x)[..., -1 ,:,:].unsqueeze(-3)
        gt_trajectory = y[..., 0:1,:]
        out = torch.cat([gt_trajectory, pose_centered+gt_trajectory], dim=-2)
        return out
        
class ZeroBaseline(torch.nn.Module):
    def __init__(self, **kwargs):
        super(ZeroBaseline, self).__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor,ph: int = 1)  -> torch.Tensor:
        last_frame = x[..., -1 ,:,:].unsqueeze(-3)
        B, T, N, D = last_frame.shape
        last_frame.broadcast_to((B, ph, N, D))
        return last_frame
    
class ZeroPoseLastVelocityBaseline(torch.nn.Module):
    def __init__(self, **kwargs):
        super(ZeroPoseLastVelocityBaseline, self).__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor,ph: int = 1)  -> torch.Tensor:
        pose_centered = center_pose(x)[..., -1 ,:,:].unsqueeze(-3)
        # last_velocity = (x[..., -3: ,0:1,:] - x[..., -4:-1 ,0:1,:]).mean(-3).unsqueeze(-3) # B, 1, 1, 3 # holds worse results
        last_velocity = x[..., -1: ,0:1,:] - x[..., -2:-1 ,0:1,:] # B, 1, 1, 3 
        B, T, N, D = last_velocity.shape # N = 1
        displacement = torch.cat([last_velocity*i for i in range(1, ph+1)], dim=-3) # B, ph, 1, 3 
        displacement +=  x[..., -1 ,0:1,:].unsqueeze(-3)
        
        out = torch.cat([displacement, pose_centered+displacement], dim=-2)
        return out

In [5]:
class H36MValDataset(UserDict):
    """ """

    def __init__(self, dataset_val_path, split="val",
                 num_joints=25, **kwargs):

        self.datase_file_path = dataset_val_path
        self.subjects = ...
        self.n_keypoints = num_joints
        assert self.n_keypoints in [32, 25, 17]  
        assert split in ["val", "test"]
        
        super().__init__(self._load(self.datase_file_path, split, num_joints))

        print(f'Successfully created H36M dataset',
              '\n\tsplit: val my dataset',
              '\n\tnumber of sequences: ', len(self["poses_3d"]),
              )
        
    def __len__(self):
        return len(self["poses_3d"])
                
    @staticmethod
    def _load(dataset_path, split, n_kpts):      
        with open(os.path.join(dataset_path, f"{split}_poses.json"), 'r') as f:
            data = np.array(json.load(f)) #(221, 2, 16, 39)
        with open(os.path.join(dataset_path, f"{split}_images.json"), 'r') as f:
            frames = np.array(json.load(f)) #(221, 16)
        seqs = [fs[0].split("img")[0][:-1] for fs in frames]
        assert n_kpts == data.shape[-2]
        kpts = torch.from_numpy(data).float()
        result = {  "poses_3d": kpts,
            "img_paths": frames,
            "seq_names": seqs}
        return result
    
class WalDynH36MTorchValDataset():
    def __init__(self, dataset: H36MValDataset, history_length: int,  prediction_horizon: int):
        self.history_length = history_length
        self.prediction_horizon = prediction_horizon
    
        self.action = "Custom"

        self._data = dataset
        print(f'Successfully created H36M dataloder',
              '\n\taction: ', self.action,
              '\n\tnumber of samples: ', self.__len__(),
              )
        
    def __len__(self):
        return len(self._data["poses_3d"])
        

    def __getitem__(self, item):
        data = self._data["poses_3d"][item]
        return data[:self.history_length], data[self.history_length: self.history_length + self.prediction_horizon]


## Training loop

In [6]:
model = ZeroPoseLastVelocityBaseline() # chose baseline

In [7]:
torch.autograd.set_detect_anomaly(True)
# Init seed
device = "cpu"
set_seed(seed=0)

dataset = H36MValDataset(dataset_path, split="test")
torch_dataset = WalDynH36MTorchValDataset(dataset=dataset, history_length=history_length,  prediction_horizon=prediction_horizon)
data_loader = DataLoader(torch_dataset, shuffle=False, batch_size=50, num_workers=0)

Successfully created H36M dataset 
	split: val my dataset 
	number of sequences:  32
Successfully created H36M dataloder 
	action:  Custom 
	number of samples:  32


In [8]:
def preprocess(engine: Engine):
        engine.state.batch =  [t.to(device) for t in engine.state.batch[:]]
        
def validation_step_eval(engine: Engine, batch: Sequence[torch.Tensor]):
    model.eval()
    with torch.no_grad():
        x, y = batch
        model_out = model(x, y, ph=prediction_horizon)
        return model_out, y, x

def extract_hip(output):
    output = [out[...,0:1,:] for out in output[:2]]
    return output[0], output[1]

def extract_pose(output):
    output = [(out - out[..., 0:1, :])[..., 1:, :] for out in output[:2]]
    return output[0], output[1]

# Define ignite metrics
mpjpe = MeanPerJointPositionError()
mpjpe_pose = MeanPerJointPositionError(output_transform=extract_pose)
mpjpe_avg = MeanPerJointPositionError(keep_time_dim=False)
mpjpe_hip = MeanPerJointPositionError(output_transform=extract_hip, keep_time_dim=True)

fde = FinalDisplacementError()
fde_pose = FinalDisplacementError(output_transform=extract_pose)
fde_hip = FinalDisplacementError(output_transform=extract_hip)

evaluator = DeterministicEngine(validation_step_eval)
evaluator.add_event_handler(Events.ITERATION_STARTED, preprocess)
mpjpe.attach(evaluator, 'MPJPE')
mpjpe_pose.attach(evaluator, 'MPJPE_POSE')
mpjpe_avg.attach(evaluator, 'MPJPE_AVG')
mpjpe_hip.attach(evaluator, 'MPJPE_HIP')
fde.attach(evaluator, 'FDE')
fde_pose.attach(evaluator, 'FDE_POSE')
fde_hip.attach(evaluator, 'FDE_HIP')

eos = EpochOutputStore()
eos.attach(evaluator, 'output')

pbar = ProgressBar()
pbar.attach(evaluator)

In [9]:
state = evaluator.run(data_loader)
metrics = {}
metrics["MPJPE table"] = [round((state.metrics["MPJPE"][i]).item(), 1) for i in [1, 24, 49, 74, -1]]
metrics["MPJPE avg"] = round(state.metrics["MPJPE"].mean().item(), 1)
metrics["MPJPE POSE table"] = [round((state.metrics["MPJPE_POSE"][i]).item(), 1) for i in [1, 24, 49, 74, -1]]
metrics["MPJPE POSE avg"] = round(state.metrics["MPJPE_POSE"].mean().item(), 1)
metrics["MPJPE HIP table"] = [round((state.metrics["MPJPE_HIP"][i]).item(), 1) for i in [1, 24, 49, 74, -1]]
metrics["MPJPE HIP avg"] = round(state.metrics["MPJPE_HIP"].mean().item(), 1)
for m,v in state.metrics.items():
    if "FDE" in m:
        metrics[m] = round(v, 1)
    

print(f"Result on eval val: ", metrics)

[1/1] 100%|########## [00:00<?]

Result on eval val:  {'MPJPE table': [36.4, 355.2, 810.6, 1477.8, 2262.6], 'MPJPE avg': 947.7, 'MPJPE POSE table': [37.3, 206.6, 250.1, 286.4, 344.2], 'MPJPE POSE avg': 240.0, 'MPJPE HIP table': [7.7, 276.9, 740.6, 1400.4, 2180.4], 'MPJPE HIP avg': 874.0, 'FDE': 2262.6, 'FDE_POSE': 344.2, 'FDE_HIP': 2180.4}
