In [None]:
%cd ../src

In [None]:
%ls

In [None]:
import os
from common.utils import Params, set_logger, copy_weight, load_checkpoint, save_checkpoint_pos_ori, write_log, get_lr, write_train_summary_scalars, write_val_summary_joint, change_momentum
from features.networks import TGraphNet, TGraphNetSeq
from vizualization.vizualize import plot_adjacency_matrix, plot_pose_animation, plot_poses_only, plot_poses_merged
from common.h36m_skeleton import get_node_names, get_edge_names
import torch
from graph import Graph
import numpy as np
import seaborn as sns
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from pytransform3d import rotations as pr
from pytransform3d.plot_utils import make_3d_axis
import mpl_toolkits.mplot3d.art3d as art3d
from matplotlib.transforms import Affine2D
from matplotlib.text import TextPath
from matplotlib.patches import PathPatch
from common.h36m_skeleton import *
from matplotlib.animation import FuncAnimation

from data.h36m_dataset import Human36M
from data.pw3d_dataset import PW3D
from data.dhp_dataset import DHPDataset
from common.h36m_skeleton import joint_id_to_names
from data.generators import ChunkedGenerator_Seq, UnchunkedGenerator_Seq, ChunkedGenerator_Frame, ChunkedGenerator_Seq2Seq, eval_data_prepare, ChunkedGeneratorDHP
from pytransform3d.plot_utils import make_3d_axis

from evaluation import mpjpe

In [None]:
# Load parameters
json_path = os.path.join('../models/stgcn/root_rel/params.json')
assert os.path.isfile(json_path), "No json file found at {}".format(json_path)
params = Params(json_path)

In [None]:
model = TGraphNetSeq(infeat_v=params.input_node_feat,
                  infeat_e=params.input_edge_feat,
                  nhid_v=params.num_hidden_nodes,
                  nhid_e=params.num_hidden_edges,
                  n_oute=params.output_edge_feat,
                  n_outv=params.output_node_feat,
                  gcn_window=params.gcn_window,
                  tcn_window=params.tcn_window,
                  in_frames=params.in_frames,
                  gconv_stages=params.gconv_stages,
                  num_groups=params.num_groups,
                  dropout=params.dropout,
                  aggregate=params.aggregate,
                  use_residual_connections=params.use_residual_connections,
                  use_non_parametric=params.use_non_parametric,
                  use_edge_conv=params.use_edge_conv,
                  learn_adj=False)

load_checkpoint('../models/stgcn/root_rel/best_pos.pth.tar', model)

In [None]:
json_path = os.path.join('../models/stgcn/run1/params.json')
assert os.path.isfile(json_path), "No json file found at {}".format(json_path)
params = Params(json_path)

model_single = TGraphNet(infeat_v=params.input_node_feat,
                  infeat_e=params.input_edge_feat,
                  nhid_v=params.num_hidden_nodes,
                  nhid_e=params.num_hidden_edges,
                  n_oute=params.output_edge_feat,
                  n_outv=params.output_node_feat,
                  gcn_window=params.gcn_window,
                  tcn_window=params.tcn_window,
                  in_frames=params.in_frames,
                  gconv_stages=params.gconv_stages,
                  num_groups=params.num_groups,
                  dropout=params.dropout,
                  aggregate=params.aggregate,
                  use_residual_connections=params.use_residual_connections,
                  use_non_parametric=params.use_non_parametric,
                  use_edge_conv=params.use_edge_conv,
                  learn_adj=False)

load_checkpoint('../models/stgcn/run1/best_pos.pth.tar', model_single)

In [None]:
plot_adjacency_matrix(model_single.state_dict()['adj_v'][3].cpu(), node_names=get_node_names(3), annotate_values=True)

# Plotting Predicted 3D Poses

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# test_dataset = Human36M(data_dir="/media/HDD3/datasets/Human3.6M/pose_zip", train=False, ds_category=params.ds_category, actions='SittingDown')
# test_dataset = PW3D(data_file="../data/pw3d_test.pkl", actions=['downtown_runForBus_01_0'])
test_dataset = DHPDataset(data_dir="../data/", train=False)

val_generator = ChunkedGeneratorDHP(1, cameras=None, poses_2d=test_dataset.pos2d, poses_3d=test_dataset.pos3d,
                          valid_frame=test_dataset.valid_frame, train=False,
                                                chunk_length=1, pad=40, out_all=True, shuffle=False,
                                                augment=False, reverse_aug=False,)


In [None]:
test_dataset.pos2d.keys()

In [None]:
cam, pos2d, pos3d, angles_6d, edge_features = test_dataset.cam, test_dataset.pos2d, test_dataset.pos3d_centered, [], []
val_generator = ChunkedGenerator_Seq2Seq(1, cameras=cam, poses_2d=pos2d, poses_3d=pos3d,
                                                   chunk_length=31, pad=25, out_all=True, shuffle=False,
                                                   augment=False, reverse_aug=False,)

In [None]:
cam, pos2d, pos3d, angles_6d, edge_features = test_dataset.cam, test_dataset.pos2d, test_dataset.pos3d_centered, [], []
val_generator_single = ChunkedGenerator_Seq2Seq(1, cameras=cam, poses_2d=pos2d, poses_3d=pos3d,
                                                   chunk_length=1, pad=40, out_all=True, shuffle=False,
                                                   augment=False, reverse_aug=False,)

In [None]:
cam[0]

In [None]:
model.to(device).eval()
None

In [None]:
from matplotlib.animation import PillowWriter, FFMpegWriter

In [None]:
# To get MPJPE per each frame

def mpjpe(predicted, target):
    """
    Mean per-joint position error (i.e. mean Euclidean distance),
    often referred to as "Protocol #1" in many papers.
    returns mean error across all data points
    and mean per joint error 17 x 1
    """
    assert predicted.shape == target.shape
    return torch.mean(torch.norm(predicted - target, dim=len(target.shape)-1), dim=(2))

In [None]:
gt_pos3d = np.empty([0, 17, 3])
pred_pose3d = np.empty([0, 17, 3])
mpjpe_err = np.empty([0])
mpjpe_err_traj = np.empty([0])

i_batch = 0

with torch.no_grad():
  for seq_name, start_3d, end_3d, flip, reverse in val_generator.pairs:
      cam, batch_3d, batch_2d, seq, subject, cam_ind = val_generator.get_batch(seq_name, start_3d, end_3d, flip, reverse)

      input_2d = torch.FloatTensor(batch_2d).to(device).reshape(-1, 81, 17, 2)
      target_pose_3d = torch.FloatTensor(batch_3d).to(device).reshape(-1, 81, 17, 3)
      # cameras_val = torch.from_numpy(cameras_val.astype('float32')).to(device)

      middle_index = int((target_pose_3d.shape[1] - 1) / 2)
      pad = 0
      start_index = middle_index - pad
      end_index = middle_index + pad + 1
      B, T, J, D = target_pose_3d.shape

      predicted_pos3d = model(input_2d)
      predicted_pos3d_center = predicted_pos3d[:, start_index:end_index].reshape(B, (2 * pad + 1), J, D)

      # target_angle_6d = target_angle_6d[:, middle_index].view_as(predicted_angle_6d)
      target_pose_3d = target_pose_3d.view_as(predicted_pos3d)
      target_pose_3d_center = target_pose_3d[:, start_index:end_index].reshape(B, (2 * pad + 1), J, D)

        #   target_pose_3d_center[:, :, :1] = 0
        #   predicted_pos3d_center[:, :, :1] = 0
      target_pose_3d_center[:, :, 1:] += target_pose_3d_center[:, :, :1]
      predicted_pos3d_center[:, :, 1:] += predicted_pos3d_center[:, :, :1]

        # mpjpe
      _mpjpe = mpjpe(target_pose_3d_center.cpu().data, predicted_pos3d_center.cpu().data)[0].numpy().reshape(-1)
      _mpjpe_traj = mpjpe(target_pose_3d_center.cpu().data[:, :, :1], predicted_pos3d_center.cpu().data[:, :, :1])[0].numpy().reshape(-1)

      if i_batch == 0:
          gt_pos3d = target_pose_3d_center.reshape(-1, J, D).cpu().data.numpy()
          pred_pose3d = predicted_pos3d_center.reshape(-1, J, D).cpu().data.numpy()
          mpjpe_err = np.array(_mpjpe).reshape(-1)
          mpjpe_err_traj = np.array(_mpjpe_traj).reshape(-1)
      else:
          gt_pos3d = np.concatenate((gt_pos3d, target_pose_3d_center.reshape(-1, J, D).cpu().data.numpy()), axis=0)
          pred_pose3d = np.concatenate((pred_pose3d, predicted_pos3d_center.reshape(-1, J, D).cpu().data.numpy()), axis=0)
          mpjpe_err = np.concatenate((mpjpe_err, np.array(_mpjpe).reshape(-1)), axis=0)
          mpjpe_err_traj = np.concatenate((mpjpe_err_traj, np.array(_mpjpe_traj).reshape(-1)), axis=0)

      i_batch += 1


In [None]:
import json

# actions = ['Directions', 'Discussion','Eating', 'Greeting','Phoning','Photo', 'Posing', 'Purchases','Sitting', 'SittingDown', 'Smoking', 'Waiting','WalkDog','WalkTogether','Walking']
# actions = ['Eating']
actions = ['WalkDog']

gt_pos3d = np.empty([0, 17, 3])
pred_pose3d = np.empty([0, 17, 3])
mpjpe_err = np.empty([0])
mpjpe_err_traj = np.empty([0])

i_batch = 0

with torch.no_grad():
    for cameras_val, batch_3d, batch_6d, batch_2d, batch_edge in val_generator.next_epoch():
        input_2d = torch.FloatTensor(batch_2d).to(device)
        target_pose_3d = torch.FloatTensor(batch_3d).to(device)
        cameras_val = torch.from_numpy(cameras_val.astype('float32')).to(device)

        middle_index = int((target_pose_3d.shape[1] - 1) / 2)
        pad = 15
        start_index = middle_index - pad
        end_index = middle_index + pad + 1
        B, T, J, D = target_pose_3d.shape

        predicted_pos3d = model(input_2d)
        predicted_pos3d_center = predicted_pos3d[:, start_index:end_index].reshape(B, (2 * pad + 1), J, D)

        # target_angle_6d = target_angle_6d[:, middle_index].view_as(predicted_angle_6d)
        target_pose_3d = target_pose_3d.view_as(predicted_pos3d)
        target_pose_3d_center = target_pose_3d[:, start_index:end_index].reshape(B, (2 * pad + 1), J, D)

        # target_pose_3d_center[:, :, :1] = 0
        # predicted_pos3d_center[:, :, :1] = 0
        target_pose_3d_center[:, :, 1:] += target_pose_3d_center[:, :, :1]
        predicted_pos3d_center[:, :, 1:] += predicted_pos3d_center[:, :, :1]

        # mpjpe
        _mpjpe = mpjpe(target_pose_3d_center.cpu().data, predicted_pos3d_center.cpu().data)[0].numpy().reshape(-1)
        _mpjpe_traj = mpjpe(target_pose_3d_center.cpu().data[:, :, :1], predicted_pos3d_center.cpu().data[:, :, :1])[0].numpy().reshape(-1)

        if i_batch == 0:
            gt_pos3d = target_pose_3d_center.reshape(-1, J, D).cpu().data.numpy()
            pred_pose3d = predicted_pos3d_center.reshape(-1, J, D).cpu().data.numpy()
            mpjpe_err = np.array(_mpjpe).reshape(-1)
            mpjpe_err_traj = np.array(_mpjpe_traj).reshape(-1)
        else:
            gt_pos3d = np.concatenate((gt_pos3d, target_pose_3d_center.reshape(-1, J, D).cpu().data.numpy()), axis=0)
            pred_pose3d = np.concatenate((pred_pose3d, predicted_pos3d_center.reshape(-1, J, D).cpu().data.numpy()), axis=0)
            mpjpe_err = np.concatenate((mpjpe_err, np.array(_mpjpe).reshape(-1)), axis=0)
            mpjpe_err_traj = np.concatenate((mpjpe_err_traj, np.array(_mpjpe_traj).reshape(-1)), axis=0)

        i_batch += 1

In [None]:
# Single frame prediction

import json

actions = ['WalkDog']

gt_pos3d = np.empty([0, 17, 3])
pred_pose3d = np.empty([0, 17, 3])
mpjpe_err = np.empty([0])

i_batch = 0

model_single.to(device).eval()

with torch.no_grad():
    for cameras_val, batch_3d, batch_6d, batch_2d, batch_edge in val_generator_single.next_epoch():
        input_2d = torch.FloatTensor(batch_2d).to(device)
        target_pose_3d = torch.FloatTensor(batch_3d).to(device)

        # out_3d, input_2d = eval_data_prepare(params.in_frames, input_2d, target_pose_3d)
        # target_pose_3d = out_3d.to(device)
        # # target_angle_6d = out_6d.to(device)
        # input_2d = input_2d.to(device)

        middle_index = int((target_pose_3d.shape[1] - 1) / 2)

        predicted_pos3d_center = model_single(input_2d[:, middle_index])
        # predicted_pos3d_center[:, :, 0] = 0  # 0 out hip pos

        # target_angle_6d = target_angle_6d[:, middle_index].view_as(predicted_angle_6d)
        target_pose_3d_center = target_pose_3d[:, middle_index].view_as(predicted_pos3d_center)
        # target_pose_3d_center[:, :, 0] = 0  # 0 out the hip pose

        # mpjpe
        _mpjpe = mpjpe(target_pose_3d_center.cpu().data, predicted_pos3d_center.cpu().data)[0].numpy().reshape(-1)

        if i_batch == 0:
            gt_pos3d = target_pose_3d_center.reshape(-1, 17, 3).cpu().data.numpy()
            pred_pose3d = predicted_pos3d_center.reshape(-1, 17, 3).cpu().data.numpy()
            mpjpe_err = np.array(_mpjpe).reshape(-1)
        else:
            gt_pos3d = np.concatenate((gt_pos3d, target_pose_3d_center.reshape(-1, 17, 3).cpu().data.numpy()), axis=0)
            pred_pose3d = np.concatenate((pred_pose3d, predicted_pos3d_center.reshape(-1, 17, 3).cpu().data.numpy()), axis=0)
            mpjpe_err = np.concatenate((mpjpe_err, np.array(_mpjpe).reshape(-1)), axis=0)


        i_batch += 1

        if i_batch > 1000:
          break

In [None]:
scale = 30

def plot_pose_animation(pred_pos3d, gt_pos3d, mpjpe_err=None, action="", num_frames=81, save_path=None):
    assert save_path is not None

    fig = plt.figure(figsize=(12, 12))
    ax = make_3d_axis(20, pos=int('11{}'.format(1)), n_ticks=5,)

    def animate(i):
        ax.clear()
        pred_pos = pred_pos3d[i]
        gt_pos = gt_pos3d[i]
        pos_err = mpjpe_err[i] if mpjpe_err is not None else 0
        ax.view_init(-80, 90) # view them from front
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_zticklabels([])
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_zlabel("")
        ax.set_ylim(-scale, scale)
        ax.set_xlim(-scale, scale)
        ax.set_zlim(-scale, scale)

        plot_poses_only(gt_pos, pred_pos, "", pos_err, ax=ax, x_offset=-20, y_offset=-14)

    ani = FuncAnimation(fig, animate, frames=num_frames, repeat=False, interval=10) #interval: Delay between frames in milliseconds

    return ani

In [None]:
anim = plot_pose_animation(pred_pose3d - gt_pos3d[:, :1], gt_pos3d=gt_pos3d - gt_pos3d[:, :1], save_path="", num_frames=500)
anim.save("../reports/sittingdown.gif", dpi=50, writer=PillowWriter(fps=24))

In [None]:
order = np.argsort(mpjpe_err_traj)

In [None]:
mpjpe_err_traj[order]

In [None]:
mpjpe_err.shape, gt_pos3d.shape

In [None]:
percentiles = [0.25, 0.5, 0.75, 0.95, 0.99, 1]

In [None]:
for p in percentiles:
  perc_index = (mpjpe_err.shape[0] - 1) * p
  perc_index = int(perc_index + 0.5)

  lab = "Median" if p == 0.5 else f"{int(p * 100)}th Percentile"
  lab = "Worst" if p == 1 else lab

  sns.set_style("whitegrid")
  pose_scale = 18

  fig = plt.figure(figsize=(12, 12))
  ax = make_3d_axis(20, pos=int('11{}'.format(1)), n_ticks=5,)
  ax.clear()
  ax.view_init(-90, 90) # view them from front
  ax.set_xticklabels([])
  ax.set_yticklabels([])
  ax.set_zticklabels([])
  ax.set_xlabel("")
  ax.set_zlabel("")
  ax.set_ylim(-pose_scale, pose_scale)
  ax.set_xlim(-pose_scale, pose_scale)
  ax.set_zlim(-pose_scale, pose_scale)

  
  ax.yaxis.set_rotate_label(False)
  ax.set_ylabel(lab, fontsize=24, rotation=90)

  # bbox_inches='tight'
  plot_poses_merged(gt_pos3d[order][perc_index] - gt_pos3d[order][perc_index][:1], pred_pose3d[order][perc_index] - gt_pos3d[order][perc_index][:1], action="SittingDown", mpjpe=mpjpe_err[order][perc_index], ax=ax, x_offset=0, y_offset=-2, text_y=16)
  plt.tight_layout()

  # plt.savefig(f"../reports/dhp_plots_traj/dhp_subplots{p}th.png", bbox_inches='tight', pad_inches = 0, dpi=300)
  # plt.show()

In [None]:
for p in percentiles:
  perc_index = (mpjpe_err.shape[0] - 1) * p
  perc_index = int(perc_index + 0.5)

  lab = "Median" if p == 0.5 else f"{int(p * 100)}th Percentile"
  lab = "Worst" if p == 1 else lab

  sns.set_style("whitegrid")
  pose_scale = 18

  fig = plt.figure(figsize=(24, 12))
  ax1 = make_3d_axis(20, pos=int('12{}'.format(1)), n_ticks=5,)
  # ax1 = fig.gca(projection="3d")
  ax1.clear()
  ax1.view_init(30, 90) # view them from front
  # ax1.set_xticklabels([])
  # ax1.set_yticklabels([])
  # ax1.set_zticklabels([])
  ax1.set_xlabel("X", fontsize=22)
  ax1.set_zlabel("Z", fontsize=22)
  ax1.set_ylim(-pose_scale, pose_scale)
  ax1.set_xlim(-pose_scale, pose_scale)
  ax1.set_zlim(-pose_scale, pose_scale)

  ax2 = make_3d_axis(20, pos=int('12{}'.format(2)), n_ticks=5,)
  ax2.clear()
  ax2.view_init(30, 90) # view them from front
  # ax2.set_xticklabels([])
  # ax2.set_yticklabels([])
  # ax2.set_zticklabels([])
  ax2.set_xlabel("X")
  ax2.set_zlabel("Z", fontsize=22)
  ax2.set_ylim(-pose_scale, pose_scale)
  ax2.set_xlim(-pose_scale, pose_scale)
  ax2.set_zlim(-pose_scale, pose_scale)

  
  ax1.yaxis.set_rotate_label(False)
  ax1.set_ylabel(lab, fontsize=24, rotation=90)

  ax2.yaxis.set_rotate_label(False)
  ax2.set_ylabel(lab, fontsize=24, rotation=90)

  # bbox_inches='tight'
  plot_poses_merged(-(gt_pos3d[order][perc_index]), np.zeros_like(gt_pos3d[order][perc_index]), action="SittingDown", mpjpe=mpjpe_err[order][perc_index], ax=ax1, x_offset=0, y_offset=-2, text_y=16)
  plot_poses_merged(np.zeros_like(gt_pos3d[order][perc_index]), -(pred_pose3d[order][perc_index]), action="SittingDown", mpjpe=mpjpe_err[order][perc_index], ax=ax2, x_offset=0, y_offset=-2, text_y=16)
  plt.tight_layout()

  # plt.savefig(f"../reports/dhp_plots_traj/dhp_subplots{p}th.png", bbox_inches='tight', pad_inches = 0, dpi=300)
  # plt.show()

In [None]:
for p in percentiles:
  perc_index = (mpjpe_err.shape[0] - 3) * p
  perc_index = int(perc_index + 0.5)

  lab = "Median" if p == 0.5 else f"{int(p * 100)}th Percentile"
  lab = "Worst" if p == 1 else lab

  sns.set_style("whitegrid")
  pose_scale = 14

  fig = plt.figure(figsize=(12, 12))
  ax = make_3d_axis(20, pos=int('11{}'.format(1)), n_ticks=5,)
  ax.clear()
  ax.view_init(-80, 90) # view them from front
  ax.set_xticklabels([])
  ax.set_yticklabels([])
  ax.set_zticklabels([])
  ax.set_xlabel("")
  ax.set_zlabel("")
  ax.set_ylim(-pose_scale, pose_scale)
  ax.set_xlim(-pose_scale, pose_scale)
  ax.set_zlim(-pose_scale, pose_scale)

  
  ax.yaxis.set_rotate_label(False)
  ax.set_ylabel(lab, fontsize=24, rotation=90)

  # bbox_inches='tight'
  plot_poses_merged(gt_pos3d[order][perc_index] - gt_pos3d[order][perc_index][:1], pred_pose3d[order][perc_index] - gt_pos3d[order][perc_index][:1], action="Photo", mpjpe=mpjpe_err[order][perc_index], ax=ax, x_offset=0, y_offset=-2, text_y=13)
  plt.tight_layout()

  plt.savefig(f"../reports/traj_plots_h36m/Photo{p}th.png", bbox_inches='tight', pad_inches = 0, dpi=300)
  plt.show()

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
arr = np.array([49.7538, 48.3057, 47.8841, 47.6104, 47.7223, 47.3221, 47.2411, 47.2078,
        47.1978, 47.1767, 47.1144, 47.0833, 47.0485, 46.9908, 46.9347, 46.9129,
        46.8645, 46.8407, 46.8305, 46.7807, 46.7581, 46.7608, 46.7286, 46.7075,
        46.6989, 46.6775, 46.6921, 46.7029, 46.6718, 46.6809, 46.7028, 46.6962,
        46.6921, 46.6889, 46.6647, 46.6807, 46.6998, 46.6770, 46.6880, 46.7206,
        46.7394, 46.7421, 46.7347, 46.7025, 46.7159, 46.7400, 46.7235, 46.7272,
        46.7431, 46.7554, 46.7867, 46.7975, 46.7783, 46.7961, 46.8182, 46.8030,
        46.8229, 46.8512, 46.8844, 46.9349, 46.9697, 46.9793, 47.0194, 47.0600,
        47.0803, 47.1180, 47.1604, 47.2124, 47.2804, 47.3459, 47.3640, 47.4203,
        47.4780, 47.4975, 47.5314, 47.6380, 48.0289, 47.9970, 48.4052, 48.8925,
  50.4541])

In [None]:
fig = plt.figure(figsize=(12, 8))

labs = [' ']*81
labs[0] = 1
labs[-1] = 81

sns.set(font_scale=1.4)
sns.set_style('whitegrid')
ax = sns.barplot(y=arr, x=np.arange(1, 82), color="steelblue")
ax.set_xticklabels(labs, fontsize=20)
ax.set_xlabel("Frame", fontsize=20, fontdict={'weight': 'bold'})
ax.set_ylabel("MPJPE (mm)", fontsize=20, fontdict={'weight': 'bold'})
plt.tight_layout()

plt.savefig("../reports/figures/frame_wise_mpjpe.pdf")
plt.show()

In [None]:
def plot_3d_traj(pred_pose3d_first, pred_pose3d_second, joint_idx, savepath=None):

    sns.set_style("whitegrid")
    # set up a figure twice as wide as it is tall
    fig = plt.figure(figsize=(24, 12))
    ax1 = fig.add_subplot(1, 2, 1, projection='3d')
    plt.subplots_adjust(wspace=0.1, hspace=5)  # set the spacing between axes.

    # Data for a three-dimensional line
    zline = pred_pose3d_first[:, joint_idx, 2]
    xline = pred_pose3d_first[:, joint_idx, 0]
    yline = pred_pose3d_first[:, joint_idx, 1]
    # ax1.set_xlim([-400, 400])
    # # ax1.set_zlim([-400, 400])
    # ax1.set_ylim([600, 1000])
    ax1.plot3D(xline, yline, zline, 'black', linewidth=1.5, alpha=0.8)

    ax2 = fig.add_subplot(1, 2, 2, projection='3d', sharex = ax1, sharey = ax1, sharez= ax1)

    # Data for a three-dimensional line
    zline = pred_pose3d_second[:, joint_idx, 2]
    xline = pred_pose3d_second[:, joint_idx, 0]
    yline = pred_pose3d_second[:, joint_idx, 1]
    # ax2.set_xlim([-400, 400])
    # # ax1.set_zlim([-400, 400])
    # ax2.set_ylim([600, 1000])

    ax1.plot3D(xline, yline, zline, 'red', linewidth=1.5, alpha=0.6)

    font = {
        'family': 'arial',
        'color':  'black',
        'weight': 'normal',
        'size': 24,
    }

    ax1.set_xlabel("X", fontdict=font, labelpad=8.)
    ax2.set_xlabel("X", fontdict=font, labelpad=8.)

    ax1.set_ylabel("Y", fontdict=font, labelpad=8.)
    ax2.set_ylabel("Y", fontdict=font, labelpad=8.)

    ax1.set_zlabel("Z", fontdict=font, labelpad=8.)
    ax2.set_zlabel("Z", fontdict=font, labelpad=8.)

    ax1.set_title("Trajectory of Hip Joint", fontdict=font)
    ax2.set_title("Single-Frame Output", fontdict=font)

    # ax1.view_init(elev=-80, azim=90) #Works!

    ax1.grid(True)
    ax1.set_xticklabels([])
    ax1.set_yticklabels([])
    ax1.set_zticklabels([])

    ax2.grid(True)
    ax2.set_xticklabels([])
    ax2.set_yticklabels([])
    ax2.set_zticklabels([])

    
    if savepath:
        plt.savefig(f"{savepath}", bbox_inches='tight', pad_inches = 0, dpi=300)
    plt.show()

In [None]:
def plot_3d_traj_cmp(pred_pose3d_first, pred_pose3d_second, joint_idx, savepath=None, scale=0.5):
    sns.set_style("whitegrid")
    # set up a figure twice as wide as it is tall
    fig = plt.figure(figsize=(24, 12))
    ax1 = fig.add_subplot(1, 2, 1, projection='3d')
    plt.subplots_adjust(wspace=0.1, hspace=5)  # set the spacing between axes.

    # Data for a three-dimensional line
    zline = pred_pose3d_first[:, joint_idx, 2]
    xline = pred_pose3d_first[:, joint_idx, 0]
    yline = pred_pose3d_first[:, joint_idx, 1]
    ax1.set_xlim([-scale, scale])
    ax1.set_zlim([-scale, scale])
    ax1.set_ylim([-scale, scale])
    ax1.plot3D(xline, yline, zline, 'royalblue', linewidth=2, alpha=1)

    ax2 = fig.add_subplot(1, 2, 2, projection='3d', sharex = ax1, sharey = ax1, sharez= ax1)

    # Data for a three-dimensional line
    zline = pred_pose3d_second[:, joint_idx, 2]
    xline = pred_pose3d_second[:, joint_idx, 0]
    yline = pred_pose3d_second[:, joint_idx, 1]
    ax2.set_xlim([-scale, scale])
    ax2.set_zlim([-scale, scale])
    ax2.set_ylim([-scale, scale])

    ax2.plot3D(xline, yline, zline, 'darkorange', linewidth=2, alpha=1)

    font = {
        'family': 'arial',
        'color':  'black',
        'weight': 'normal',
        'size': 24,
    }

    ax1.set_xlabel("X", fontdict=font, labelpad=8.)
    ax2.set_xlabel("X", fontdict=font, labelpad=8.)

    ax1.set_ylabel("Y", fontdict=font, labelpad=8.)
    ax2.set_ylabel("Y", fontdict=font, labelpad=8.)

    ax1.set_zlabel("Z", fontdict=font, labelpad=8.)
    ax2.set_zlabel("Z", fontdict=font, labelpad=8.)

    ax1.set_title("Single-Frame Output", fontdict=font)
    ax2.set_title("Multi-Frame Output", fontdict=font)

    # ax1.view_init(elev=-80, azim=90) #Works!

    ax1.grid(True)
    ax1.set_xticklabels([])
    ax1.set_yticklabels([])
    ax1.set_zticklabels([])

    ax2.grid(True)
    ax2.set_xticklabels([])
    ax2.set_yticklabels([])
    ax2.set_zticklabels([])

    
    if savepath:
        plt.savefig(f"{savepath}", bbox_inches='tight', pad_inches = 0, dpi=300)
    plt.show()

In [None]:
plot_3d_traj_cmp(pred_pose3d_first=(pred_pose3d - pred_pose3d[:, :1])[100:250] / 1000, pred_pose3d_second=(gt_pos3d - gt_pos3d[:, :1])[100:250] / 1000, joint_idx=12, savepath="../reports/traj_plots_h36m/walking_traj_single_lelbow.png", scale=0.3)

In [None]:
# {rwrist: 0.4, lwrist: 0.4, relbow: 0.3, lelbow: 0.3}

plot_3d_traj_cmp(pred_pose3d_second=(pred_pose3d - pred_pose3d[:, :1])[100:250] / 1000, pred_pose3d_first=(gt_pos3d - gt_pos3d[:, :1])[100:250] / 1000, joint_idx=9, savepath="../reports/traj_plots_h36m/walking_traj_multi_head.png", scale=0.3)

In [None]:
((gt_pos3d - gt_pos3d[:, :1]) / 1000)[:, 16]