In [None]:
from densematcher.model import MeshFeaturizer
from densematcher.utils import load_pytorch3d_mesh, recenter, get_groups_dmtx, get_uniform_SO3_RT
from densematcher.pyFM.mesh.geometry import geodesic_distmat_dijkstra, heat_geodmat_robust
from densematcher import diffusion_net
from densematcher.diffusion_net.utils import random_rotate_points, random_rotation_matrix
import os
import numpy as np
import torch
os.environ["INFERENCE"] = "1" # speeds up the model loading time by directly loading stuff onto GPU

In [None]:
width = 512 # number of channels in DiffusionNet
num_blocks = 8 # number of blocks in DiffusionNet
imsize = 384 # slightly affects accuracy, but not much

model = MeshFeaturizer(f"checkpoints/featup_imsize={imsize}_channelnorm=False_unitnorm=False_rotinv=True/final.ckpt",
                        (3, 1),
                        num_blocks,
                        width,
                        aggre_net_weights_folder="checkpoints/SDDINO_weights",
                        )


In [None]:
# load weights into model
ckpt_file = f"checkpoints/exp_mvmatcher_imsize={imsize}_width={width}_nviews=3x1_wrecon=10.0_cutprob=0.5_blocks={num_blocks}_release_jitter=0.0/final.ckpt"
ckpt = torch.load(ckpt_file)
state_dict = {}
for key in ckpt["state_dict"].keys():
    if key.startswith("model.extractor_3d"):
        state_dict[key.removeprefix("model.extractor_3d.")] = ckpt["state_dict"][key]
model.extractor_3d.load_state_dict(state_dict)

# move model to gpu
model.to("cuda:0").half()
model.extractor_2d.featurizer.mem_eff = True # tradeoff speed to save memory by forwarding one view at a time into 2D backbone

In [None]:
def get_mesh(instance, num_views=(1, 3), random_rotation=True):
    '''
    args:
        instance: path to object folder
        num_views: number of azimuth and elevations for the view grid. Does not count the north/south poles. Total #views is num_views[0] * num_views[1] + 2
        random_rotation: if True, randomly rotate the object
    return:
        mesh_color: PyTorch3D Mesh with texture/color. Assets are normalized to 0.3 on the largest axis
        mesh_simp: PyTorch3D Mesh with remeshed geometry
        groups: list of list of int, each sublist is a semantic group
        groups_dmtx: [num_groups, num_groups] semantic distance matrix between semantic groups, $D_{semantic}$ refered in the paper
        operators: tuple of diffusionnet operators
        cameras: tuple of rotation matrices and translations for camera extrinsics
        geodesic_distance: [V, V] geodesic distance between vertices
    '''
    mesh_color, _None, _, _None = load_pytorch3d_mesh(f"{instance}/color_mesh.obj")
    mesh_simp, _None, _, _None = load_pytorch3d_mesh(f"{instance}/simple_mesh.obj")
    
    # semantic groups is not used in inference, I place it here to illustrate how groups are computed
    groups = []
    with open(f"{instance}/groups.txt") as f:
        for line in f:
            line = line.strip()
            groups.append(list(map(int, line.split())))
    geodesic_distance = heat_geodmat_robust(mesh_simp.verts_packed().numpy(), mesh_simp.faces_packed().numpy()) # [V, V] geodesic distance between vertices
    groups_dmtx = get_groups_dmtx(geodesic_distance, groups) # [num_groups, num_groups] semantic distance matrix between semantic groups, $D_{semantic}$ refered in the paper        

    # move both meshes bounding box center to origin
    recenter(mesh_color, mesh_simp)
    # get rendering cameras
    bb = mesh_color.get_bounding_boxes()
    cam_dist = bb.abs().max() * (np.random.rand() * 0.5 + 2.0)

    # compute diffusionnet operators
    operators = diffusion_net.geometry.get_operators(
        mesh_simp.verts_list()[0].cpu(),

        k_eig=128, # default
        op_cache_dir=os.environ.get("OP_CACHE_DIR", None), # frames aren't rotation invariant but they aren't needed
        normals=mesh_simp.verts_normals_list()[0],
    )
    frames, mass, L, evals, evecs, gradX, gradY = operators # convert to dense, since dataloader workers cannot pickle sparse tensors
    operators = (frames, mass, L.to_dense(), evals, evecs, gradX.to_dense(), gradY.to_dense())
    
    # do random rotation
    if random_rotation:
        R_inv = torch.from_numpy(random_rotation_matrix()).type_as(mesh_simp.verts_packed())
    else:
        R_inv = torch.eye(3).to(frames)
    new_verts_color = torch.matmul(mesh_color.verts_padded(), R_inv)
    new_verts_simp = torch.matmul(mesh_simp.verts_padded(), R_inv)
    mesh_color = mesh_color.update_padded(new_verts_color)
    mesh_simp = mesh_simp.update_padded(new_verts_simp)
        
    # uniformly sample cameras around sphere
    Rs, ts, _, _ = get_uniform_SO3_RT(num_azimuth=num_views[0], num_elevation=num_views[1], distance=cam_dist, center=bb.mean(2))
    cameras = [Rs, ts]
    return mesh_color, mesh_simp, groups, groups_dmtx, operators, cameras, geodesic_distance
