In [None]:
import torch
import tensorflow as tf
import numpy as np
from main.smpl import Smpl

In [None]:
def smpl2vertices(pose_and_shape, **kwargs):
    smpl = Smpl()
    vertices, joints_3d, rotations = smpl(pose_and_shape, **kwargs)    # transform the pose and shape parameter to SMPL vertices
    shapes = pose_and_shape[:, -10:] # NUM_SHAPE_PARAMS = 10

    return tf.tuple([vertices, joints_3d, shapes])

In [None]:
def batch_align_by_pelvis(kp3d):
    left_id, right_id = 3, 0
    pelvis = (kp3d[:, left_id, :] + kp3d[:, right_id, :]) / 2.
    return kp3d - tf.expand_dims(pelvis, axis=1)

def batch_compute_similarity_transform(real_kp3d, pred_kp3d):
    # transpose to [batch x 3 x K]
    real_kp3d = tf.transpose(real_kp3d, perm=[0, 2, 1])
    pred_kp3d = tf.transpose(pred_kp3d, perm=[0, 2, 1])

    # 1. Remove mean.
    mean_real = tf.reduce_mean(real_kp3d, axis=2, keepdims=True)
    mean_pred = tf.reduce_mean(pred_kp3d, axis=2, keepdims=True)

    centered_real = real_kp3d - mean_real
    centered_pred = pred_kp3d - mean_pred

    # 2. Compute variance of centered_real used for scale.
    variance = tf.reduce_sum(centered_pred ** 2, axis=[-2, -1], keepdims=True)

    # 3. The outer product of centered_real and centered_pred.
    K = tf.matmul(centered_pred, centered_real, transpose_b=True)

    # 4. Solution that Maximizes trace(R'K) is R=s*V', where s, V are
    # singular vectors of K.
    with tf.device('/CPU:0'):
        # SVD is terrifyingly slow on GPUs, use cpus for this. Makes it a lot faster.
        s, u, v = tf.linalg.svd(K, full_matrices=True)

        # Construct identity that fixes the orientation of R to get det(R)=1.
        det = tf.sign(tf.linalg.det(tf.matmul(u, v, transpose_b=True)))

    det = tf.expand_dims(tf.expand_dims(det, -1), -1)
    shape = tf.shape(u)
    identity = tf.eye(shape[1], batch_shape=[shape[0]])
    identity = identity * det

    # Construct R.
    R = tf.matmul(v, tf.matmul(identity, u, transpose_b=True))

    # 5. Recover scale.
    trace = tf.linalg.trace(tf.matmul(R, K))
    trace = tf.expand_dims(tf.expand_dims(trace, -1), -1)
    scale = trace / variance

    # 6. Recover translation.
    trans = mean_real - scale * tf.matmul(R, mean_pred)

    # 7. Align
    aligned_kp3d = scale * tf.matmul(R, pred_kp3d) + trans

    return tf.transpose(aligned_kp3d, perm=[0, 2, 1])

def cocomap(data, batch_size = 512):
  #Reduce 19 to 14 joints
  joint = np.zeros((batch_size, 14,3))
  total = [0,1,2,3,4,5,6,7,8,9,10,11,12,13]
  count = 0
  for i in total:
    joint[:,count] = data[:,i]
    count += 1
  return joint

def gtmap(data, batch_size=512):
  #Reduce 17 to 14 joints
  joint = np.zeros((batch_size, 14,3))
  total = [3,2,1,4,5,6, 16,15,14,11,12,13,9,10]
  count = 0
  for i in total:
    joint[:, count] = data[:, i]
    count += 1
  return joint


In [None]:
def loss_3d(kp3d_sym, gt3d):
    kp3d_mpjpe_aligned = tf.norm(kp3d_sym - gt3d, axis=2)
    kp3d_mpjpe_aligned = tf.reduce_mean(kp3d_mpjpe_aligned)
    return kp3d_mpjpe_aligned

In [None]:
def loss_compute(pred, central, gt3d):
    pred = pred.cpu()
    central = central.cpu()
    loss1 = torch.mean(torch.norm((pred.squeeze() - central.squeeze()), dim = 1))

    pred = np.squeeze(pred)
    gt3d = np.squeeze(gt3d)
    
    pred = pred.detach().numpy()
    pred = tf.convert_to_tensor(pred)
    vertices, kp3d_pred, shapes = smpl2vertices(pred)

    joint_3d = cocomap(kp3d_pred)
    # print(joint_3d.shape)

    gt_3d = gt3d.detach().numpy()
    # print(gt_3d.shape)
    joint_gt_3d = gtmap(gt_3d)

    # print(joint_gt_3d.shape)
    gt3d = batch_align_by_pelvis(joint_gt_3d)
    kp3d = batch_align_by_pelvis(joint_3d)
    # kp3d = torch.from_numpy(kp3d)
    # gt3d = torch.from_numpy(gt3d)
    kp3d = tf.cast(kp3d, float)
    gt3d = tf.cast(gt3d, float)

    kp3d_sym = batch_compute_similarity_transform(gt3d, kp3d)

    loss2 = loss_3d(kp3d_sym, gt3d)
    loss2 = loss2.numpy()
    loss2 = torch.from_numpy(np.asarray(loss2))

    loss = loss1 + 10 * loss2
    # print(" The total loss is: ", loss, '\n', 'The smpl loss is: ', loss1, '\n', 'The 3D joints loss is: ', loss2)

    return loss, loss1, loss2

# loss_compute(central, central, gt_3d)

In [None]:
def loss_compute_growth(pred, central, gt3d, epoch):
    pred = pred.cpu()
    central = central.cpu()
    loss1 = torch.mean(torch.norm((pred.squeeze() - central.squeeze()), dim = 1))

    pred = np.squeeze(pred)
    gt3d = np.squeeze(gt3d)
    
    pred = pred.detach().numpy()
    pred = tf.convert_to_tensor(pred)
    vertices, kp3d_pred, shapes = smpl2vertices(pred)

    joint_3d = cocomap(kp3d_pred)
    # print(joint_3d.shape)

    gt_3d = gt3d.detach().numpy()
    # print(gt_3d.shape)
    joint_gt_3d = gtmap(gt_3d)

    # print(joint_gt_3d.shape)
    gt3d = batch_align_by_pelvis(joint_gt_3d)
    kp3d = batch_align_by_pelvis(joint_3d)
    # kp3d = torch.from_numpy(kp3d)
    # gt3d = torch.from_numpy(gt3d)
    kp3d = tf.cast(kp3d, float)
    gt3d = tf.cast(gt3d, float)

    kp3d_sym = batch_compute_similarity_transform(gt3d, kp3d)

    loss2 = loss_3d(kp3d_sym, gt3d)
    loss2 = loss2.numpy()
    loss2 = torch.from_numpy(np.asarray(loss2))

    k = 1+epoch*0.1

    loss = 100*loss1 + 1000 * k * loss2
    # print(" The total loss is: ", loss, '\n', 'The smpl loss is: ', loss1, '\n', 'The 3D joints loss is: ', loss2)

    return loss, 100*loss1, 1000*loss2

# loss_compute(central, central, gt_3d)

In [None]:
def loss_compute_grad(pred, central, gt3d, epoch):
    pred = pred.cpu()
    central = central.cpu()
    loss1 = torch.mean(torch.norm((pred.squeeze() - central.squeeze()), dim = 1))

    pred = np.squeeze(pred)
    gt3d = np.squeeze(gt3d)
    
    pred = pred.detach().numpy()
    pred = tf.convert_to_tensor(pred)
    vertices, kp3d_pred, shapes = smpl2vertices(pred)

    joint_3d = cocomap(kp3d_pred)
    # print(joint_3d.shape)

    gt_3d = gt3d.detach().numpy()
    # print(gt_3d.shape)
    joint_gt_3d = gtmap(gt_3d)

    # print(joint_gt_3d.shape)
    gt3d = batch_align_by_pelvis(joint_gt_3d)
    kp3d = batch_align_by_pelvis(joint_3d)
    # kp3d = torch.from_numpy(kp3d)
    # gt3d = torch.from_numpy(gt3d)
    kp3d = tf.cast(kp3d, float)
    gt3d = tf.cast(gt3d, float)

    kp3d_sym = batch_compute_similarity_transform(gt3d, kp3d)

    loss2 = loss_3d(kp3d_sym, gt3d)
    loss2 = loss2.numpy()
    loss2 = torch.from_numpy(np.asarray(loss2))

    k = 1+epoch*0.1

    loss = loss1*loss2*1000
    # print(" The total loss is: ", loss, '\n', 'The smpl loss is: ', loss1, '\n', 'The 3D joints loss is: ', loss2)

    return loss, loss1, 1000*loss2

# loss_compute(central, central, gt_3d)

In [None]:
def mpjpe(pred, central, gt3d):
    pred = pred.cpu()
    central = central.cpu()
    loss1 = torch.mean(torch.norm((pred.squeeze() - central.squeeze()), dim = 1))

    pred = np.squeeze(pred)
    gt3d = np.squeeze(gt3d)
    
    pred = pred.detach().numpy()
    pred = tf.convert_to_tensor(pred)
    vertices, kp3d_pred, shapes = smpl2vertices(pred)

    joint_3d = cocomap(kp3d_pred)
    # print(joint_3d.shape)

    gt_3d = gt3d.detach().numpy()
    # print(gt_3d.shape)
    joint_gt_3d = gtmap(gt_3d)

    # print(joint_gt_3d.shape)
    gt3d = batch_align_by_pelvis(joint_gt_3d)
    kp3d = batch_align_by_pelvis(joint_3d)
    # kp3d = torch.from_numpy(kp3d)
    # gt3d = torch.from_numpy(gt3d)
    kp3d = tf.cast(kp3d, float)
    gt3d = tf.cast(gt3d, float)

    kp3d_sym = batch_compute_similarity_transform(gt3d, kp3d)

    loss2 = loss_3d(kp3d_sym, gt3d)
    loss2 = loss2.numpy()
    loss2 = torch.from_numpy(np.asarray(loss2))


    k1 = 1
    k2 = 1+epoch*0.1
    k1 = k1/(k1+k2)
    k2 = k2/(k1+k2)

    loss = k1* 100 * loss1 + 1000 * k2 * loss2
    # print(" The total loss is: ", loss, '\n', 'The smpl loss is: ', loss1, '\n', 'The 3D joints loss is: ', loss2)

    return loss2*1000

# loss_compute(central, central, gt_3d)