In [1]:
import argparse
import datetime
import os
import sys
from enum import Enum

import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

from utils.general import instantiate_from_config
from torch.utils.tensorboard import SummaryWriter

from torch import nn

from utils.data import dict2device
import numpy as np
import json
import smplx



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_parser(**parser_kwargs):
    parser = argparse.ArgumentParser(**parser_kwargs)
    parser.add_argument(
        "-b",
        "--base",
        nargs="*",
        metavar="base_config.yaml",
        help="Paths to base configs. Loaded from left-to-right. "
             "Parameters can be overwritten or added with command-line options of the form `--key value`.",
        default=list(),
    )
    parser.add_argument(
        "-t",
        "--test_mode",
        action='store_true',
        help="Only evaluate metrics from the checkpoint",
    )
    parser.add_argument(
        "-p",
        "--pretrained",
        type=str,
        const=True,
        default=None,
        nargs="?",
        help="Load pretrained weights from the checkpoint",
    )
    return parser

def generate_path_to_logs(config, opt, sequence_name):
    experiment_name = opt.base[0].split('/')[-1].split('.yaml')[0]
    time = datetime.datetime.now()
    run_name = sequence_name + time.strftime(f"-%Y_%m-%d_%H-%M")
    log_dir = os.path.join(config.logdir, experiment_name, run_name)
    return log_dir


def create_test_datasets(config):
    test_dataset = instantiate_from_config(config.test_dataloader)
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=config.val_dataloader.batch_size,
        num_workers=config.val_dataloader.num_workers,
        shuffle=False,
        pin_memory=False,
        drop_last=False,
    )
    print("Test samples:", len(test_dataloader))
    return test_dataset, test_dataloader


def create_train_val_datasets(config):
    train_dataset = instantiate_from_config(config.train_dataloader)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.train_dataloader.batch_size,
        num_workers=config.train_dataloader.num_workers,
        shuffle=True,
        pin_memory=False,
        drop_last=True,
    )
    print("Training samples:", len(train_dataloader))

    test_dataset = instantiate_from_config(config.val_dataloader)
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=config.val_dataloader.batch_size,
        num_workers=config.val_dataloader.num_workers,
        shuffle=False,
        pin_memory=False,
        drop_last=False,
    )
    print("Validation samples:", len(test_dataloader))

    return train_dataset, train_dataloader, test_dataloader


def setup_tensorboard_logger(runner, config, opt, sequence_name):
    log_dir = generate_path_to_logs(config, opt, sequence_name)
    if opt.test_mode:
        log_dir += '-test'
    os.makedirs(log_dir, exist_ok=False)
    runner.logger = SummaryWriter(log_dir)


def setup_callbacks(runner, config):
    callbacks = []
    for callback_config in config.callbacks.values():
        callbacks.append(instantiate_from_config(callback_config))
    runner.set_callbacks(callbacks)

In [3]:
parser = get_parser()

# opt, unknown = parser.parse_known_args(['--base=./configs/gaussians_docker_male3.yaml', '--pretrained=./logs/gaussians_docker_male3/male-3-casual-2024_12-20_21-04/checkpoints/OPTIMIZE_OPACITY_10500.ckpt', '--test_model'])

opt, unknown = parser.parse_known_args(['--base=./configs/gaussians_docker_male3.yaml', '--pretrained=./logs/gaussians_docker_male3/male-3-casual-2024_12-20_21-04/checkpoints/OPTIMIZE_OPACITY_10500.ckpt', '--test_model'])

configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)

runner = instantiate_from_config(config.runner)
runner.to(runner.device)

setup_callbacks(runner, config)
runner.load_checkpoint(opt.pretrained)
test_dataset, test_dataloader = create_test_datasets(config)
setup_tensorboard_logger(runner, config, opt, test_dataset.sequence_name)
runner.initialize_optimizable_pose(test_dataset)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /home/user/miniconda/envs/textured-avatar/lib/python3.8/site-packages/lpips/weights/v0.1/vgg.pth


  self._gaussian_to_face = torch.nn.Parameter(torch.range(0, npoints - 1, dtype=torch.long, device=self.device),


Loaded Gaussians: 13541
Test samples: 14


In [None]:
runner.save_ply('./male3.ply',runner._xyz, runner._color, runner._opacity, runner._scaling,runner._rotation)

In [4]:
def setTPoseAndZero(batch):
    batch['smplx_params']['global_orient'][0] = torch.zeros(3)
    batch['smplx_params']['transl'][0] = torch.zeros(3)
    batch['smplx_params']['body_pose'][0] = torch.zeros([1,63])
    batch['smplx_params']['left_hand_pose'][0] = torch.zeros([1,12])
    batch['smplx_params']['right_hand_pose'][0] = torch.zeros([1,12])
    batch['smplx_params']['jaw_pose'][0] = torch.zeros([1,3])
    batch['smplx_params']['leye_pose'][0] = torch.zeros([1,3])
    batch['smplx_params']['reye_pose'][0] = torch.zeros([1,3])
    batch['smplx_params']['expression'][0] = torch.zeros([1,10])
    return batch

In [5]:
class TrainingStage(Enum):
    INIT_TEXTURE = "INIT_TEXTURE"
    OPTIMIZE_GAUSSIANS = "OPTIMIZE_GAUSSIANS"
    FINETUNE_TEXTURE = "FINETUNE_TEXTURE"
    OPTIMIZE_OPACITY = "OPTIMIZE_OPACITY"
    FINETUNE_POSE = "FINETUNE_POSE"

In [5]:
it = iter(test_dataloader)
batch = next(it)
# batch = setTPoseAndZero(batch)
# data_dict = dict2device(batch, "cuda")
# output = runner.predict_smplx_vertices(batch, calc_gaussians=True)
# runner._render_frame(data_dict, "FINETUNE_POSE")
# for key in output.keys():
    # output[key] = output[key][0,:,:]


In [7]:
print(batch["smplx_params"]["transl"].shape)

torch.Size([4, 1, 3])


In [12]:
batch['pid']

['00000', '00001', '00002', '00003']

In [None]:
runner.save_ply('./male3.ply',runner._xyz, runner._color, runner._opacity, runner._scaling,runner._rotation)

In [None]:
output = runner.predict_smplx_vertices(batch, calc_gaussians=True)
# runner.save_ply('./male3-test.ply',runner._xyz, runner._color, runner._opacity, runner._scaling,runner._rotation)
for key in output.keys():
    output[key] = output[key][1,:,:]

In [None]:
for k in output.keys():
    print(k)
    

In [None]:
for k in output.keys():
    print(k)
runner.save_ply('./male3-test.ply',
                output['gaussians_xyz'],
                output['gaussians_colors'],
                output['gaussians_opacity'],
                output['gaussians_scales'],
                output['gaussians_rotations'])

In [None]:
data = torch.load(opt.pretrained)
print(opt.pretrained)


In [51]:
data = {}
data['state_dict'] = runner.state_dict()

In [None]:
state_dict = data['state_dict']
for key in list(state_dict.keys()):
    if key.startswith('_body_pose_dict'):
        del state_dict[key]
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("lpips")}
# print(state_dict['_smplx_model.lbs_weights'].shape)
# print(state_dict['_smplx_model.pose_mean'].shape)
# for key in list(state_dict.keys()):
#     print(key)
# len(state_dict["_xyz"])
# state_dict = data
for key in list(state_dict.keys()):
    if key.startswith('_body_pose_dict'):
        del state_dict[key]
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("lpips")}
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("_xyz_gradient")}
state_dict['_faces'] = state_dict['_smplx_model.faces_tensor']
# state_dict = {k: v for k, v in state_dict.items() if not k.startswith("_smplx")}
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("_max")}
# print(state_dict['_smplx_model.lbs_weights'].shape)
# print(state_dict['_smplx_model.pose_mean'].shape)
for key,value in list(state_dict.items()):
    print(f'{key}: {value.shape}')
    # print(key)
# len(state_dict["_xyz"])
data = state_dict


In [None]:
state_dict['_smplx_model.faces_tensor']

In [53]:
def RemapToUnity(smplx_path,unity_path, data):
    # get canonical vertices, flat_hand_mean=True to align with unity default pose
    smplx_model = smplx.create(smplx_path, model_type='smplx', gender='male', flat_hand_mean=True)
    output = smplx_model(return_full_pose=True)
    canonical_vertices = output.vertices.detach().cpu().numpy().squeeze()

    # get unity vertices
    with open("ExportedVertices.json", "r") as f:
        unity_data = json.load(f)
    u = np.array(unity_data['vertices'])

    tolerance = 5e-5
    map = []
    # create mapping of canonical -> unity 
    for i, canon_v in enumerate(canonical_vertices):
        canon_v[0] *= -1
        # print(f'({i},{canon_v})')
        # Calculate distances between the current canonical vertex and all Unity vertices
        distances = np.linalg.norm(u - canon_v, axis=1)
        
        # Find the closest Unity vertex
        idx = np.argmin(distances)
        
        
        # Check if the closest vertex is within the tolerance
        if distances[idx] < tolerance:
            map.append(idx)
        else:
            map.append(-1)
    
    print(f"Mapped {len([i for i in map if i != -1])}/{len(canonical_vertices)} vertices successfully.")

    # remap face vertices
    if len([i for i in map if i != -1])==len(canonical_vertices):
        faces = data['_faces'].cpu().numpy()
        for i,face in enumerate(faces):
            x,y,z = face
            data['_faces'][i][0] = map[x]
            data['_faces'][i][1] = map[y]
            data['_faces'][i][2] = map[z]
            pass
            
        return data
            
    else:
        return None
        


    

In [8]:
from utils.data import pass_smplx_dict

def CheckAgainstSMPLX(smplx_path, params_dict, output):
    model_params = dict(model_path=smplx_path,
                            model_type='smplx',
                            gender='Male',
                            use_pca=True,
                            use_hands=True,
                            use_face=True,
                            num_pca_comps=12,
                            use_face_contour=False,
                            create_global_orient=False,
                            create_body_pose=False,
                            create_betas=False,
                            create_left_hand_pose=False,
                            create_right_hand_pose=False,
                            create_expression=False,
                            create_jaw_pose=False,
                            create_leye_pose=False,
                            create_reye_pose=False,
                            create_transl=False,
                            flat_hand_mean=False,
                            dtype=torch.float32,
                            )


    smplx_model = smplx.create(**model_params).cuda()
    # model_output = pass_smplx_dict(params_dict, smplx_model, "cuda")
    smplx_output = smplx_model(**params_dict)
    # canonical_vertices = model_output['vertices'].detach().cpu().numpy().squeeze()
    canonical_vertices = smplx_output.vertices.detach().cpu().numpy().squeeze()

    tolerance = 5e-5
    map = []
    u = output['vertices'].squeeze().detach().cpu().numpy()
    # create mapping of canonical
    for i, canon_v in enumerate(canonical_vertices):
        # canon_v[0] *= -1
        # print(f'({i},{canon_v})')
        # Calculate distances between the current canonical vertex and all Unity vertices
        distances = np.linalg.norm(u - canon_v, axis=1)
        
        # Find the closest Unity vertex
        idx = np.argmin(distances)
        
        
        # Check if the closest vertex is within the tolerance
        if distances[idx] < tolerance:
            map.append(idx)
        else:
            map.append(-1)
    
    print(f"Mapped {len([i for i in map if i != -1])}/{len(canonical_vertices)} vertices successfully.")

    # # remap face vertices
    # if len([i for i in map if i != -1])==len(canonical_vertices):
    #     faces = data['_faces'].cpu().numpy()
    #     for i,face in enumerate(faces):
    #         x,y,z = face
    #         data['_faces'][i][0] = map[x]
    #         data['_faces'][i][1] = map[y]
    #         data['_faces'][i][2] = map[z]
    #         pass
            
    #     return data
            
    # else:
    #     return None
        


    

In [None]:
runner.load_checkpoint(opt.pretrained)
test_dataset, test_dataloader = create_test_datasets(config)
it = iter(test_dataloader)
batch = next(it)
smplx_params = {}
for k,v in batch['smplx_params'].items():
    smplx_params[k] = v[1,...].cuda()
smplx_params["betas"] = runner._betas
output = runner.predict_smplx_vertices(batch, calc_gaussians=True)

for key in output.keys():
    output[key] = output[key][1,:,:].cuda()
smplx_path = '/mounted/home/dresden/repositories/HAHA/data'
# CheckAgainstSMPLX(smplx_path,smplx_params,output)
print('smplx_params...')
for k in smplx_params:
    print(f'k: {k}, shape: {smplx_params[k].shape}')
print('output...')
for k in output:
    print(f'k: {k}, shape: {output[k].shape}')

In [31]:
smplx_path = '/mounted/home/dresden/repositories/HAHA/data'
smplx_model = smplx.create(smplx_path, model_type='smplx', gender='male', flat_hand_mean=True)
output = smplx_model(return_full_pose=True)

In [None]:
for k, v in output.items():
    print(k)

In [None]:
output['full_pose'].shape[1]

In [None]:
CheckAgainstSMPLX(smplx_path,smplx_params,output)

In [None]:
for k in output.keys():
    print(k)
# runner.save_ply('./male3-test.ply',
#                 output['gaussians_xyz'],
#                 output['gaussians_colors'],
#                 output['gaussians_opacity'],
#                 output['gaussians_scales'],
#                 output['gaussians_rotations'])

In [None]:
data = {}
for k,v in smplx_params.items():
    data[k] = v
data['vertices'] = output['vertices']
data['offset_xyz'] = output['offset_xyz']
data['xyz'] = output['gaussians_xyz']
data['rot'] = output['gaussians_rotations']
data['scale'] = output['gaussians_scales']
data['color'] = output['gaussians_colors']
data['opacity'] = output['gaussians_opacity']
del data['camera_matrix']
del data['camera_transform']
for k in data.keys():
    print(f'k: {k}, shape: {data[k].shape}')

In [None]:
# remap faces
print(data['_faces'])
test = RemapToUnity('/mounted/home/dresden/repositories/HAHA/data',
             '/mounted/home/dresden/repositories/HAHA/ExportedVertices.json',
             data)
print(test['_faces'])

In [None]:
print(data['_faces'])

In [21]:
for key,value in data.items():
    if isinstance(value, torch.Tensor):
        data[key] = value.cpu().tolist()

with open("sample_output_dict.json",'w') as f:
    json.dump(data,f)

In [31]:
smplx_model = smplx.create('/mounted/home/dresden/repositories/HAHA/data', model_type='smplx', gender='female', flat_hand_mean=True)
output = smplx_model(return_full_pose=True)
verts = output.vertices
canonical_vertices = verts.detach().cpu().numpy().squeeze().tolist()

In [None]:
smplx_model = smplx.create('/mounted/home/dresden/repositories/HAHA/data', model_type='smplx', gender='female', flat_hand_mean=True)
output = smplx_model(return_full_pose=True)
for k in output.keys():
    print(k)

In [21]:
with open("smplxVerts.json",'w') as f:
    json.dump(canonical_vertices,f)

In [None]:
data['_faces'][0,:]

In [None]:
verts = verts.squeeze()

In [53]:
faceVerts = data['_faces'].long()
sampled = verts[faceVerts]
vec1 = sampled[:, 2] - sampled[:, 1]
vec2 = sampled[:, 0] - sampled[:, 1]
vec3 = sampled[:, 0] - sampled[:, 2]

In [None]:
sampled[:, 2][0,:]

In [None]:
# updating the modified state_dict used in unity

with open('updated_dict.json','r') as f:
    data = json.load(f)
    data['_faces'] = torch.tensor(data['_faces'], dtype=torch.int32)

print(data['_faces'])
result = RemapToUnity('/mounted/home/dresden/repositories/HAHA/data',
             '/mounted/home/dresden/repositories/HAHA/ExportedVertices.json',
             data)
print(result['_faces'])

In [None]:
model = np.load('/mounted/home/dresden/repositories/HAHA/data/smplx/SMPLX_FEMALE')

In [None]:
print(dir(smplx_model))

In [None]:
# Function to compute centroids for a set of faces and vertices
def compute_centroids(vertices, faces):
    centroids = []
    for face in faces:
        # Get the vertices of the face
        v0, v1, v2 = vertices[face[0]], vertices[face[1]], vertices[face[2]]
        centroid = (v0 + v1 + v2) / 3.0  # Compute centroid as the average of the vertices
        centroids.append(centroid)
    return np.array(centroids)
# Compute centroids for SMPL-X and Unity faces
smplx_centroids = compute_centroids(smplx_vertices, smplx_faces)
unity_centroids = compute_centroids(unity_vertices, unity_faces)
# Function to remap SMPL-X faces based on normals and centroids
def remap_faces_refined(smplx_normals, smplx_centroids, smplx_faces, unity_normals, unity_centroids, unity_faces, w_normal=1.0, w_centroid=1.0):
    remapped_faces = []
    for smplx_normal, smplx_centroid, smplx_face in zip(smplx_normals, smplx_centroids, smplx_faces):
        # Compute the difference in normals and centroids
        normal_differences = np.linalg.norm(unity_normals - smplx_normal, axis=1)
        centroid_differences = np.linalg.norm(unity_centroids - smplx_centroid, axis=1)
        # Weighted distance metric
        distances = w_normal * normal_differences + w_centroid * centroid_differences
        # Find the Unity face with the smallest weighted distance
        closest_idx = np.argmin(distances)
        remapped_faces.append(unity_faces[closest_idx])
    return np.array(remapped_faces)
# Perform the refined remapping
remapped_faces_refined = remap_faces_refined(
    smplx_normals, smplx_centroids, smplx_faces, unity_normals, unity_centroids, unity_faces, w_normal=1.0, w_centroid=1.0
)
# Save the refined remapped faces to a file
remapped_faces_refined_path = 'RefinedRemappedUnityFaces.txt'
np.savetxt(remapped_faces_refined_path, remapped_faces_refined, fmt='%d', delimiter=',')
remapped_faces_refined_path

- pid
- smplx_params


In [27]:
class DataLogger:
    def __init__(self, tpath, apath, cpath):
        self.T = self.load_pose(tpath)
        self.A = self.load_pose(apath)
        self.C = self.load_pose(cpath)

    def load_pose(self,path):
        camera_matrix = torch.tensor([[[4.9270 * 3, 0.0000, -0.0519, 0.0000],
                                       [0.0000, 4.9415 * 3, 0.0000, 0.0000],
                                       [0.0000, 0.0000, 1.0001, -0.0101],
                                       [0.0000, 0.0000, 1.0000, 0.0000]]], device='cuda')
        camera_transform = torch.tensor([[[1., 0., 0., 0.],
                                          [0., 1., 0., 0.9],
                                          [0., 0., 1., 15.2122],
                                          [0., 0., 0., 1.]]], device="cuda")

        with open(path, 'r') as f:
            data = json.load(f)

        smplx_params = {}
        smplx_params["transl"] = torch.FloatTensor(data["transl"]).unsqueeze(0).unsqueeze(0)
        smplx_params["global_orient"] = torch.FloatTensor(data["global_orient"]).unsqueeze(0).unsqueeze(0)
        smplx_params["body_pose"] = torch.FloatTensor(data["body_pose"]).unsqueeze(0).unsqueeze(0)
        smplx_params["right_hand_pose"] = torch.FloatTensor(data["right_hand_pose"]).unsqueeze(0).unsqueeze(0)
        smplx_params["left_hand_pose"] = torch.FloatTensor(data["left_hand_pose"]).unsqueeze(0).unsqueeze(0)
        smplx_params["leye_pose"] = torch.FloatTensor(data["leye_pose"]).unsqueeze(0).unsqueeze(0)
        smplx_params["reye_pose"] = torch.FloatTensor(data["reye_pose"]).unsqueeze(0).unsqueeze(0)
        smplx_params["jaw_pose"] = torch.FloatTensor(data["jaw_pose"]).unsqueeze(0).unsqueeze(0)
        smplx_params["expression"] = torch.FloatTensor(data["expression"]).unsqueeze(0).unsqueeze(0)
        smplx_params["betas"] = torch.FloatTensor(data["betas"]).unsqueeze(0).unsqueeze(0)
        smplx_params["camera_matrix"] = camera_matrix.unsqueeze(0)
        smplx_params["camera_transform"] = camera_transform.unsqueeze(0)

        return smplx_params

        

    def _log_data_val(self, save_folder, batch, global_step):
        pass


    def on_train_batch_end(
            self,
            runner,
            outputs
    ):
        pass

    def on_validation_batch_end(
            self,
            runner,
            outputs
    ):
        # save_folder = os.path.join(runner.logger.log_dir, 'val')
        # self._log_data_val(save_folder, outputs, runner.global_step)
        pass

    def on_test_batch_end(
            self,
            runner,
            outputs
    ):
        pass

    def on_test_end(self, runner):
        save_folder = os.path.join(runner.logger.log_dir, 'data')
        data_dict = {}
        data_dict["pid"] = ["000"]
        name = ["T", "A", "C"]

        for name,smplx_params in zip(name,[self.T, self.A, self.C]):

            smplx_params["transl"] = smplx_params["transl"].to(runner.device)
            smplx_params["global_orient"] = smplx_params["global_orient"].to(runner.device)
            smplx_params["body_pose"] = smplx_params["body_pose"].to(runner.device)
            smplx_params["right_hand_pose"] = smplx_params["right_hand_pose"].to(runner.device)
            smplx_params["left_hand_pose"] = smplx_params["left_hand_pose"].to(runner.device)

            data_dict["smplx_params"] = smplx_params
            
            with torch.no_grad():
                output = runner.predict_smplx_vertices(data_dict, calc_gaussians=True)
            
            for key in output.keys():
                output[key] = output[key][0,...]
            
            runner.save_ply(os.path.join(save_folder, f'output_{name}.ply'),
                            output['gaussians_xyz'],
                            output['gaussians_colors'],
                            output['gaussians_opacity'],
                            output['gaussians_scales'],
                            output['gaussians_rotations'])
            
            for key,value in output.items():
                if isinstance(value, torch.Tensor):
                    output[key] = value.cpu().tolist()
            with open(os.path.join(save_folder, f'output_{name}.json','w')) as f:
                json.dump(output,f)

In [25]:
tpath = '/mounted/home/dresden/repositories/HAHA/T.json'
apath = '/mounted/home/dresden/repositories/HAHA/A.json'
cpath = '/mounted/home/dresden/repositories/HAHA/C.json'
d = DataLogger(tpath,apath,cpath)

In [17]:
d.T["transl"].shape

torch.Size([1, 1, 3])

In [28]:
d.on_test_end(runner)

torch.Size([1, 13541, 3])
torch.Size([1, 13541, 4])
torch.Size([1, 13541, 1])
torch.Size([1, 13541, 3])
torch.Size([1, 13541, 3])


KeyboardInterrupt: 