In [1]:
import os
import numpy as np
import json
from collections import UserDict
from typing import Sequence

import torch
from torch.utils import data
from torch.utils.data import DataLoader
from ignite.contrib.handlers import ProgressBar, EpochOutputStore
from ignite.engine import Events, Engine, DeterministicEngine

In [2]:
import sys
sys.path.append(os.path.abspath('../src/'))

from model.utils.torch import  set_seed

In [3]:
history_length = 16
prediction_horizon = 14

dataset_path = f"../data/somof_data_3dpw"

## Define models & Dataset

In [4]:
def center_pose(kpts):
    out = (kpts - kpts[..., 0:1, :])[..., 1:, :]
    return out

class ZeroPoseGTTrajectoryBaseline(torch.nn.Module):
    def __init__(self, **kwargs):
        super(ZeroPoseGTTrajectoryBaseline, self).__init__()


    def forward(self, x: torch.Tensor, y: torch.Tensor,ph: int = 1)  -> torch.Tensor:
        pose_centered = center_pose(x)[..., -1 ,:,:].unsqueeze(-3)
        gt_trajectory = y[..., 0:1,:]
        out = torch.cat([gt_trajectory, pose_centered+gt_trajectory], dim=-2)
        return out
        
class ZeroBaseline(torch.nn.Module):
    def __init__(self, **kwargs):
        super(ZeroBaseline, self).__init__()


    def forward(self, x: torch.Tensor, ph: int = 1)  -> torch.Tensor:
        last_frame = x[..., -1 ,:,:].unsqueeze(-3)
        B, P, T, N, D = last_frame.shape
        last_frame.broadcast_to((B, ph, N, D))
        return last_frame
    
class ZeroPoseLastVelocityBaseline(torch.nn.Module):
    def __init__(self, **kwargs):
        super(ZeroPoseLastVelocityBaseline, self).__init__()


    def forward(self, x: torch.Tensor, ph: int = 1)  -> torch.Tensor:
        pose_centered = center_pose(x)[..., -1 ,:,:].unsqueeze(-3)
        # last_velocity = (x[..., -3: ,0:1,:] - x[..., -4:-1 ,0:1,:]).mean(-3).unsqueeze(-3) # B, 1, 1, 3 
        last_velocity = x[..., -1: ,0:1,:] - x[..., -2:-1 ,0:1,:] # B, 1, 1, 3 
        B, P, T, N, D = last_velocity.shape # N = 1
        displacement = torch.cat([last_velocity*i for i in range(1, ph+1)], dim=-3) # B, ph, 1, 3 
        displacement +=  x[..., -1 ,0:1,:].unsqueeze(-3)
        
        out = torch.cat([displacement, pose_centered+displacement], dim=-2)
        return out

In [5]:
class SoMoFDataset(UserDict):
    original_resolution = [1002, 1000]  # from video data
    images_subdir = 'ImageSequence'
    dataset_name = "3dpw" #"posetrack" 


    def __init__(self, dataset_path, 
                 num_joints=13, **kwargs):

        self.datase_path = dataset_path
        name_split = "test"
        dataset_name = "3dpw" #"posetrack"
        assert num_joints == 13
        self.num_joints = num_joints
        # self.n_keypoints = 13
        
        super().__init__(self._load(self.datase_path, dataset_name, self.num_joints, name_split))

        print(f'Successfully created SoMoF {dataset_name} dataset from file: ', dataset_path,
              '\n\tnumber of samples: ', len(self["poses_3d"]),
                )
    def __len__(self):
        return len(self["poses_3d"])
                
    @staticmethod
    def _load(dataset_path, dataset_name, num_joints, name_split="train"):
        with open(os.path.join(dataset_path, f"{dataset_name}_{name_split}_in.json"), 'r') as f:
            data_in = np.array(json.load(f)) #(221, 2, 16, 39)
    
        data_in = torch.from_numpy(data_in).view(data_in.shape[0], data_in.shape[1], data_in.shape[2], 
                                                 num_joints, 3).float()
        data = data_in
        # from meters to mm
        # data *= 1000
        result = {"poses_3d": data,
                  "gt": data}
        return result

class SoMoFTorchDataset(data.Dataset):
    def __init__(self,
                 dataset: SoMoFDataset,
                 history_length: int,
                 prediction_horizon: int,
                **kwargs
                 ):
        
        self.history_length = history_length
        self.prediction_horizon = prediction_horizon
        assert history_length == 16
        assert prediction_horizon == 14

        self._data = dataset
        print("Len of sample: ", self.__len__())
    
    def __getitem__(self, item):
        poses = self._data["poses_3d"][item]
        return poses

    def __len__(self):
        return len(self._data)  


## Training loop

In [6]:
model = ZeroPoseLastVelocityBaseline() # chose baseline
baseline_name = "ZeroPoseLastVelocityBaseline"

In [7]:
torch.autograd.set_detect_anomaly(True)
# Init seed
device = "cpu"
set_seed(seed=0)

dataset = SoMoFDataset(dataset_path, split="valid", num_joints=13)
torch_dataset = SoMoFTorchDataset(dataset=dataset, history_length=history_length,  prediction_horizon=prediction_horizon)
data_loader = DataLoader(torch_dataset, shuffle=False, batch_size=50, num_workers=0)

Successfully created SoMoF 3dpw dataset from file:  ../data/somof_data_3dpw 
	number of samples:  85
Len of sample:  85


In [8]:
def preprocess(engine: Engine):
        engine.state.batch =  engine.state.batch.to(device)
def validation_step_eval(engine: Engine, batch: Sequence[torch.Tensor]):
    model.eval()
    with torch.no_grad():
        x = batch
        model_out = model(x, ph=prediction_horizon)
        return model_out, x


evaluator = DeterministicEngine(validation_step_eval)
evaluator.add_event_handler(Events.ITERATION_STARTED, preprocess)

eos = EpochOutputStore()
eos.attach(evaluator, 'output')

pbar = ProgressBar()
pbar.attach(evaluator)

In [9]:
state = evaluator.run(data_loader)
output = torch.cat([kpts for kpts, start in state.output], dim=0)
output = output.view((-1, 2, 14, 13*3))
print(output.shape)
os.makedirs(f"result_baselines/somof/baselines_{baseline_name}", exist_ok=True)
with open(os.path.join(f"result_baselines/somof/baselines_{baseline_name}", '3dpw_predictions.json'), 'w') as f:
    f.write(json.dumps(output.tolist()))

[1/2]  50%|#####      [00:00<?]

torch.Size([85, 2, 14, 39])
