In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import dataloader,dataset
from torch.utils.data import DataLoader
import torch.optim as optim
import torchvision
from torchvision.utils import save_image
from torch.distributions.normal import Normal
from torchvision import datasets, models, transforms
from torchvision.transforms import ToTensor 
from torchsummary import summary
import math
import os
import os.path as osp
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
import datetime
import cv2
from skimage.util import img_as_ubyte
from skimage import io
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import argparse
from datetime import datetime
import time
from IPython import display
import networkx as nx
import glob
import hashlib
import pickle
from tqdm import tqdm
import json
import scipy.io as sio
import copy
import sys
import copy
from utils import Progbar
from loss import Loss
import utils
from argparse import ArgumentParser
from scipy.linalg import expm
from scipy.spatial import procrustes
from models.indiv_crossAttention import crossAttention

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.cuda.device_count()
torch.cuda.current_device()

### Config

In [4]:
T_obs                   = 30
T_pred                  = 60
T_total                 = T_obs + T_pred
batch_size              = 16
in_size                 = 63
out_size                = 21
stochastic_out_size     = out_size * 3
hidden_size             = 256
embed_size              = 64
global dropout_val
dropout_val             = 0.2
teacher_forcing_ratio   = 0.7
avg_n_path_eval         = 20
bst_n_path_eval         = 20
path_mode               = "avg" 
startpoint_mode         = "on"

In [7]:
class Args:
    #Model specific parameters
    learn_A=True #Self learning Adjacecny matrix #learn_A=False
    video_back = False #Use Sequence of image embedding 
    background_back = False #Use single image embedding
    #Data specifc paremeters
    obs_seq_len=30
    pred_seq_len=60
    dataset='GTA_IM'
    im_w= 90
    im_h= 160
    
    #Training specifc parameters
    num_epochs=500
    log_frq=32
    batch_size=16
    clip_grad=1.0   
    lr=0.001
    lr_sh_rate=200 
    use_lrschd=True
    tag=''
    eval_only=False #evaluate the model
    torso_joint=13 #center of torso (13 for GTA-IM)
args=Args()

#create a unique tag per exp
original_tag = args.tag
args_hash = ''
for k,v in vars(args).items():
    if k == 'eval_only' or k =='torso_joint':
        continue
    args_hash += str(k)+str(v)
args_hash = hashlib.sha256(args_hash.encode()).hexdigest()
args.tag+=args_hash

input_window_size = 30 
output_window_size = 60

# Dataloader_Gta

In [8]:
class D2_D3_GTA_IM(Dataset):
    def __init__(self, dpaths, tag='', seq_in=5, seq_out=10, load_img = True, load_depth =False, img_resize=(90,160), save= True):
        super(D2_D3_GTA_IM, self).__init__()
        self.load_img = load_img
        self.load_depth = load_depth
        self.seq_in = seq_in
        self.seq_out= seq_out
        self.save = save
        self.img_resize = img_resize
        self.tag = tag
        
        _hash = '-'.join(dpaths)+str(tag)+str(seq_in)+str(seq_out)+str(load_img)+str(load_depth)+str(img_resize[0])+str(img_resize[1]) #dpaths: the path of data
        savefile = './'+hashlib.sha256(_hash.encode()).hexdigest()+'.pkl'
        print("Save file is:", savefile)
        
        if save:
            if os.path.exists(savefile):
                print("Save file:",savefile," exists... loading it")
                with open(savefile,'rb') as f: 
                    data = pickle.load(f)
                self._x = data['_x'] #x is the 2d input joins
                self._xA = data['_xA'] #xA is the adjacency matrix
                self._y = data['_y'] #y is the 3d target poses
                self._x3 = data['_x3'] #x3 is the 3d input joins
                if load_img:
                    self._vrgb = data['_vrgb']
                if load_depth:
                    self._vdepth = data['_vdepth']
                    
            else:
                #This need to be computed once, the Adjacency matrix of the skeleton 
                LIMBS = [
                    (0, 1),  # head_center -> neck
                    (1, 2),  # neck -> right_clavicle
                    (2, 3),  # right_clavicle -> right_shoulder
                    (3, 4),  # right_shoulder -> right_elbow
                    (4, 5),  # right_elbow -> right_wrist
                    (1, 6),  # neck -> left_clavicle
                    (6, 7),  # left_clavicle -> left_shoulder
                    (7, 8),  # left_shoulder -> left_elbow
                    (8, 9),  # left_elbow -> left_wrist
                    (1, 10),  # neck -> spine0
                    (10, 11),  # spine0 -> spine1
                    (11, 12),  # spine1 -> spine2
                    (12, 13),  # spine2 -> spine3
                    (13, 14),  # spine3 -> spine4
                    (14, 15),  # spine4 -> right_hip
                    (15, 16),  # right_hip -> right_knee
                    (16, 17),  # right_knee -> right_ankle
                    (14, 18),  # spine4 -> left_hip
                    (18, 19),  # left_hip -> left_knee
                    (19, 20)  # left_knee -> left_ankle
                ]

                A = np.zeros((21,21))
                for i,j in LIMBS: 
                    A[i,j] = 1
                    A[j,i] = 1

                G = nx.from_numpy_matrix(A)
                Anorm = nx.normalized_laplacian_matrix(G).toarray()

                _x = []
                _xA = []
                if load_img:
                    _vrgb = []
                if load_depth:
                    _vdepth = []
                _y = []
                _x3 = []

                for dpath in dpaths:

                    info = pickle.load(open(dpath + 'info_frames.pickle', 'rb'))
                    info_npz = np.load(dpath+'info_frames.npz')

                    if load_img:
                        _rgb = []
                    if load_depth:
                        _depth = []
              
                    _d2 = [] 
                    _d2A = []
                    _d3 = []
                    for fm_id in range(len(info)):

                        if load_img:
                            rgb = cv2.resize(cv2.cvtColor(cv2.imread(dpath+'{:05d}'.format(fm_id)+'.jpg'), cv2.COLOR_BGR2RGB),img_resize, interpolation = cv2.INTER_LANCZOS4)/255.0

                        if load_depth:
                            depth = cv2.resize(cv2.cvtColor(cv2.imread(dpath+'{:05d}'.format(fm_id)+'.png'), cv2.COLOR_BGR2RGB),img_resize, interpolation = cv2.INTER_LANCZOS4)/255.0
                        
                        d2 = info_npz['joints_2d'][fm_id] 
                        d3 = info_npz['joints_3d_cam'][fm_id]

                        if load_img:
                            _rgb.append(rgb[None,...].transpose(0,3,1,2))
                        if load_depth:
                            _depth.append(depth[None,...].transpose(0,3,1,2))
                        _d2.append(d2[None,...])
                        _d2A.append(A[None,...])
                        _d3.append(d3[None,...])

                    #Create the sequences using a moving window of (seq_in_seq_out)
                    kk =0 
                    for k in range(0,len(info)-(seq_in+seq_out),1):
                        kk =k

                    pbar = tqdm(total=kk) 

                    for i in range(0,len(info)-(seq_in+seq_out),1):
                        _x.append(torch.from_numpy(np.concatenate(_d2[i:i+seq_in],axis=0)).type(torch.float32))
                        if load_img:
                            _vrgb.append(torch.from_numpy(np.concatenate(_rgb[i:i+seq_in],axis=0)).type(torch.float32))
                        if load_depth:
                            _vdepth.append(torch.from_numpy(np.concatenate(_depth[i:i+seq_in],axis=0)).type(torch.float32))

                        _xA.append(torch.from_numpy(np.concatenate(_d2A[i:i+seq_in],axis=0)).type(torch.float32))
                        _y.append(torch.from_numpy(np.concatenate(_d3[i+seq_in:i+seq_in+seq_out],axis=0)).type(torch.float32))
                        _x3.append(torch.from_numpy(np.concatenate(_d3[i:i+seq_in],axis=0)).type(torch.float32))
                        pbar.update(1)
                    pbar.close()

                self._x = _x
                self._xA = _xA
                self._y = _y
                self._x3 = _x3
                if load_img:
                    self._vrgb = _vrgb
                if load_depth:
                    self._vdepth = _vdepth
                    
                if save:
                    with open(savefile,'wb') as f : 
                        data = {}
                        data['_x'] = _x
                        data['_xA'] = _xA
                        data['_y'] = _y
                        data['_x3'] = _x
                        if load_img:
                            data['_vrgb'] = _vrgb
                        if load_depth:
                            data['_vdepth'] = _vdepth
                        pickle.dump(data,f)        

    def __len__(self):
        return len(self._x)

    def __getitem__(self, index):
        output = [self._x[index],self._xA[index],self._y[index],self._x3[index]]
        if self.load_img:
            output.append(self._vrgb[index])
        if self.load_depth:
            output.append(self._vdepth[index])
        return output

In [None]:
if args.dataset == 'GTA_IM':
    datasections = glob.glob('./GTAIM/*/')
    load_img = False
    if args.video_back or args.background_back:
        load_img = True
    dataset_train = D2_D3_GTA_IM(datasections[:8],tag='train', seq_in=args.obs_seq_len, seq_out=args.pred_seq_len,load_img = load_img, load_depth =False, img_resize=(args.im_w,args.im_h))    print("dataset_test")
    dataset_test = D2_D3_GTA_IM(datasections[8:10],tag='test', seq_in=args.obs_seq_len, seq_out=args.pred_seq_len,load_img = load_img, load_depth =False, img_resize=(args.im_w,args.im_h))

elif args.dataset == 'PROX':
    record_sections = glob.glob('PROX/recordings/*/')
    keyp_sections = glob.glob('PROX/keypoints/*/')
    load_img = False
    if args.video_back or args.background_back:
        load_img = True
    dataset_train = D2_D3_PROX(record_sections[:52],keyp_sections[:52],tag='train', seq_in=args.obs_seq_len, seq_out=args.pred_seq_len,load_img = load_img, load_depth =False, img_resize=(args.im_w,args.im_h))
    dataset_test = D2_D3_PROX(record_sections[52:60],keyp_sections[52:60],tag='test', seq_in=args.obs_seq_len, seq_out=args.pred_seq_len,load_img = load_img, load_depth =False, img_resize=(args.im_w,args.im_h))
    
loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle =True, num_workers=0, drop_last=True)
loader_val = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, shuffle =False, num_workers=0, drop_last=True)

print("Normalization")
all_train_2d = []
for cnt,batch in enumerate(loader_train): 
    if args.background_back or args.video_back:
        X,XA,y,x3,scene = batch
    else:
        X,XA,y,x3 = batch

    all_train_2d.extend(X.flatten().numpy())
    
all_train_2d = np.asarray(all_train_2d)
_mean  = all_train_2d.mean()
_std = all_train_2d.std()

print("vision normalization")
if args.background_back:
    v_mean = [0.485, 0.456, 0.406]
    v_std = [0.229, 0.224, 0.225]
elif args.video_back:
    v_mean = [0.43216, 0.394666, 0.37645]
    v_std = [0.22803, 0.22145, 0.216989] 

### Processing

In [16]:
ref_skel = copy.deepcopy(dataset_train._x3[0][0])

#KENDALL:

def CenteredScaled(X):
    # Centering: Subtract the mean of each joint across all frames
    X_reshaped = X - torch.mean(X, dim=0)
    # Calculate the "centered" Frobenius norm for each joint
    normX = torch.norm(X_reshaped, dim=(1, 2), p='fro')
    # Scale to equal (unit) norm for each joint
    X_scaled = X_reshaped / normX[:, None, None]
    return X_scaled

def inv_exp(X, Y):
    # Move tensors to CPU for procrustes analysis
    X_cpu = X.cpu().numpy()
    Y_cpu = Y.cpu().numpy()
    
    n_frames, joints, dim_k = X.shape
    num_frames = X_cpu.shape[0]
    Total_dim = X_cpu.shape[1]*3
    
    X_cpu = X_cpu.reshape(num_frames,Total_dim)
    Y_cpu = X_cpu.reshape(num_frames,Total_dim)
    
    # Check if X and Y have more than one row
    if X_cpu.shape[0] <= 1 or Y_cpu.shape[0] <= 1:
        # Handle the case when matrices have one row
        pass
        return Y

    # Apply Procrustes to align Y on X
    _, Y_aligned, _ = procrustes(X_cpu, Y_cpu)

    # Calculate the invExp matrix on the aligned matrices
    skeleton = X_cpu.dot(Y_aligned.T)
    tr = abs(skeleton.trace())
    if tr > 1:
        tr = 1
    teta_invexp = math.acos(tr)
    if math.sin(teta_invexp) < 0.0001:
        teta_invexp = 0.1
    invExp = (teta_invexp / math.sin(teta_invexp)) * (Y_aligned - (math.cos(teta_invexp)) * X_cpu)
    np_inv = np.array(invExp)
    
    # Reshape back to the original shape
    np_inv = np_inv.reshape((n_frames, joints, dim_k))
    
    return torch.from_numpy(np_inv).to(X.device)


#LIE: 

def calculate_global_transformation(skeleton, ref_skeleton):
        
    #Ensure skeleton and ref_skeleton have compatible shapes for the dot product
    skeleton = skeleton.reshape(-1, 63) if skeleton.shape[1] == 63 else skeleton
    ref_skeleton = ref_skeleton.reshape(-1, 63) if ref_skeleton.shape[1] == 63 else ref_skeleton

    # Calculate global rotation
    rotation_matrix = np.dot(skeleton, np.transpose(ref_skeleton))
    u, s, v = np.linalg.svd(rotation_matrix, full_matrices=False)
    rotation_matrix = np.dot(v.T, u.T)
    # Calculate global translation
    translation_vector = np.mean(ref_skeleton, axis=1) - np.dot(rotation_matrix, np.mean(skeleton, axis=1))
    return rotation_matrix, translation_vector

def to_SE3(rotation_matrix, translation_vector):
    se3_matrix = np.eye(4)
    se3_matrix[:3, :3] = rotation_matrix[:3, :3]  # Take the top-left 3x3 block
    se3_matrix[:3, 3] = translation_vector[:3]  # Take the first 3 elements
    return se3_matrix

def extract_point_in_SE3(se3_matrix):
    return se3_matrix

def derive_tangent_space(rotation_matrix, translation_vector):
    # Ensure rotation_matrix is 3x3
    rotation_matrix = rotation_matrix[:3, :3]
    # Create a 3x3 identity matrix
    identity_matrix = np.eye(3)
    # Ensure translation_vector is a column vector
    translation_vector = translation_vector[:3].reshape(-1, 1)
    # Calculate the skew-symmetric matrix directly
    skew_symmetric_matrix = rotation_matrix - identity_matrix
    skew_symmetric_matrix_flat = skew_symmetric_matrix.flatten()
    tangent_space = np.zeros((4, 4))
    tangent_space[:3, :3] = rotation_matrix
    tangent_space[:3, 3] = translation_vector.flatten()
    tangent_space[3, :3] = skew_symmetric_matrix_flat[:3]  # Take the first 3 elements
    return tangent_space

def lie_group_and_algebra_transform_s(skeleton, ref_skeleton):
    skeleton = skeleton.cpu().numpy() if isinstance(skeleton, torch.Tensor) else skeleton
    ref_skeleton = ref_skeleton.cpu().numpy() if isinstance(ref_skeleton, torch.Tensor) else ref_skeleton
    # Calculate global transformation (rotation and translation)
    rotation_matrix, translation_vector = calculate_global_transformation(skeleton, ref_skeleton)
    se3_matrix = to_SE3(rotation_matrix, translation_vector)
    # Extract a representative point in Lie group (SE(3))
    point_in_SE3 = extract_point_in_SE3(se3_matrix)
    # Derive tangent space (Lie algebra) associated with SE(3)
    tangent_space = derive_tangent_space(rotation_matrix, translation_vector)
    return tangent_space #point_in_SE3, tangent_space

def lie_group_and_algebra_transform(frames, ref_skeleton):    
    result = np.array([lie_group_and_algebra_transform_s(skeleton, ref_skeleton) for skeleton in frames])
    return result

# MODEL

In [18]:
def position_embedding(input, d_model):
    input = input.view(-1, 1)
    dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1)
    sin = torch.sin(input / 10000 ** (2 * dim / d_model))
    cos = torch.cos(input / 10000 ** (2 * dim / d_model))

    out = torch.zeros((input.shape[0], d_model), device=input.device)
    out[:, ::2] = sin
    out[:, 1::2] = cos
    return out

def sinusoid_encoding_table(max_len, d_model):
    pos = torch.arange(max_len, dtype=torch.float32)
    out = position_embedding(pos, d_model)
    return out

class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, h):
        """
        param:
        d_model: Output dimensionality of the model
        d_k: Dimensionality of queries and keys
        d_v: Dimensionality of values
        h: Number of heads
        """
        super(ScaledDotProductAttention, self).__init__()
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)
        self.fc_o = nn.Linear(h * d_v, d_model)

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h

        self.init_weights(gain=1.0)

    def init_weights(self, gain=1.0):
        nn.init.xavier_normal_(self.fc_q.weight, gain=gain)
        nn.init.xavier_normal_(self.fc_k.weight, gain=gain)
        nn.init.xavier_normal_(self.fc_v.weight, gain=gain)
        nn.init.xavier_normal_(self.fc_o.weight, gain=gain)
        nn.init.constant_(self.fc_q.bias, 0)
        nn.init.constant_(self.fc_k.bias, 0)
        nn.init.constant_(self.fc_v.bias, 0)
        nn.init.constant_(self.fc_o.bias, 0)

    def forward(self, queries, keys, values):
        """
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :return:
        """
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]
        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)

        att = torch.softmax(att, -1)

        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)
        return out
    
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, h, dff=2048, dropout=.1):
        super(MultiHeadAttention, self).__init__()

        self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
        self.dropout = nn.Dropout(p=dropout)
        self.layer_norm = nn.LayerNorm(d_model)
        self.fc = nn.Sequential(*[nn.Linear(d_model, dff), nn.ReLU(inplace=True), nn.Dropout(p=dropout),nn.Linear(dff, d_model)])

    def forward(self, queries, keys, values):
        att = self.attention(queries, keys, values)
        att = self.dropout(att)
        att = self.fc(att)
        att = self.dropout(att)
        return self.layer_norm(queries + att)
    
class EncoderSelfAttention(nn.Module):
    def __init__(self, device, d_model, d_k, d_v, n_head, dff=2048, dropout_transformer=.1, n_module=6):
        super(EncoderSelfAttention, self).__init__()
        self.encoder = nn.ModuleList([MultiHeadAttention(d_model, d_k, d_v, n_head, dff, dropout_transformer) for _ in range(n_module)])
        self.device = device
    
    def forward(self, x): 
        in_encoder = x + sinusoid_encoding_table(x.shape[1], x.shape[2]).expand(x.shape).to(self.device)
        for l in self.encoder:
            in_encoder = l(in_encoder, in_encoder, in_encoder)
            
        return in_encoder

In [19]:
class SelfAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(input_dim, hidden_dim)
        self.key = nn.Linear(input_dim, hidden_dim)
        self.value = nn.Linear(input_dim, hidden_dim)

    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        attention_scores = F.softmax(torch.matmul(q, k.transpose(-2, -1)) / (x.size(-1) ** 0.5), dim=-1)
        output = torch.matmul(attention_scores, v)

        return output

class CoordinatesTransformer_k(nn.Module):
    def __init__(self, device, dropout1d,input_dim=63, hidden_size=256, output_dim=256):
        super(CoordinatesTransformer_k, self).__init__()
        self.device = device
        self.dropout1d = dropout1d
        self.self_attention = SelfAttention(input_dim, hidden_size)
        # Feedforward Layers
        self.linear1 = nn.Linear(hidden_size, hidden_size) 
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout1d)
        self.linear2 = nn.Linear(hidden_size, output_dim)

    def forward(self, input_tensor_k):
        self_attended = self.self_attention(input_tensor_k)
        output = self.linear1(self_attended)
        output = self.relu(output)
        output = self.dropout(output)
        output = self.linear2(output)
        return output

In [20]:
class SelfAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(input_dim, hidden_dim)
        self.key = nn.Linear(input_dim, hidden_dim)
        self.value = nn.Linear(input_dim, hidden_dim)

    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        attention_scores = F.softmax(torch.matmul(q, k.transpose(-2, -1)) / (x.size(-1) ** 0.5), dim=-1)
        output = torch.matmul(attention_scores, v)
        return output

class CoordinatesTransformer_l(nn.Module):
    def __init__(self, device, dropout1d, input_dim=16, hidden_size=256, output_dim=256):
        super(CoordinatesTransformer_l, self).__init__()
        self.device = device
        self.dropout1d = dropout1d
        # Self-Attention Layer
        self.self_attention = SelfAttention(input_dim, hidden_size)
        # Feedforward Layers
        self.linear1 = nn.Linear(hidden_size, hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout1d)
        self.linear2 = nn.Linear(hidden_size, output_dim)

    def forward(self, input_tensor_l):
        # Assuming input_tensor_k shape: [batch_size, sequence_length, input_dim]
        self_attended = self.self_attention(input_tensor_l)
        output = self.linear1(self_attended)
        output = self.relu(output)
        output = self.dropout(output)
        output = self.linear2(output)
        return output

In [21]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, nhead):
        super(DecoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead)
        self.norm1 = nn.LayerNorm(d_model) 
        self.norm2 = nn.LayerNorm(d_model)
        self.feed_forward = nn.Linear(d_model, d_model)

    def forward(self, x, e_output):
        # Multihead self-attention
        attn_output, _ = self.self_attn(x, x, x)
        x = x + self.norm1(attn_output)
        # Feed forward
        ff_output = self.feed_forward(x)
        x = x + self.norm2(ff_output)
        return x

    
class DecoderTransformer(nn.Module):
    def __init__(self, in_size, embed_size, hidden_size, d_model=512, dropout_val=dropout_val, batch_size=1, nhead=8, num_layers=6):
        super(DecoderTransformer, self).__init__()
        self.in_size                = in_size 
        self.stochastic_out_size    = stochastic_out_size
        self.hidden_size            = hidden_size
        self.batch_size             = batch_size
        self.embed_size             = embed_size
        self.seq_length             = T_pred
        self.dropout_val            = dropout_val
        self.d_model                = d_model
        self.nhead                  = nhead
        self.num_layers             = num_layers

        self.embedder_rho = nn.Linear(63, 200)
        self.fC_mu = nn.Sequential(nn.Linear(self.hidden_size + self.hidden_size + 2, int(self.hidden_size/2), bias=True),nn.ReLU(),nn.Dropout(p=dropout_val),nn.Linear(int(self.hidden_size/2), self.stochastic_out_size, bias=True))
        self.dropout = nn.Dropout(dropout_val)        
        self.embedding = nn.Linear(200, d_model)
        self.layers = nn.ModuleList([DecoderLayer(d_model, nhead) for _ in range(num_layers)])
        self.output_layer = nn.Linear(d_model, 514)
                
    def forward(self, x, encoder_outputs):        
        embedding = self.embedder_rho(x.view(x.shape[0], 1, -1))                    
        embedding = F.relu(self.dropout(embedding))        
        x = self.embedding(embedding)
        for layer in self.layers: #6 layers
            x = layer(x, encoder_outputs)
            if torch.isnan(x).any():
                print("NaN values found in x layer")
            if torch.isinf(x).any():
                print("infinite values found in x layer")
        output = self.output_layer(x)
        prediction = self.fC_mu(output.squeeze(0))         
        return prediction

In [22]:
class Model(nn.Module):
    def __init__(self, in_size, embed_size, hidden_size, batch_size, d_model=512, d_ff=2048, h=8, dropout_val=dropout_val, N=6, input_dim=512):
        super(Model, self).__init__()
        torch.cuda.empty_cache()
        
        self.encoder_k = CoordinatesTransformer_k(device, dropout1d=dropout_val)
        self.encoder_k.apply(init_weights)
        self.encoder_l = CoordinatesTransformer_l(device, dropout1d=dropout_val)
        self.encoder_l.apply(init_weights)
        self.decoder = DecoderTransformer(in_size, embed_size, hidden_size, num_layers=6, nhead=8)
        self.decoder.apply(init_weights)
        self.crossAttention = crossAttention(N=6,d_model=256, d_ff=2048, h=8, dropout=0.1)
                    
        if device.type=='cuda':
            self.encoder_k.cuda()
            self.encoder_l.cuda()
            self.decoder.cuda()

    def forward(self, input_tensor_k, input_tensor_l, input_tensor, output_tensor, batch_size, train_mode):       
  
        batch_size     = int(input_tensor_k.size(0))
        encoder_outputs = torch.zeros(batch_size, T_obs, hidden_size).cuda()
        start_point = (input_tensor[:,0,:]).to(device).clone().detach()
        if startpoint_mode=="on":
            input_tensor[:,0,:]    = 0
        
        encoder_outputs_k = self.encoder_k(input_tensor_k)
        encoder_outputs_l = self.encoder_l(input_tensor_l)

        src_mask = None
        obd_enc_mask = None
        cross_ouput = self.crossAttention( encoder_outputs_k, encoder_outputs_l, src_mask, obd_enc_mask)
        e_outputs=cross_ouput
        
        decoder_input = input_tensor[:,-1,:] 
        outputs                         = torch.zeros(batch_size, T_pred , in_size).cuda() 
        stochastic_outputs              = torch.zeros(batch_size, T_pred , stochastic_out_size).cuda()
        teacher_force                   = 1
        epsilonX                        = Normal(torch.zeros(batch_size,1),torch.ones(batch_size,1))
        teacher_force                   = int(random.random() < teacher_forcing_ratio) if train_mode else 0

        for t in range(0, output_window_size-1):
            output = self.decoder(decoder_input, e_outputs)                    
            # Reparameterization Trick :)
            decoder_output              = torch.zeros(batch_size,1,63).cuda()                        
            for i in range(0,out_size):
                epsilonX               = Normal(torch.zeros(batch_size,1),torch.ones(batch_size,1))
                decoder_output[:,:,i]  = output[:,:,2*i] + epsilonX.sample((avg_n_path_eval,1)).view(-1,avg_n_path_eval,1).mean(-2).cuda() * output[:,:,2*i+1]
            outputs[:,t,:]               = decoder_output.squeeze(1)
            stochastic_outputs[:,t,:]    = output.squeeze(1)
            predictionns = outputs
            predictionns = outputs.reshape(args.batch_size,T_pred,21,3)
            
        return predictionns

In [None]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.2, 0.2)

Train

In [23]:
def train(epoch):
    global metrics,constant_metrics
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    net = Model(in_size, embed_size, hidden_size, dropout_val=dropout_val, batch_size=batch_size)
    net.to(device)
    
    model.train()
    loss_train = 0 

    for cnt,batch in enumerate(loader_train): 
        #get data
        if args.background_back:
            X,XA,y,x3,scene = batch
            scene =scene[:,-1,...].cuda()
        elif args.video_back:
            X,XA,y,x3,scene = batch
            scene = scene.view(scene.shape[0],scene.shape[1],scene.shape[2],scene.shape[3],scene.shape[4]).cuda()
        else: 
            X,XA,y,x3 = batch

        X =(X-_mean)/_std
        X,XA,y,x3 = X.cuda(),XA.cuda(),y.cuda(),x3.cuda()
        
        encoder_inputs = x3.reshape(args.batch_size, args.obs_seq_len,-1)
        # Convert to Kendall shape space representations:
        ref_skel = copy.deepcopy(dataset_train._x3[0])
        x3_k = [CenteredScaled(frame) for frame in x3]
        x3_k = torch.stack(x3_k)
        x3_k_tg = [inv_exp(ref_skel,frame) for frame in x3_k]
        x3_k_tg = torch.stack(x3_k_tg)
        #print("x3_k_tg shape:", x3_k_tg.shape) #torch.Size([16, 30, 21, 3])
        encoder_inputs_k = x3_k_tg.reshape(args.batch_size, args.obs_seq_len,-1)

        # Convert to Lie group and then Lie algebra:
        ref_skel_l = copy.deepcopy(dataset_train._x3[0][0])
        x3_l_tg = [lie_group_and_algebra_transform(frame,ref_skel_l) for frame in x3]
        x3_l_tg = np.array(x3_l_tg)
        x3_l_tg= torch.tensor(x3_l_tg)
        #print("x3_l_tg_tensor shape:", x3_l_tg.shape) #torch.Size([16, 30, 4, 4])
        encoder_inputs_l = x3_l_tg.reshape(args.batch_size, args.obs_seq_len, -1)        
    
        decoder_outputs = y.view(args.batch_size, args.pred_seq_len,63)

        optimizer.zero_grad()
        encoder_inputs_k = encoder_inputs_k.float().to(device)
        encoder_inputs_l = encoder_inputs_l.float().to(device)
        encoder_inputs = encoder_inputs.float().to(device)
        decoder_outputs = decoder_outputs.float().to(device)
        
        prediction= net(encoder_inputs_k, encoder_inputs_l, encoder_inputs, decoder_outputs, batch_size,train_mode=True)
        
        loss = loss(prediction,y)
        
        loss.backward()
        if args.clip_grad is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(),args.clip_grad)
        optimizer.step()
        
        loss_train += loss.item()
        if cnt%args.log_frq == 0 and cnt!=0:
            print('Epoch:', epoch,'\t Train Loss:',loss_train/(cnt+1))
            
    metrics['train_loss'].append(loss_train/(cnt+1))

In [24]:
# Metrics

def MPJPE(V_pred,V_trgt):
    return torch.linalg.norm(V_trgt- V_pred,dim=-1).mean() 

def MPJPE_torso(V_pred,V_trgt,torso_joint = 13): 
    return torch.linalg.norm(V_trgt[:,:,torso_joint,:]- V_pred[:,:,torso_joint,:],dim=-1).mean() 

def MPJPE_timelimit(V_pred,V_trgt,time_limit):
    T = V_pred.shape[1]
    T = max(int(T*time_limit),1)
    return torch.linalg.norm(V_trgt[:,:T,...]- V_pred[:,:T,...],dim=-1).mean()

def MPJPE_torso_timelimit(V_pred,V_trgt,time_limit,torso_joint = 13 ): 
    T = V_pred.shape[1]
    T = max(int(T*time_limit),1)
    return torch.linalg.norm(V_trgt[:,:T,torso_joint,:]- V_pred[:,:T,torso_joint,:],dim=-1).mean() 

Vald

In [26]:
def vald(epoch):
    global metrics,constant_metrics
    
    model.eval()
    loss_val = 0 
    mpjpe_avg = 0 
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    net = Model(in_size, embed_size, hidden_size, dropout_val=dropout_val, batch_size=batch_size)
    net.to(device)

    with torch.no_grad(): #Faster without grad 
        for cnt,batch in enumerate(loader_val): 
            if args.background_back:
                X,XA,y,x3,scene = batch
                scene =scene[:,-1,...].cuda()
            elif args.video_back:
                X,XA,y,x3,scene = batch
                scene = scene.view(scene.shape[0],scene.shape[1],scene.shape[2],scene.shape[3],scene.shape[4]).cuda()
            else: 
                X,XA,y,x3 = batch

            X =(X-_mean)/_std
            X,XA,y,x3 = X.cuda(),XA.cuda(),y.cuda(),x3.cuda()
            
            encoder_inputs = x3.reshape(args.batch_size, args.obs_seq_len,-1)
            ref_skel = copy.deepcopy(dataset_train._x3[0])
            x3_k = [CenteredScaled(frame) for frame in x3]
            x3_k = torch.stack(x3_k)
            x3_k_tg = [inv_exp(ref_skel,frame) for frame in x3_k]
            x3_k_tg = torch.stack(x3_k_tg)
            encoder_inputs_k = x3_k_tg.reshape(args.batch_size, args.obs_seq_len,-1)
            ref_skel_l = copy.deepcopy(dataset_train._x3[0][0])
            x3_l_tg = [lie_group_and_algebra_transform(frame,ref_skel_l) for frame in x3]
            x3_l_tg = np.array(x3_l_tg)
            x3_l_tg= torch.tensor(x3_l_tg)
            encoder_inputs_l = x3_l_tg.reshape(args.batch_size, args.obs_seq_len, -1)
            decoder_outputs = y.view(args.batch_size, args.pred_seq_len,63)

            optimizer.zero_grad()

            encoder_inputs_k = encoder_inputs_k.float().to(device)
            encoder_inputs_l = encoder_inputs_l.float().to(device)
            encoder_inputs = encoder_inputs.float().to(device)
            decoder_outputs = decoder_outputs.float().to(device)

            prediction= net(encoder_inputs_k, encoder_inputs_l, encoder_inputs, decoder_outputs, batch_size,train_mode=0)
        
            loss = loss(prediction,y)
            mpjpe = MPJPE(prediction,y)
            
            loss_val += loss.item()
            mpjpe_avg += mpjpe.item()
            
            if cnt%args.log_frq == 0 and cnt!=0:
                print('Epoch:', epoch,'\t Val Loss:',loss_val/(cnt+1),'\t MPJPE:',mpjpe_avg/(cnt+1))

    metrics['val_loss'].append(loss_val/(cnt+1))
    metrics['val_mpjpe'].append(mpjpe_avg/(cnt+1))

    if  metrics['val_loss'][-1]< constant_metrics['min_val_loss']:
        constant_metrics['min_val_loss'] =  metrics['val_loss'][-1]
        constant_metrics['min_val_epoch'] = epoch
        torch.save(model.state_dict(),checkpoint_dir+'val_loss_best.pth')  
        
    if  metrics['val_mpjpe'][-1]< constant_metrics['min_val_mpjpe']:
        constant_metrics['min_val_mpjpe'] =  metrics['val_mpjpe'][-1]
        constant_metrics['min_val_mpjpe_epoch'] = epoch
        torch.save(model.state_dict(),checkpoint_dir+'val_mpjpe_best.pth')  

### Creating the model

In [28]:
print('Creating the model ....')
model                       = Model(in_size, embed_size, hidden_size, dropout_val=dropout_val, batch_size=batch_size)
model                       = nn.DataParallel(model).cuda()

Creating the model ....


### Train / Test

In [None]:
optimizer = optim.Adam(model.parameters(), lr=args.lr) #optimizer = optim.SGD(model.parameters(),lr=args.lr)

if args.use_lrschd: #Use lr rate scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_sh_rate, gamma=0.2)
    
checkpoint_dir = './checkpoint/'+args.tag+'/'

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
    
with open(checkpoint_dir+'args.pkl', 'wb') as fp:
    pickle.dump(args, fp)
    
print('Data and model loaded')
print('Checkpoint dir:', checkpoint_dir)

In [None]:
args.video_back = False 
args.background_back = False
args.torso_joint=13

In [None]:
#Train or eval? 
if not args.eval_only:
    print('Training started ...')

    metrics = {'train_loss':[],  'val_loss':[], 'val_mpjpe':[]}
    constant_metrics = {'min_val_epoch':-1, 'min_val_loss':9999999999999999,'min_val_mpjpe_epoch':-1, 'min_val_mpjpe':9999999999999999}

    for epoch in range(args.num_epochs):
    
        with torch.autograd.set_detect_anomaly(True):
            train(epoch)
        vald(epoch)
                    
        if args.use_lrschd:
            scheduler.step()

        print('*'*30)
        print('Epoch:',args.tag,":", epoch)
        for k,v in metrics.items():
            if len(v)>0:
                print(k,v[-1])

        print(constant_metrics)
        print('*'*30)

        with open(checkpoint_dir+'metrics.pkl', 'wb') as fp:
            pickle.dump(metrics, fp)

        with open(checkpoint_dir+'constant_metrics.pkl', 'wb') as fp:
            pickle.dump(constant_metrics, fp) 
else:
    def eval_metrics():
        model.eval()
        loss_val = 0
        
        mpjpe_avg = 0 
        mpjpe_qaurt_avg =0 
        mpjpe_half_avg  =0
        mpjpe_3quart_avg =0 

        mpjpe_path_avg = 0 
        mpjpe_path_qaurt_avg =0 
        mpjpe_path_half_avg  =0
        mpjpe_path_3quart_avg =0 

        with torch.no_grad(): #Faster without grad 
            for cnt,batch in enumerate(loader_val): 
                if args.background_back:
                    X,XA,y,x3,scene = batch
                    scene =scene[:,0,...].cuda()
                elif args.video_back:
                    X,XA,y,x3,scene = batch
                    scene = scene.view(scene.shape[0],scene.shape[1]*scene.shape[2],scene.shape[3],scene.shape[4]).cuda()
                else: 
                    X,XA,y,x3 = batch

                X =(X-_mean)/_std
                X,XA,y,x3 = X.cuda(),XA.cuda(),y.cuda(),x3.cuda()
                
                encoder_inputs = x3.reshape(args.batch_size, args.obs_seq_len,-1)

                ref_skel = copy.deepcopy(dataset_train._x3[0])
                x3_k = [CenteredScaled(frame) for frame in x3]
                x3_k = torch.stack(x3_k)
                x3_k_tg = [inv_exp(ref_skel,frame) for frame in x3_k]
                x3_k_tg = torch.stack(x3_k_tg)
                encoder_inputs_k = x3_k_tg.reshape(args.batch_size, args.obs_seq_len,-1) #torch.Size([16, 50, 63])okk

                ref_skel_l = copy.deepcopy(dataset_train._x3[0][0])
                x3_l_tg = [lie_group_and_algebra_transform(frame,ref_skel_l) for frame in x3]
                x3_l_tg = np.array(x3_l_tg)
                x3_l_tg= torch.tensor(x3_l_tg)
                encoder_inputs_l = x3_l_tg.reshape(args.batch_size, args.obs_seq_len, -1)

                decoder_outputs = y.view(args.batch_size, args.pred_seq_len,63)

                optimizer.zero_grad()

                encoder_inputs_k = encoder_inputs_k.float().to(device)
                encoder_inputs_l = encoder_inputs_l.float().to(device)
                encoder_inputs = encoder_inputs.float().to(device)
                decoder_outputs = decoder_outputs.float().to(device)

                prediction= net(encoder_inputs_k, encoder_inputs_l, encoder_inputs, decoder_outputs, batch_size,train_mode=0)
            
                loss_val += loss(prediction,y).item()
                
                mpjpe_avg+= MPJPE(prediction,y).item()
                mpjpe_qaurt_avg+=MPJPE_timelimit(prediction,y,0.25).item()
                mpjpe_half_avg+=MPJPE_timelimit(prediction,y,0.50).item()
                mpjpe_3quart_avg+=MPJPE_timelimit(prediction,y,0.75).item()
                
                mpjpe_path_avg+= MPJPE_torso(prediction,y,torso_joint=torso_joint).item() 
                mpjpe_path_qaurt_avg+=MPJPE_torso_timelimit(prediction,y,0.25,torso_joint=torso_joint).item()  
                mpjpe_path_half_avg+=MPJPE_torso_timelimit(prediction,y,0.50,torso_joint=torso_joint).item() 
                mpjpe_path_3quart_avg+=MPJPE_torso_timelimit(prediction,y,0.75,torso_joint=torso_joint).item()  
                
        loss_val /= (cnt+1)
        mpjpe_avg /= (cnt+1)
        mpjpe_qaurt_avg /= (cnt+1)
        mpjpe_half_avg /= (cnt+1)
        mpjpe_3quart_avg /= (cnt+1)

        mpjpe_path_avg /= (cnt+1)
        mpjpe_path_qaurt_avg /= (cnt+1)
        mpjpe_path_half_avg  /= (cnt+1)
        mpjpe_path_3quart_avg /= (cnt+1)
        
        mpjpe_avg = int(mpjpe_avg*1000)
        mpjpe_qaurt_avg = int(mpjpe_qaurt_avg*1000)
        mpjpe_half_avg = int(mpjpe_half_avg*1000)
        mpjpe_3quart_avg = int(mpjpe_3quart_avg*1000)

        mpjpe_path_avg = int(mpjpe_path_avg*1000)
        mpjpe_path_qaurt_avg = int(mpjpe_path_qaurt_avg*1000)
        mpjpe_path_half_avg  =int(mpjpe_path_half_avg*1000)
        mpjpe_path_3quart_avg = int(mpjpe_path_3quart_avg*1000)
                
        print('#'*30)
        print('All results are in mm')
        print('*'*30)
        print('MPJPE POSE: 0.25\t 0.50\t 0.75\t full')
        print('MPJPE: ',mpjpe_qaurt_avg,'\t ',mpjpe_half_avg,'\t ',mpjpe_3quart_avg,'\t ',mpjpe_avg,'')
        print('*'*30)
        print('MPJPE PATH: 0.25\t 0.50\t 0.75\t full')
        print('PATH: ',mpjpe_path_qaurt_avg,'\t ',mpjpe_path_half_avg,'\t ',mpjpe_path_3quart_avg,'\t ',mpjpe_path_avg,'')
        print('#'*30)
        print('#'*30)

        f = open(checkpoint_dir+"eval.csv", "w")
        eval_id = ''
        for k,v in vars(args).items():
            if k == 'eval_only' or k =='torso_joint' or k =='tag':
                continue
            eval_id += str(k)+str(v)
        eval_result = [eval_id,',',original_tag,',',mpjpe_path_qaurt_avg,',',mpjpe_path_half_avg,',',mpjpe_path_3quart_avg,',',mpjpe_path_avg,',',mpjpe_qaurt_avg,',',mpjpe_half_avg,',',mpjpe_3quart_avg,',',mpjpe_avg,',',FDE,',',ADE,',',STB,'\n']
        eval_result_row = ''
        for ss in eval_result:
            eval_result_row+= str(ss)
        f.write(eval_result_row)
        f.close()
    model.load_state_dict(torch.load(checkpoint_dir+'val_mpjpe_best.pth'))
    eval_metrics()

In [None]:
def eval_metrics():
    model.eval()
    loss_val = 0

    mpjpe_avg = 0 
    mpjpe_qaurt_avg =0 
    mpjpe_half_avg  =0
    mpjpe_3quart_avg =0 

    mpjpe_path_avg = 0 
    mpjpe_path_qaurt_avg =0 
    mpjpe_path_half_avg  =0
    mpjpe_path_3quart_avg =0 
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    net = Model(in_size, embed_size, hidden_size, dropout_val=dropout_val, batch_size=batch_size)
    net.to(device)

    with torch.no_grad(): #Faster without grad 
        for cnt,batch in enumerate(loader_val): 
            if args.background_back:
                X,XA,y,x3,scene = batch
                scene =scene[:,0,...].cuda()
            elif args.video_back:
                X,XA,y,x3,scene = batch
                scene = scene.view(scene.shape[0],scene.shape[1],scene.shape[2],scene.shape[3],scene.shape[4]).cuda()
            else: 
                X,XA,y,x3 = batch

            X =(X-_mean)/_std
            X,XA,y,x3 = X.cuda(),XA.cuda(),y.cuda(),x3.cuda()
            
            encoder_inputs = x3.reshape(args.batch_size, args.obs_seq_len,-1)
            ref_skel = copy.deepcopy(dataset_train._x3[0])
            x3_k = [CenteredScaled(frame) for frame in x3]
            x3_k = torch.stack(x3_k)
            x3_k_tg = [inv_exp(ref_skel,frame) for frame in x3_k]
            x3_k_tg = torch.stack(x3_k_tg)
            encoder_inputs_k = x3_k_tg.reshape(args.batch_size, args.obs_seq_len,-1) #torch.Size([16, 50, 63])okk
            ref_skel_l = copy.deepcopy(dataset_train._x3[0][0])
            x3_l_tg = [lie_group_and_algebra_transform(frame,ref_skel_l) for frame in x3]
            x3_l_tg = np.array(x3_l_tg)
            x3_l_tg= torch.tensor(x3_l_tg)
            encoder_inputs_l = x3_l_tg.reshape(args.batch_size, args.obs_seq_len, -1)
            decoder_outputs = y.view(args.batch_size, args.pred_seq_len,63)

            optimizer.zero_grad()

            encoder_inputs_k = encoder_inputs_k.float().to(device)
            encoder_inputs_l = encoder_inputs_l.float().to(device)
            encoder_inputs = encoder_inputs.float().to(device)
            decoder_outputs = decoder_outputs.float().to(device)

            prediction= net(encoder_inputs_k, encoder_inputs_l, encoder_inputs, decoder_outputs, batch_size,train_mode=0)
                                    
            loss_val += loss(prediction,y).item()

            mpjpe_avg+= MPJPE(prediction,y).item()
            mpjpe_qaurt_avg+=MPJPE_timelimit(prediction,y,0.25).item()
            mpjpe_half_avg+=MPJPE_timelimit(prediction,y,0.50).item()
            mpjpe_3quart_avg+=MPJPE_timelimit(prediction,y,0.75).item()

            mpjpe_path_avg+= MPJPE_torso(prediction,y,torso_joint=args.torso_joint).item() 
            mpjpe_path_qaurt_avg+=MPJPE_torso_timelimit(prediction,y,0.25,torso_joint=args.torso_joint).item()  
            mpjpe_path_half_avg+=MPJPE_torso_timelimit(prediction,y,0.50,torso_joint=args.torso_joint).item() 
            mpjpe_path_3quart_avg+=MPJPE_torso_timelimit(prediction,y,0.75,torso_joint=args.torso_joint).item()  

    loss_val /= (cnt+1)
    mpjpe_avg /= (cnt+1)
    mpjpe_qaurt_avg /= (cnt+1)
    mpjpe_half_avg /= (cnt+1)
    mpjpe_3quart_avg /= (cnt+1)

    mpjpe_path_avg /= (cnt+1)
    mpjpe_path_qaurt_avg /= (cnt+1)
    mpjpe_path_half_avg  /= (cnt+1)
    mpjpe_path_3quart_avg /= (cnt+1)

    mpjpe_avg = int(mpjpe_avg*1000)
    mpjpe_qaurt_avg = int(mpjpe_qaurt_avg*1000)
    mpjpe_half_avg = int(mpjpe_half_avg*1000)
    mpjpe_3quart_avg = int(mpjpe_3quart_avg*1000)

    mpjpe_path_avg = int(mpjpe_path_avg*1000)
    mpjpe_path_qaurt_avg = int(mpjpe_path_qaurt_avg*1000)
    mpjpe_path_half_avg  =int(mpjpe_path_half_avg*1000)
    mpjpe_path_3quart_avg = int(mpjpe_path_3quart_avg*1000)

    print('#'*30)
    print('All results are in mm')
    print('*'*30)
    print('MPJPE POSE: 0.25\t 0.50\t 0.75\t full')
    print('MPJPE: ',mpjpe_qaurt_avg,'\t ',mpjpe_half_avg,'\t ',mpjpe_3quart_avg,'\t ',mpjpe_avg,'')
    print('*'*30)
    print('MPJPE PATH: 0.25\t 0.50\t 0.75\t full')
    print('PATH: ',mpjpe_path_qaurt_avg,'\t ',mpjpe_path_half_avg,'\t ',mpjpe_path_3quart_avg,'\t ',mpjpe_path_avg,'')
    print('#'*30)
    print('#'*30)

    f = open(checkpoint_dir+"eval.csv", "w")
    eval_id = ''
    for k,v in vars(args).items():
        if k == 'eval_only' or k =='torso_joint' or k =='tag':
            continue
        eval_id += str(k)+str(v)
    eval_result = [eval_id,',',original_tag,',',mpjpe_path_qaurt_avg,',',mpjpe_path_half_avg,',',mpjpe_path_3quart_avg,',',mpjpe_path_avg,',',mpjpe_qaurt_avg,',',mpjpe_half_avg,',',mpjpe_3quart_avg,',',mpjpe_avg,',',FDE,',',ADE,',',STB,'\n']
    eval_result_row = ''
    for ss in eval_result:
        eval_result_row+= str(ss)
    f.write(eval_result_row)
    f.close()

In [63]:
#Eval at differnet time steps and 3d pose, 3d positions (GTA)

print("Load the model weights")
model.load_state_dict(torch.load(checkpoint_dir+'val_mpjpe_best.pth'))

eval_metrics()

Load the model weights
##############################
All results are in mm
******************************
MPJPE POSE: 0.25	 0.50	 0.75	 full
MPJPE:  48.2 	  66.4 	  73.1 	  85.3 
******************************
MPJPE PATH: 0.25	 0.50	 0.75	 full
PATH:  71.2 	  106.3 	  156.1 	  220.6 
##############################
##############################
