In [1]:
import os
import numpy as np
import json
from typing import Sequence
from collections import UserDict
from torch.utils import data

import torch
from torch.utils.data import DataLoader
from ignite.contrib.handlers import ProgressBar, EpochOutputStore
from ignite.engine import Events, Engine, DeterministicEngine

In [2]:
import sys
sys.path.append(os.path.abspath('../src'))

from model.utils.torch import  set_seed
from model.core.metrics.mpjpe import MeanPerJointPositionError
from model.core.metrics.fde import FinalDisplacementError

In [3]:
history_length = 16
prediction_horizon = 14

dataset_path = f"../data/somof_data_3dpw"

## Define models & Dataset

In [4]:
def center_pose(output):
    out = (output - output[..., 0:1, :])[..., 1:, :]
    return out

class ZeroPoseGTTrajectoryBaseline(torch.nn.Module):
    def __init__(self, **kwargs):
        super(ZeroPoseGTTrajectoryBaseline, self).__init__()


    def forward(self, x: torch.Tensor, y: torch.Tensor,ph: int = 1)  -> torch.Tensor:
        pose_centered = center_pose(x)[..., -1 ,:,:].unsqueeze(-3)
        gt_trajectory = y[..., 0:1,:]
        out = torch.cat([gt_trajectory, pose_centered+gt_trajectory], dim=-2)
        return out
        
class ZeroBaseline(torch.nn.Module):
    def __init__(self, **kwargs):
        super(ZeroBaseline, self).__init__()


    def forward(self, x: torch.Tensor, y: torch.Tensor,ph: int = 1)  -> torch.Tensor:
        last_frame = x[..., -1 ,:,:].unsqueeze(-3)
        B, T, N, D = last_frame.shape
        last_frame.broadcast_to((B, ph, N, D))
        return last_frame
    
class ZeroPoseLastVelocityBaseline(torch.nn.Module):
    def __init__(self, **kwargs):
        super(ZeroPoseLastVelocityBaseline, self).__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor,ph: int = 1)  -> torch.Tensor:
        pose_centered = center_pose(x)[..., -1 ,:,:].unsqueeze(-3)
        # last_velocity = (x[..., -3: ,0:1,:] - x[..., -4:-1 ,0:1,:]).mean(-3).unsqueeze(-3) # B, 1, 1, 3 
        last_velocity = x[..., -1: ,0:1,:] - x[..., -2:-1 ,0:1,:] # B, 1, 1, 3 
        B, T, N, D = last_velocity.shape # N = 1
        displacement = torch.cat([last_velocity*i for i in range(1, ph+1)], dim=-3) # B, ph, 1, 3 
        displacement +=  x[..., -1 ,0:1,:].unsqueeze(-3)
        
        out = torch.cat([displacement, pose_centered+displacement], dim=-2)
        return out

In [5]:
class SoMoFDataset(UserDict):
    """ """
    original_resolution = [1002, 1000]  # from video data
    images_subdir = 'ImageSequence'
    dataset_name = "3dpw" #"posetrack" 


    def __init__(self, dataset_path, 
                 num_joints=13, split="train", **kwargs):

        self.datase_path = dataset_path
        assert split in ["train", "valid"]
        name_split = split
        dataset_name = "3dpw" #"posetrack"
        assert num_joints == 13
        self.num_joints = num_joints
        
        # self.n_keypoints = 13
        
        super().__init__(self._load(self.datase_path, dataset_name, self.num_joints, name_split))

        print(f'Successfully created SoMoF {dataset_name} dataset from file: ', dataset_path,
              '\n\tsplit: ', split,
              '\n\tnumber of samples: ', len(self["poses_3d"]),
            )
    def __len__(self):
        return len(self["poses_3d"])
                
    @staticmethod
    def _load(dataset_path, dataset_name, num_joints, name_split="train"):
        """returns a dict with img_paths, poses3d, hip3d, centered3d

        Args:
            path (_type_): _description_

        Returns:
            _type_: _description_
        """
        with open(os.path.join(dataset_path, f"{dataset_name}_{name_split}_in.json"), 'r') as f:
            data_in = np.array(json.load(f)) #(221, 2, 16, 39)
        with open(os.path.join(dataset_path, f"{dataset_name}_{name_split}_out.json"), 'r') as f:
            data_out = np.array(json.load(f)) #(221, 2, 14, 39)
            
        data_in = torch.from_numpy(data_in).view(data_in.shape[0]*data_in.shape[1], data_in.shape[2], 
                                                 num_joints, 3).float()
        data_out = torch.from_numpy(data_out).view(data_out.shape[0]*data_out.shape[1], data_out.shape[2], 
                                                 num_joints, 3).float()
        
        data = torch.cat([data_in, data_out], dim=-3)
        # from meters to mm
        data *= 1000
        result = {"poses_3d": data,
                  "gt": data}
        return result

class SoMoFTorchDataset(data.Dataset):
    def __init__(self,
                 dataset: SoMoFDataset,
                 history_length: int,
                 prediction_horizon: int,
                **kwargs
                 ):
        
        self.history_length = history_length
        self.prediction_horizon = prediction_horizon
        assert history_length == 16
        assert prediction_horizon == 14

        self._data = dataset
        print("Len of sample: ", self.__len__())
    
    def __getitem__(self, item):
        poses = self._data["poses_3d"][item]
        return poses[:self.history_length], poses[self.history_length: self.history_length + self.prediction_horizon]

    def __len__(self):
        return len(self._data)  


## Training loop

In [6]:
model = ZeroPoseLastVelocityBaseline() # chose baseline

In [7]:
torch.autograd.set_detect_anomaly(True)
# Init seed
device = "cpu"
set_seed(seed=0)

dataset = SoMoFDataset(dataset_path, split="valid", num_joints=13)
torch_dataset = SoMoFTorchDataset(dataset=dataset, history_length=history_length,  prediction_horizon=prediction_horizon)
data_loader = DataLoader(torch_dataset, shuffle=False, batch_size=50, num_workers=0)

Successfully created SoMoF 3dpw dataset from file:  ../data/somof_data_3dpw 
	split:  valid 
	number of samples:  72
Len of sample:  72


In [8]:
def preprocess(engine: Engine):
        engine.state.batch =  [t.to(device) for t in engine.state.batch[:]]
        
def validation_step_eval(engine: Engine, batch: Sequence[torch.Tensor]):
    model.eval()
    with torch.no_grad():
        x, y = batch
        model_out = model(x, y, ph=prediction_horizon)
        return model_out, y, x

def extract_hip(output):
    output = [out[...,0:1,:] for out in output[:2]]
    return output[0], output[1]

def extract_pose(output):
    output = [(out - out[..., 0:1, :])[..., 1:, :] for out in output[:2]]
    return output[0], output[1]

# Define ignite metrics
mpjpe = MeanPerJointPositionError()
mpjpe_pose = MeanPerJointPositionError(output_transform=extract_pose)
mpjpe_avg = MeanPerJointPositionError(keep_time_dim=False)
mpjpe_hip = MeanPerJointPositionError(output_transform=extract_hip, keep_time_dim=True)

fde = FinalDisplacementError()
fde_pose = FinalDisplacementError(output_transform=extract_pose)
fde_hip = FinalDisplacementError(output_transform=extract_hip)

evaluator = DeterministicEngine(validation_step_eval)
evaluator.add_event_handler(Events.ITERATION_STARTED, preprocess)
mpjpe.attach(evaluator, 'MPJPE')
mpjpe_pose.attach(evaluator, 'MPJPE_POSE')
mpjpe_avg.attach(evaluator, 'MPJPE_AVG')
mpjpe_hip.attach(evaluator, 'MPJPE_HIP')
fde.attach(evaluator, 'FDE')
fde_pose.attach(evaluator, 'FDE_POSE')
fde_hip.attach(evaluator, 'FDE_HIP')

eos = EpochOutputStore()
eos.attach(evaluator, 'output')

pbar = ProgressBar()
pbar.attach(evaluator)

In [9]:
state = evaluator.run(data_loader)
metrics = {}
metrics["MPJPE table"] = [round((state.metrics["MPJPE"][i]/10).item(), 1) for i in [1,3,7,9,13]]
metrics["MPJPE avg"] = round(state.metrics["MPJPE"].mean().item()/10, 1)
metrics["MPJPE POSE table"] = [round((state.metrics["MPJPE_POSE"][i]).item()/10, 1) for i in [1,3,7,9,13]]
metrics["MPJPE POSE avg"] = round(state.metrics["MPJPE_POSE"].mean().item()/10, 1)
metrics["MPJPE HIP table"] = [round((state.metrics["MPJPE_HIP"][i]).item()/10, 1) for i in [1,3,7,9,13]]
metrics["MPJPE HIP avg"] = round(state.metrics["MPJPE_HIP"].mean().item()/10, 1)
for m,v in state.metrics.items():
    if "FDE" in m:
        metrics[m] = round(v/10, 1)
    

print(f"Result on eval val in cm: ", metrics)

[1/2]  50%|#####      [00:00<?]

Result on eval val in cm:  {'MPJPE table': [3.5, 7.3, 14.6, 18.1, 25.5], 'MPJPE avg': 13.6, 'MPJPE POSE table': [3.5, 6.5, 10.7, 11.9, 13.6], 'MPJPE POSE avg': 9.3, 'MPJPE HIP table': [1.2, 3.5, 9.8, 13.0, 19.9], 'MPJPE HIP avg': 9.3, 'FDE': 25.5, 'FDE_POSE': 13.6, 'FDE_HIP': 19.9}
