In [67]:
#!pip install smplx
#!pip install chamferdist

import time
import json as js
import numpy as np
import os
from os.path import join, exists

import torch
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import smplx
from smplx.lbs import *
#import chamferdist
#from chamferdist import ChamferDist
from chamfer_distance import ChamferDistance

from plyfile import PlyData, PlyElement
import pandas as pd
import pickle as pkl
import math

In [2]:
gpu_id = [0]

if len(gpu_id) > 0 and torch.cuda.is_available():
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id[0])
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
torch.cuda.set_device(1)

In [3]:
SMPL_joints = [1, 2,  4,  5,  7,  8, 10, 11, 12, 16, 17, 18, 19, 20, 21, 24, 25, 26, 27, 28, 6, 3, 0, 9,]
OPOS_joints = [12, 9, 13, 10, 14, 11, 19, 22,  1,  5,  2,  6,  3,  7,  4,  0, 15, 16, 17, 18, ]

print(len(SMPL_joints))

24


In [125]:
def load_ske(path):
    if not exists(path):
        raise ValueError(path)

    skeleton = []
    f = open(path)
    for line in f.readlines():
        numbers = line.split()        # 将数据分隔
        numbers_float = map(float, numbers) #转化为浮点数
        skeleton.append(list(numbers_float)[:3])
    f.close()
    
    skeleton = np.array(skeleton)
    skeleton = torch.from_numpy(skeleton)    
    return skeleton

def customize_load_smplx_params(path):
    with open(path, 'rb') as smplx_file:
        model_data = pkl.load(smplx_file, encoding='latin1')
    for (k,v) in model_data.items(): 
            model_data[k] = torch.from_numpy(v).double().to(device)
    return model_data

def customize_save_smplx_params(model_data, path):
    for (k,v) in model_data.items(): 
            model_data[k] = v.detach().cpu().double().numpy()
    with open(path, 'wb') as f:
           pkl.dump(model_data, f, protocol=2)    

def customize_save_ply(vs, fs, path):
    vs = [(vs[i,0], vs[i,1], vs[i,2]) for i in range(vs.shape[0])]
    vertex = np.array(vs, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
    
    face = np.empty(len(fs), dtype=[('vertex_indices', 'i4', (3,))])
    face['vertex_indices'] = fs
    
    el_vs = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
    el_fs = PlyElement.describe(face, 'face', comments=['faces'])        

    PlyData([el_vs, el_fs]).write(path)
            
def customize_load_ply(path):
    plydata = PlyData.read(path)
    data_pd = pd.DataFrame(plydata.elements[0].data) 
    data_np = np.zeros((data_pd.shape[0], 3), dtype=np.float64)

    property_names = plydata.elements[0].data[0].dtype.names
    for i in range(3):  
        data_np[:, i] = data_pd[property_names[i]]

    ply_vertices = torch.from_numpy(data_np)
    ply_vertices = ply_vertices.float().to(device)
    return ply_vertices
    
def write_obj(verts, faces, file_name):
    with open(file_name, 'w') as fp:
        for i in range(verts.shape[0]):
            v = verts[i]
            fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))

        for i in range(faces.shape[0]):
            f = faces[i]
            fp.write('f %d %d %d\n' % (f[0] + 1, f[1] + 1, f[2] + 1))    

In [130]:
# setup smplx
s_model = smplx.create("./models", "smplx", gender = "male", num_pca_comps = 12, use_face_contour= False, ext = "pkl", dtype = torch.float64).to(device)
f_model = smplx.create("./models", "smplx", gender = "female", num_pca_comps = 12, use_face_contour= False, ext = "pkl", dtype = torch.float64).to(device)
body_model = f_model

pose_size = body_model.NUM_BODY_JOINTS * 3 # 54*3 
beta_size = body_model.NUM_BETAS # 10

betas = torch.zeros(beta_size, device=device, dtype=torch.double, requires_grad=True)
pose = torch.zeros(pose_size, device=device, dtype=torch.double, requires_grad=True)
trans = torch.zeros(3, device=device, dtype=torch.double, requires_grad=True)    
pos_dir = torch.zeros(3, device=device, dtype=torch.double, requires_grad=True)

# idx = 0
res = body_model(betas=betas.view(1, -1), body_pose=pose.view(1, -1), global_orient=pos_dir.view(1, -1), transl=trans.view(1,-1), return_verts=True, return_full_pose=True)
write_obj(res.vertices.view(-1,3), body_model.faces, "t.obj")

In [36]:
smplx_precompute_pattern = "/data/NFS/new_disk/chenxin/relightable-nr/data/white_walkshow/mesh_smplx/results/%03d-removebg-preview/000.pkl"
print(model_data)
print(model_data['body_pose'].shape)
print(model_data['left_hand_pose'].shape)

res = body_model(betas=model_data['betas'], 
           global_orient=model_data['global_orient'], 
           left_hand_pose=model_data['left_hand_pose'],
           right_hand_pose=model_data['right_hand_pose'],
           jaw_pose=model_data['jaw_pose'],
           leye_pose=model_data['leye_pose'],
           reye_pose=model_data['reye_pose'],
           expression=model_data['expression'],
           body_pose=model_data['body_pose_joint'])
write_obj(res.vertices.view(-1,3), body_model.faces, "t1.obj")

{'camera_rotation': tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]], device='cuda:1', dtype=torch.float64), 'camera_translation': tensor([[-2.2852e-02,  6.9026e-02,  2.3333e+01]], device='cuda:1',
       dtype=torch.float64), 'betas': tensor([[ 3.1941, -0.9010,  3.1141, -0.2805, -0.0660,  2.0125, -1.3411, -0.0242,
          1.2103, -0.0469]], device='cuda:1', dtype=torch.float64), 'global_orient': tensor([[ 3.1834,  0.0446, -0.1491]], device='cuda:1', dtype=torch.float64), 'left_hand_pose': tensor([[-0.3915,  0.0575,  0.2200, -0.2511,  0.1584, -0.0458, -0.1827,  0.1439,
         -0.1736, -0.0109,  0.0938,  0.0734]], device='cuda:1',
       dtype=torch.float64), 'right_hand_pose': tensor([[-0.1808, -0.7088, -0.0313,  0.2641,  0.4155,  0.1189,  0.3650,  0.2739,
          0.1288, -0.6526,  0.7052, -0.1389]], device='cuda:1',
       dtype=torch.float64), 'jaw_pose': tensor([[0.0655, 0.0007, 0.0003]], device='cuda:1', dtype=torch.float64), 'leye_pose': tensor([[1.0582,

In [129]:
load_precompute = True

frame_range = [111,156]

# body_ske_pattern = "/home/haines/datasets/models/%s/%d/openpose/skeleton_body/skeleton.txt"
# lhand_ske_pattern = "/home/haines/datasets/models/%s/%d/openpose/skeleton_hand_left/skeleton.txt"
# rhand_ske_pattern = "/home/haines/datasets/models/%s/%d/openpose/skeleton_hand_right/skeleton.txt"

smplx_precompute_pattern = "/data/NFS/new_disk/chenxin/relightable-nr/data/white_walkshow/mesh_smplx/results/%03d-removebg-preview/000.pkl"
target_mesh_pattern = "/data/NFS/new_disk/chenxin/relightable-nr/data/white_walkshow/mesh_pifu/result_%03d-removebg-preview_512.ply"
rs_mesh_pattern = '/data/NFS/new_disk/chenxin/relightable-nr/data/white_walkshow/mesh_smplx2pifu/rs_%03d.ply'
rs_params_pattern = '/data/NFS/new_disk/chenxin/relightable-nr/data/white_walkshow/mesh_smplx2pifu/ps_%03d.pkl'

torch.autograd.set_detect_anomaly(True)

chamfer_dist = ChamferDistance()

betas = []
pose = []
glob_t = []
glob_r = []
pose_size = body_model.NUM_BODY_JOINTS * 3 # 54*3 
hand_size = body_model.NUM_HAND_JOINTS * 3 # 15*3
beta_size = body_model.NUM_BETAS # 10

for iF in range(frame_range[0], frame_range[1]):
    
    # check target
    if not exists(target_mesh_pattern % iF):
        raise ValueError(target_mesh_pattern % iF)
        
    time_frame_start=time.time()
        
    # load ske
    #ske_body = load_ske(body_ske_pattern%iF)[OPOS_joints]
    #ske_lhand = load_ske(lhand_ske_pattern%iF)[OPOS_joints]
    #ske_rhand = load_ske(rhand_ske_pattern%iF)[OPOS_joints]
    
    # load target
    target_vertices = customize_load_ply(target_mesh_pattern%iF)
    
    # init params
    if iF == frame_range[0]:
        # load smplx
        epoch = 1001
        # left_hand_pose = torch.zeros(hand_size, device=device, dtype=torch.double, requires_grad=True)
        # right_hand_pose = torch.zeros(hand_size, device=device, dtype=torch.double, requires_grad=True)
        if load_precompute:
            smplx_params = customize_load_smplx_params('/data/NFS/new_disk/chenxin/relightable-nr/data/white_walkshow/mesh_smplx2pifu/ps_%03d.pkl'%iF)
            betas = smplx_params['betas'].clone().detach().requires_grad_(True).double()
            pose = smplx_params['body_pose_joint'].clone().detach().requires_grad_(True).double()
            glob_t = smplx_params['transl'].clone().detach().requires_grad_(True).double()
            glob_r = smplx_params['global_orient'].clone().detach().requires_grad_(True).double()
            print("Load first frame ...")
            continue
        else:
            smplx_params = customize_load_smplx_params(smplx_precompute_pattern%iF)
            betas = smplx_params['betas'].clone().detach().requires_grad_(True).double()
            pose = smplx_params['body_pose_joint'].clone().detach().requires_grad_(True).double()
            glob_t = torch.zeros(3, device=device, dtype=torch.double, requires_grad=True)
            glob_r = torch.zeros(3, device=device, dtype=torch.double, requires_grad=True)   
            
    else:
        smplx_params = customize_load_smplx_params(smplx_precompute_pattern%iF)
        epoch = 501
    
    ##########################################
    # optmize with global with target
    ##########################################
    # optmizer
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(params=[betas, pose, glob_t, glob_r], lr = 0.01)
    # optimizer = torch.optim.Adam(params=[glob_t, glob_r], lr = 0.01)
    for i in range(epoch):
        optimizer.zero_grad()

        # update smplx
        time_start=time.time()
        res = body_model(betas=betas, 
               global_orient=glob_r.view(1,-1), # smplx_params['global_orient'], 
               transl=glob_t.view(1,-1), 
               left_hand_pose=smplx_params['left_hand_pose'],
               right_hand_pose=smplx_params['right_hand_pose'],
               jaw_pose=smplx_params['jaw_pose'],
               leye_pose=smplx_params['leye_pose'],
               reye_pose=smplx_params['reye_pose'],
               expression=smplx_params['expression'],
               body_pose=pose,
               return_verts=True,
               return_full_pose=True)
        time_smpl=time.time() - time_start
        
        d1, d2 = chamfer_dist(res.vertices.float(), target_vertices.view(1,-1,3).float())
        loss = (torch.mean(d1)) + (torch.mean(d2))
        ##########################################
        # TO-DO
        # add loss - porjected joints 
        # add pose embedding - 3D pose embedding prior (see vposer in smplify-x)
        ##########################################
        time_chamdist=time.time()-time_start-time_smpl

        loss.backward()
        optimizer.step()
        time_backward=time.time()-time_start-time_smpl-time_chamdist

        if i % 40 == 0:
            print("epch: %d,  loss: %f, t_smpl: %f, t_chamDist: %f, t_backward: %f" % (i, loss.item(), time_smpl, time_chamdist, time_backward))
        if loss.item() < 0.000300:
            print("    stop optmization")
            break
    res = body_model(betas=betas, 
           global_orient=glob_r.view(1,-1), # smplx_params['global_orient'], 
           transl=glob_t.view(1,-1), 
           left_hand_pose=smplx_params['left_hand_pose'],
           right_hand_pose=smplx_params['right_hand_pose'],
           jaw_pose=smplx_params['jaw_pose'],
           leye_pose=smplx_params['leye_pose'],
           reye_pose=smplx_params['reye_pose'],
           expression=smplx_params['expression'],
           body_pose=pose,
           return_verts=True,
           return_full_pose=True)
    rs_data = smplx_params
    rs_data['betas'] = betas.clone()
    rs_data['body_pose_joint'] = pose.clone()
    rs_data['transl'] = glob_t.view(1,-1).clone()
    rs_data['global_orient'] = glob_r.view(1,-1).clone()
    customize_save_smplx_params(rs_data, rs_params_pattern%iF)
    time_opt = time.time()
    
    customize_save_ply(res.vertices.view(-1,3).clone().detach().cpu(), body_model.faces, rs_mesh_pattern%iF)
    
    # write_obj(res.vertices.view(-1,3), body_model.faces, rs_mesh_pattern%iF)
    time_total = time.time() - time_frame_start
    time_save = time.time() - time_opt
    print("Frame %d Save %s, t_save %f, t_totall %f" % (iF, rs_mesh_pattern%iF, time_save, time_total))


Load first frame ...
epch: 0,  loss: 0.000437, t_smpl: 0.307997, t_chamDist: 0.042816, t_backward: 0.087488
epch: 40,  loss: 0.000399, t_smpl: 0.288871, t_chamDist: 0.009977, t_backward: 0.073625
epch: 80,  loss: 0.000393, t_smpl: 0.246073, t_chamDist: 0.004665, t_backward: 0.081347
epch: 120,  loss: 0.000392, t_smpl: 0.276382, t_chamDist: 0.007746, t_backward: 0.078207
epch: 160,  loss: 0.000391, t_smpl: 0.230507, t_chamDist: 0.004776, t_backward: 0.074420
epch: 200,  loss: 0.000391, t_smpl: 0.220127, t_chamDist: 0.013072, t_backward: 0.073458
epch: 240,  loss: 0.000391, t_smpl: 0.288065, t_chamDist: 0.011983, t_backward: 0.074887
epch: 280,  loss: 0.000390, t_smpl: 0.367311, t_chamDist: 0.018193, t_backward: 0.095960
epch: 320,  loss: 0.000390, t_smpl: 0.251924, t_chamDist: 0.006609, t_backward: 0.077235
epch: 360,  loss: 0.000390, t_smpl: 0.237042, t_chamDist: 0.004520, t_backward: 0.085834
epch: 400,  loss: 0.000390, t_smpl: 0.249157, t_chamDist: 0.009787, t_backward: 0.068202
epch

KeyboardInterrupt: 

array([[    3,     1,     0],
       [    7,     5,     4],
       [   12,    14,    13],
       ...,
       [ 9944, 10097, 10084],
       [ 9940, 10084, 10071],
       [10071, 10058,  9932]], dtype=uint32)

In [71]:
rs_mesh_pattern = '/data/NFS/new_disk/chenxin/relightable-nr/data/white_walkshow/mesh_smplx2pifu/rs_%03d.obj'
res = body_model(betas=model_data['betas'], 
       global_orient=glob_r.view(1,-1), # model_data['global_orient'], 
       transl=glob_t.view(1,-1), 
       left_hand_pose=model_data['left_hand_pose'],
       right_hand_pose=model_data['right_hand_pose'],
       jaw_pose=model_data['jaw_pose'],
       leye_pose=model_data['leye_pose'],
       reye_pose=model_data['reye_pose'],
       expression=model_data['expression'],
       body_pose=model_data['body_pose_joint'],
       return_verts=True,
       return_full_pose=True)
write_obj(res.vertices.view(-1,3), body_model.faces, rs_mesh_pattern%iF)

In [None]:
# back 
    # optimize joint with joint
#     for i in range(1500):
#         optimizer.zero_grad()

#         res = body_model(betas=betas.view(1, -1), 
#                       body_pose=pose.view(1, -1), 
#                       global_orient=pos_dir.view(1, -1), 
#                       # left_hand_pose=left_hand_pose.view(1,-1),
#                       # right_hand_pose=right_hand_pose.view(1, -1),
#                       transl=trans.view(1,-1), 
#                       return_verts=True, 
#                       return_full_pose=True)

#         j = res.joints.squeeze()

#         s = skeleton[OPOS_joints]

#         # body
#         s = torch.cat([s, (skeleton[1] / 2 + skeleton[8] / 2).unsqueeze(0)],0)
#         s = torch.cat([s, (skeleton[1] / 4 + 3 * skeleton[8] / 4).unsqueeze(0)],0)
#         s = torch.cat([s, (skeleton[1] / 8 + 7 * skeleton[8] / 8).unsqueeze(0)],0)
#         s = torch.cat([s, (5 * skeleton[1] / 8 + 3 * skeleton[8] / 8).unsqueeze(0)],0)

#         loss = criterion(j[SMPL_joints], s.to(device))

#         loss.backward()
#         optimizer.step()

#         if i % 100 == 0:
#             print("epch: %d,  loss: %f" % (i, loss.item()))
#     # save output

#     # write_obj(res.vertices.view(-1,3), s_model.faces, "t1.obj")
        
    # optimize shape and pose with joint

In [None]:
# backup
    d1, d2 = chamfer_dist(res.vertices.float(), ply_vertices.view(1,-1,3).float())
    nloss = (torch.mean(d1)) + (torch.mean(d2))
    if nloss > 0.001:
        optimizer = torch.optim.Adam(params=[betas, pose, pos_dir, trans], lr = 0.01)
        for i in range(1500):
            optimizer.zero_grad()

            res = s_model(betas=betas.view(1, -1), 
                          body_pose=pose.view(1, -1), 
                          global_orient=pos_dir.view(1, -1), 
                          transl=trans.view(1,-1), 
                          return_verts=True, 
                          return_full_pose=True)

            j = res.joints.cpu().squeeze()
            s = skeleton[OPOS_joints]

            # body
            s = torch.cat([s, (skeleton[1] / 2 + skeleton[8] / 2).unsqueeze(0)],0)
            s = torch.cat([s, (skeleton[1] / 4 + 3 * skeleton[8] / 4).unsqueeze(0)],0)
            s = torch.cat([s, (skeleton[1] / 8 + 7 * skeleton[8] / 8).unsqueeze(0)],0)
            s = torch.cat([s, (5 * skeleton[1] / 8 + 3 * skeleton[8] / 8).unsqueeze(0)],0)

            loss = criterion(j[SMPL_joints], s)

            loss.backward()
            optimizer.step()

            if i % 100 == 0:
                print("epch: %d,  loss: %f" % (i, loss.item()))
        # write_obj(res.vertices.view(-1,3), s_model.faces, "t2.obj")
        
    # optimize joint with target
    if nloss > 0.001:
        optimizer = torch.optim.Adam(params=[betas, pose, pos_dir, trans], lr = 0.01)
        epoch = 1500
    else:
        epoch = 500
        optimizer = torch.optim.Adam(params=[pose, pos_dir, trans], lr = 0.01)   
    for i in range(epoch):
        optimizer.zero_grad()

        res = s_model(betas=betas.view(1, -1), body_pose=pose.view(1, -1), global_orient=pos_dir.view(1, -1), transl=trans.view(1,-1), return_verts=True, return_full_pose=True)
        d1, d2 = chamfer_dist(res.vertices.float(), ply_vertices.view(1,-1,3).float())
        loss = (torch.mean(d1)) + (torch.mean(d2))

        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print("epch: %d,  loss: %f" % (i, loss.item()))

    res = s_model(betas=betas.view(1, -1), body_pose=pose.view(1, -1), global_orient=pos_dir.view(1, -1), transl=trans.view(1,-1), return_verts=True, return_full_pose=True)
    write_obj(res.vertices.view(-1,3), s_model.faces, "./%s_smpl/%06d.obj"% (seq, idx))