In [22]:
import yaml
import os.path as osp
import os

import numpy as np
import matplotlib.pyplot as plt
import igl
from tqdm import tqdm

import torch

from preprocess import mesh_sampling_method
from dataset import MeshData
from models import VAE_coma, GraphPredictor

In [23]:
# read into the config file
config_path = 'config/general_config.yaml'
with open(config_path, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

# set the device, we can just assume we are using single GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = config["dataset"]
template = config["template"]
data_dir = osp.join('data', dataset)
template_fp = osp.join('template', template)

# get the up/down sampling matrix
ds_factor = config["ds_factor"]
edge_index_list, down_transform_list, up_transform_list = mesh_sampling_method(data_fp=data_dir,
                                                                                template_fp=template_fp,
                                                                                ds_factors=ds_factor,
                                                                                device=device)

# create the model
in_channels = config["model"]["in_channels"]
out_channels = config["model"]["out_channels"]
latent_channels = config["model"]["latent_channels"]
K = config["model"]["K"]

# get the mean and std of the CAESAR dataset(Traing set)
CAESAR_meshdata = MeshData(root=data_dir, template_fp=template_fp)
# of shape (10002, 3). is torch.Tensor
caesar_mean = CAESAR_meshdata.mean
caesar_std = CAESAR_meshdata.std

Normalizing the training and testing dataset...
Normalization Done!


## Load the predictor and the CoMA

In [24]:
predictor = GraphPredictor(in_channels = in_channels, out_channels = out_channels,
                        edge_index = edge_index_list, down_transform = down_transform_list, K=K).to(device)
predictor.load_state_dict(torch.load("predictor_network/predictor.pth"))

height_mean = 1716.4373
height_std = 107.98842

arm_length_mean = 612.611
arm_length_std = 45.986

crotch_height_mean = 773.540
crotch_height_std = 55.731

chest_mean = 996.745
chest_std = 124.099

hip_mean = 1050.170
hip_std = 113.026

waist_mean = 848.005
waist_std = 144.338

In [25]:
model = VAE_coma(in_channels = in_channels,
                out_channels = out_channels,
                latent_channels = latent_channels,
                edge_index = edge_index_list,
                down_transform = down_transform_list,
                up_transform = up_transform_list,
                K=K).to(device)

model.load_state_dict(torch.load("out/vae_coma/trainer6/20240117-190657/model.pth"))
# model.load_state_dict(torch.load("out/vae_coma/trainer12/20231220-123538/model.pth"))

<All keys matched successfully>

## Define the decoder function

In [26]:
def decode_func(x, model):
    num_layers = len(model.de_layers)
    num_deblocks = num_layers - 2
    for i, layer in enumerate(model.de_layers):
        if i == 0:
            x = layer(x)
            x = x.view(-1, model.num_vert, model.out_channels[-1])
        elif i != num_layers - 1:
            x = layer(x, model.edge_index[num_deblocks - i],
                        model.up_transform[num_deblocks - i])
        else:
            # last layer
            x = layer(x, model.edge_index[0])
    return x 

## generate mesh samples

In [27]:
# read into the template mesh
v, f = igl.read_triangle_mesh(template_fp)


def generate_a_mesh(model, latent_vector: np.ndarray, mean, std, fp, name):
    mean = mean.to(device)
    std = std.to(device)
    
    latent_vector = torch.Tensor(latent_vector).cuda()
    latent_vector = latent_vector.reshape(1, -1)
    
    v = decode_func(latent_vector, model)
    v = v.detach()
    v = v * std + mean
    v = v.cpu().numpy()
    v = v.reshape(-1, 3)
    
    # find the lowest vertex
    z_min = v[:, 2].min()
    # move the lowest vertex to the origin
    v[:, 2] = v[:, 2] - z_min
        
    if not osp.exists(fp):
        os.makedirs(fp)
    
    igl.write_triangle_mesh(osp.join(fp, name), v, f)

In [28]:
latent_vector = np.zeros(8)

# create a numpy array from [-4.5, 4.5] with 0.5 interval
values = np.arange(-4.5, 7.5, 1.5)

for i in range(len(values)):
    latent_vector[1] = values[i]
    generate_a_mesh(model, latent_vector, caesar_mean, caesar_std, "result/meshes/teaser2", "{:.1f}.obj".format(values[i]))