In [1]:
import yaml
import os.path as osp

import numpy as np
import matplotlib.pyplot as plt
import igl
from tqdm import tqdm

import torch

from preprocess import mesh_sampling_method
from dataset import MeshData
from models import VAE_coma, GraphPredictor

In [2]:
# read into the config file
config_path = 'config/general_config.yaml'
with open(config_path, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

# set the device, we can just assume we are using single GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = config["dataset"]
template = config["template"]
data_dir = osp.join('data', dataset)
template_fp = osp.join('template', template)

# get the up/down sampling matrix
ds_factor = config["ds_factor"]
edge_index_list, down_transform_list, up_transform_list = mesh_sampling_method(data_fp=data_dir,
                                                                                template_fp=template_fp,
                                                                                ds_factors=ds_factor,
                                                                                device=device)

# create the model
in_channels = config["model"]["in_channels"]
out_channels = config["model"]["out_channels"]
latent_channels = config["model"]["latent_channels"]
K = config["model"]["K"]

# get the mean and std of the CAESAR dataset(Traing set)
CAESAR_meshdata = MeshData(root=data_dir, template_fp=template_fp)
# of shape (10002, 3). is torch.Tensor
caesar_mean = CAESAR_meshdata.mean
caesar_std = CAESAR_meshdata.std

  torch.LongTensor([spmat.tocoo().row,


Normalizing...
Done!


## Load the predictor and the CoMA

In [3]:
weight_predictor = GraphPredictor(in_channels = in_channels,
                        out_channels = out_channels,
                        edge_index = edge_index_list,
                        down_transform = down_transform_list,
                        K=K,
                        type="weight").to(device)
weight_predictor.load_state_dict(torch.load("predictor_network/weight/predictor.pth"))
weight_mean = 76.13506
weight_std = 19.45952

In [4]:
height_predictor = GraphPredictor(in_channels = in_channels,
                        out_channels = out_channels,
                        edge_index = edge_index_list,
                        down_transform = down_transform_list,
                        K=K,
                        type="height").to(device)
height_predictor.load_state_dict(torch.load("predictor_network/height/predictor.pth"))
height_mean = 1716.4373
height_std = 107.98842

In [5]:
model = VAE_coma(in_channels = in_channels,
                out_channels = out_channels,
                latent_channels = latent_channels,
                edge_index = edge_index_list,
                down_transform = down_transform_list,
                up_transform = up_transform_list,
                K=K).to(device)
model.load_state_dict(torch.load("out/vae_coma/weight/20230921-005450/model.pth"))

<All keys matched successfully>

## Define the decoder function

In [6]:
def decode_func(x, model):
    num_layers = len(model.de_layers)
    num_deblocks = num_layers - 2
    for i, layer in enumerate(model.de_layers):
        if i == 0:
            x = layer(x)
            x = x.view(-1, model.num_vert, model.out_channels[-1])
        elif i != num_layers - 1:
            x = layer(x, model.edge_index[num_deblocks - i],
                        model.up_transform[num_deblocks - i])
        else:
            # last layer
            x = layer(x, model.edge_index[0])
    return x 

## Define the computation process

In [7]:
def get_height_weight(model, sample_vals: np.ndarray,  
                      height_mean: float, height_std: float, weight_mean: float, weight_std: float, 
                      height_predictor, weight_predictor):
    """_summary_
        load the model from model_path
        sample points from the latent space
        return the latent values and the corresponding height values and weight values

    Args:
        model: the generative model
        sample_vals (np.ndarray): of shape (n_samples, ). We would sample n_samples from the latent space. And use the same vals for all 8 dimensions.
    """

    # Sample from the latent space
    # create a tensor of shape (n_samples, 8)
    height_vals = np.zeros((sample_vals.shape[0], 8))
    weight_vals = np.zeros((sample_vals.shape[0], 8))

    # for each latent dimension
    for i in tqdm(range(8)):
        latent_val = torch.zeros((sample_vals.shape[0], 8))

        latent_val[:, i] = torch.Tensor(sample_vals)
        latent_val = latent_val.to(device) 
        
        # decode the latent values
        # of shape (n_samples, 10002, 3)
        v = decode_func(latent_val, model)
        
        # get the weight and height of the mesh
        weight = weight_predictor(v)
        height = height_predictor(v)
        
        # convert to numpy array
        weight = weight.cpu().detach().numpy().reshape(-1) * weight_std + weight_mean
        height = height.cpu().detach().numpy().reshape(-1) * height_std + height_mean

        height_vals[:, i] = height / 1000
        weight_vals[:, i] = weight

    return height_vals, weight_vals

## Define the plot function

In [8]:
def plot_latent(filename: str, sample_vals: np.ndarray, height_vals: np.ndarray, type: str):
    """_summary_
        declare a plot of 8 subplots, 2 rows and 4 column
        each subplot is a dot plot of the corresponding latent dimension
        the x-axis is the latent value, the y-axis is the height
        the figure size is tight to the subplots
    Args:
        filename (str): the file name to save the plot.
        sample_vals (np.ndarray): of shape (n_samples, ). We would sample n_samples from the latent space. And use the same vals for all 8 dimensions.
        height_vals (np.ndarray): of shape (n_samples, 8). 8 is the number of latent dimensions.
    """
    
    if type not in ["weight", "height"]:
        raise ValueError("type should be either weight or height")

    # Ensure that the dimensions match up correctly
    if len(sample_vals) != height_vals.shape[0] or height_vals.shape[1] != 8:
        raise ValueError("Input dimensions are mismatched!")

    fig, axs = plt.subplots(2, 4, figsize=(23, 10), tight_layout=True)

    for i, ax in enumerate(axs.flat):
        # Plot each latent dimension
        ax.scatter(sample_vals, height_vals[:, i], marker='o')
        
        # set the x-axis limit to be between -1 and 1
        # set the y-axis limit to be between 1.5 and 2
        ax.set_xlim(-1, 1)
        if type == "weight":
            ax.set_ylim(60, 90)
        if type == "height":
            ax.set_ylim(1.6,1.8)
        
        ax.set_title("Latent dimension {}".format(i))

    # Save the figure to the given filename
    plt.savefig(filename)
    
    # Close the figure
    plt.close(fig)

In [9]:
sample_vals = np.linspace(-1, 1, 50)
height_vals, weight_vals = get_height_weight(model, sample_vals, height_mean, height_std, weight_mean, weight_std, height_predictor, weight_predictor)
plot_latent("result/weight/weight.pdf", sample_vals, weight_vals, type="weight")
plot_latent("result/weight/height.pdf", sample_vals, height_vals, type="height")

100%|██████████| 8/8 [00:01<00:00,  6.95it/s]


## generate mesh samples

In [10]:
# read into the template mesh
v, f = igl.read_triangle_mesh(template_fp)

# generate a banch of meshes samples from the latent space
def generate_mesh_samples(model, latent_vals: np.ndarray, mean, std, latent_dim):
    """_summary_
        generate meshes from the latent values
        save the meshes to the result folder

    Args:
        model_path (str): the path to the model file
        latent_vals (np.ndarray): of shape (n_samples,). The latent values for the first latent dimension.
    """

    # move the mean and std to the device
    mean = mean.to(device)
    std = std.to(device)


    # create a tensor of shape (n_samples, 8)
    latent_vectors = np.zeros((latent_vals.shape[0], 8))
    # set the latent dimension to be latent_vals
    latent_vectors[:, latent_dim] = latent_vals
    # convert to torch.Tensor
    latent_vectors = torch.Tensor(latent_vectors).cuda()

    # decode the latent values
    # of shape (n_samples, 10002, 3)
    v = decode_func(latent_vectors, model)
    # convert to numpy array
    v = v.detach()
    # denormalize the vertices
    v = v * std + mean

    v = v.cpu().numpy()

    # save the meshes
    for i in range(v.shape[0]):
        vertex = v[i]
        # find the lowest vertex
        z_min = vertex[:, 2].min()
        # move the lowest vertex to the origin
        vertex[:, 2] = vertex[:, 2] - z_min

        igl.write_triangle_mesh("result/weight/meshes/latent{}/{:.2f}.obj".format(latent_dim, latent_vals[i]), vertex, f)

In [11]:
latent_vals = np.array((-0.9, -0.6, -0.3, 0, 0.3, 0.6, 0.9))
generate_mesh_samples(model=model, latent_vals=latent_vals, mean=caesar_mean, std=caesar_std, latent_dim=0)
generate_mesh_samples(model=model, latent_vals=latent_vals, mean=caesar_mean, std=caesar_std, latent_dim=4)