In [None]:
import numpy as np
import trimesh
import pickle
from trimesh.viewer import SceneViewer
import pyglet
import glob
import joblib
import os

In [None]:
class SMPLModel():
    def __init__(self, model_path):
        with open(model_path, 'rb') as f:

            params = pickle.load(f, encoding='latin1')
            self.J_regressor = params['J_regressor']
            self.weights = params['weights']
            self.posedirs = params['posedirs']
            self.v_template = params['v_template']
            self.shapedirs = params['shapedirs']
            self.faces = params['f']
            self.kintree_table = params['kintree_table']

        id_to_col = {
            self.kintree_table[1, i]: i for i in range(self.kintree_table.shape[1])
        }
        self.parent = {
            i: id_to_col[self.kintree_table[0, i]]
            for i in range(1, self.kintree_table.shape[1])
        }

        self.pose_shape = [24, 3]
        self.beta_shape = [10]
        self.trans_shape = [3]

        self.pose = np.zeros(self.pose_shape)
        self.beta = np.zeros(self.beta_shape)
        self.trans = np.zeros(self.trans_shape)

        self.verts = None
        self.J = None
        self.R = None

        self.update()

    def set_params(self, pose=None, beta=None, trans=None):
        if pose is not None:
            self.pose = pose
        if beta is not None:
            self.beta = beta
        if trans is not None:
            self.trans = trans
        self.update()
        return self.verts

    def update(self):
        # how beta affect body shape
        v_shaped = self.shapedirs.dot(self.beta) + self.v_template

        # joints location due to changed body shape
        self.J = self.J_regressor.dot(v_shaped)
        pose_cube = self.pose.reshape((-1, 1, 3))

        # rotation matrix for each joint
        self.R = self.rodrigues(pose_cube)
        I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), (self.R.shape[0] - 1, 3, 3))
        lrotmin = (self.R[1:] - I_cube).ravel()

        # how pose affect body shape in zero pose
        v_posed = v_shaped + self.posedirs.dot(lrotmin)

        # world transformation of each joint
        G = np.empty((self.kintree_table.shape[1], 4, 4))
        G[0] = self.with_zeros(np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
        for i in range(1, self.kintree_table.shape[1]):
            G[i] = G[self.parent[i]].dot(self.with_zeros(np.hstack([self.R[i], ((self.J[i, :] - self.J[self.parent[i], :]).reshape([3, 1]))])))

        # remove the transformation due to the rest pose
        G = G - self.pack(np.matmul(G,np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))

        # transformed vertices
        T = np.tensordot(self.weights, G, axes=[[1], [0]])
        rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
        v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1, 4])[:, :3]
        self.verts = v + self.trans.reshape([1, 3])

        # transformed joints
        self.J = self.J_regressor.dot(self.verts)


    def rodrigues(self, r):
        theta = np.linalg.norm(r, axis=(1, 2), keepdims=True)
        # avoid zero divide
        theta = np.maximum(theta, np.finfo(np.float64).tiny)
        r_hat = r / theta
        cos = np.cos(theta)
        z_stick = np.zeros(theta.shape[0])

        m = np.dstack([z_stick, -r_hat[:, 0, 2], r_hat[:, 0, 1], r_hat[:, 0, 2], z_stick, -r_hat[:, 0, 0], -r_hat[:, 0, 1], r_hat[:, 0, 0], z_stick]).reshape([-1, 3, 3])

        i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0), [theta.shape[0], 3, 3])

        A = np.transpose(r_hat, axes=[0, 2, 1])
        B = r_hat
        dot = np.matmul(A, B)
        R = cos * i_cube + (1 - cos) * dot + np.sin(theta) * m

        return R

    def with_zeros(self, x):
        return np.vstack((x, np.array([[0.0, 0.0, 0.0, 1.0]])))

    def pack(self, x):
        return np.dstack((np.zeros((x.shape[0], 4, 3)), x))

In [None]:
def load_pose(model_path):
    with open(model_path, 'rb') as f:
        params = joblib.load(f)
    return np.reshape(params['pose'], [24,3])

def load_shape(model_path):
    with open(model_path, 'rb') as f:
        params = joblib.load(f)
    return np.squeeze(params['beta'])

def color(dataset):
    if dataset.find('GT') == 0:
        return [0.65,0.74,0.86]
    elif dataset.find('Tang2018') == 0:
        return [255,0,0,255]
    elif dataset.find('Ren2020') == 0:
        return [0,255,0,255]
    elif dataset.find('Huang2021') == 0:
        return [0,0,255,255]
    elif dataset.find('WGAN') == 0:
        return [255,0,255,255]
    elif dataset.find('Remove_GW') == 0:
        return [255,255,0,255]
    else:
        return [0.9,0.7,0.7]

In [None]:
class MeshWindow():
    counter = 0
    def __init__(self, dataset):
        self.dataset = dataset
        self.colors = color(dataset)
        self.dir_save = 'AAAI2022_Results/Rendered/{}'.format(dataset)
        if not os.path.isdir(self.dir_save):
            os.makedirs(self.dir_save, exist_ok=True)
            
        self.smpl = SMPLModel('SMPL/SMPLmodel/models/smpl/SMPL_NEUTRAL.pkl')
        
        params = joblib.load('AAAI2022_Results/{}.pkl'.format(dataset))
        self.pose_params = params['pose']
        self.trans_params = np.zeros([250,3])
        # self.trans_params = params['trans']
    
        # self.smpl.beta = load_shape('Rendered/{}/{}.pkl'.format(dataset, 0))
        
        mesh = trimesh.Trimesh(self.smpl.verts, self.smpl.faces, face_colors=[0,0,0,0])
        self.scene = trimesh.Scene(mesh)
        self.scene.set_camera(distance=4, center=[0,0,0], resolution=[720, 720])
        
    def update(self, scene):
        self.smpl.pose = self.pose_params[MeshWindow.counter]
        self.smpl.trans = self.trans_params[MeshWindow.counter]/100
        self.smpl.update()
        scene.geometry.clear()
        scene.add_geometry(trimesh.Trimesh(self.smpl.verts, self.smpl.faces, face_colors=self.colors))
        
        MeshWindow.counter = (MeshWindow.counter+1) % 250
        
    def save_frames(self):
        try:
            lastframe = int(glob.glob('{}/*.png'.format(self.dir_save))[-1][-7:-4])
        except:
            lastframe = 0
        for i in range(lastframe, len(self.pose_params)):
            self.smpl.pose = self.pose_params[i]
#             self.smpl.trans = np.zeros([250,3])
            self.smpl.update()
            while True:
                try:
                    self.scene.geometry.clear()
                    self.scene.add_geometry(trimesh.Trimesh(self.smpl.verts, self.smpl.faces, face_colors=self.colors))

                    file_name = '{}/{:03d}.png'.format(self.dir_save, i)
                    # save a render of the object as a png
                    png = self.scene.save_image(resolution=[720, 720], visible=True)
                    with open(file_name, 'wb') as f:
                        f.write(png)
                        f.close()
                except:
                    continue
                break
                
    def make_video(self):
        os.system('ffmpeg -y -i {}/%03d.png -c:v libx264 -vf "fps=25,format=yuv420p" Videos/{}.mp4'.format(self.dir_save, self.dataset))
    
    def show(self):
        SceneViewer(self.scene, callback=self.update)

In [None]:
all_datasets = glob.glob('IJCAI2022_Results/*.pkl')
all_datasets = [i[17:-4] for i in all_datasets]
all_datasets

In [None]:
for dataset in all_datasets:
    if os.path.exists('Videos/{}.mp4'.format(dataset)):
        continue
    mymesh = MeshWindow(dataset)
    # mymesh.show()
    mymesh.save_frames()
    mymesh.make_video()