In [1]:
import os
import numpy as np
import json
import torch

In [2]:
import sys
sys.path.append(os.path.abspath('../src/'))

from skeleton import SoMoFSkeleton

from model.utils.plot import get_np_frames_3d_projection
from model.utils.image import  save_gif

In [3]:
dataset_path = "../data/somof_data_3dpw"
dataset_name = "3dpw"
name_split = "test"
path_to_submission = "../output/somof/exp_name/test_num_epoch"
sub_file_name = "3dpw_predictions.json"


In [4]:
skeleton = SoMoFSkeleton(num_joints=13, if_consider_hip=True, pose_box_size=1200) # box size is not used here

In [5]:
with open(os.path.join(dataset_path, f"{dataset_name}_{name_split}_in.json"), 'r') as f:
    data_in = np.array(json.load(f))
    
data_in = torch.from_numpy(data_in).view(data_in.shape[0], 2, 16, 13, 3)
data_in.shape

torch.Size([85, 2, 16, 13, 3])

In [6]:
with open(os.path.join(path_to_submission, sub_file_name), 'r') as f:
    data_pred = np.array(json.load(f))
data_pred = torch.from_numpy(data_pred).view(data_pred.shape[0], 2, 14, 13, 3)
data_pred.shape

torch.Size([85, 2, 14, 13, 3])

In [13]:
data = torch.cat([data_in, data_pred], dim=-3)

In [8]:
def save_visual(data_track, name="frames"):
  track= data_track.clone()
  track[..., 1] *= -1
  frames_in = torch.stack(get_np_frames_3d_projection(track[:16, :, :].numpy()*1000,limbseq=skeleton.limbseq, left_right_limb=skeleton.left_right_limb, 
                                                    xyz_range=None, center_pose=False, units="mm", 
                                                  as_tensor=True, orientation_like="h36m", title="gt Kpts"), dim=0)
  fake_start = track[0].unsqueeze(0).broadcast_to(track.shape)
  frames_pred = torch.stack(get_np_frames_3d_projection(fake_start.numpy()*1000,data_pred=track.numpy()*1000, limbseq=skeleton.limbseq, left_right_limb=skeleton.left_right_limb, 
                                                    xyz_range=None, center_pose=False, units="mm", 
                                                  as_tensor=True, orientation_like="h36m", title="gt Kpts"), dim=0)
  frames = torch.cat([frames_in, frames_pred[16:]], dim=0)
  save_gif(frames, name=os.path.join(path_to_submission, name), fps=5)

In [14]:
track = data.view(-1, 30, 13, 3)[0]
save_visual(track, name="frames")

In [15]:
pose = track - track[..., 0,:].unsqueeze(-2)
save_visual(pose, name="frames_pose")

In [16]:
print(track[15, 0], track[16, 0])

tensor([ 0.3121, -0.7862,  1.0315], dtype=torch.float64) tensor([ 0.5139, -1.0141,  1.2629], dtype=torch.float64)
