In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

import trimesh

scene = trimesh.Scene()

In [None]:
import my_code.diffusion_training.data_loading as data_loading

train_dataset = data_loading.get_val_dataset(
    'FAUST_orig', 'train', 200, canonicalize_fmap=None
    )[1]
test_dataset = data_loading.get_val_dataset(
    'FAUST_orig', 'test', 200, canonicalize_fmap=None
)[1]

In [None]:
import os
import utils.geometry_util as geometry_util
import utils.shape_util as shape_util
from tqdm import tqdm

def load_cached_shapes(save_folder):

    # prepare the folders
    mesh_folder = f'{save_folder}/meshes'
    diff_folder = f'{save_folder}/diffusion'


    # get all meshes in the folder
    mesh_files = sorted([f for f in os.listdir(mesh_folder) if f.endswith('.off')])

    shapes_list = []
    for file in tqdm(mesh_files):

        verts, faces = shape_util.read_shape(os.path.join(mesh_folder, file))
        verts = torch.tensor(verts, dtype=torch.float32)
        faces = torch.tensor(faces, dtype=torch.int32)

        _, _, _, _, evecs, _, _ = geometry_util.get_operators(verts, faces,
                                                    k=128,
                                                    cache_dir=diff_folder)
        shapes_list.append({
            'verts': verts,
            'faces': faces,
            'evecs': evecs,
        })
        
    return shapes_list, diff_folder

In [None]:
train_folder = 'FAUST_rot_xyz_180_scaling_0.9_1.1'
train_shapes, train_diff_folder = load_cached_shapes(
    f'/home/s94zalek_hpc/shape_matching/data_sign_training/train/{train_folder}'
)

test_folder = 'FAUST_rot_xyz_180_scaling_0.9_1.1'
test_shapes, test_diff_folder = load_cached_shapes(
    f'/home/s94zalek_hpc/shape_matching/data_sign_training/test/{test_folder}'
)

In [None]:
from utils.geometry_util import get_operators
import my_code.utils.plotting_utils as plotting_utils
import robust_laplacian
import scipy.sparse.linalg as sla
import utils.geometry_util as geometry_util
import potpourri3d as pp3d

evec_n = 60

feature_dim = 64

scene.geometry.clear()

verts_0 = test_dataset[12]['second']['verts']
faces_0 = test_dataset[12]['second']['faces']

# verts_0 = test_shapes[91]['verts']
# faces_0 = test_shapes[91]['faces']

# evecs_0 = test_shapes[11]['evecs'][:, start_dim:start_dim+feature_dim]
# evecs_0 = torch.nn.functional.normalize(evecs_0, p=2, dim=0)

# verts_0 = verts_0 * 2.5

# L_0 = pp3d.cotan_laplacian(verts_0.numpy(), faces_0.numpy(), denom_eps=1e-10)
# M_0 = pp3d.vertex_areas(verts_0.numpy(), faces_0.numpy())
# M_0 += 1e-8 * np.mean(M_0)
# M_0 = np.diag(M_0).astype(np.float32)

# print(L_0.dtype, M_0.dtype)

L_0, M_0 = robust_laplacian.mesh_laplacian(verts_0.numpy(), faces_0.numpy())
# L_0, M_0 = robust_laplacian.point_cloud_laplacian(verts_0.numpy())
evals_0, evecs_0 = sla.eigsh(L_0, feature_dim, M_0, sigma=1e-8)
evecs_0 = torch.tensor(evecs_0)


verts_1_orig = test_dataset[12]['second']['verts']
faces_1 = test_dataset[12]['second']['faces']

verts_1 = geometry_util.data_augmentation(
    verts_1_orig.unsqueeze(0),
    rot_x=180.0, rot_y=180.0, rot_z=180.0,
    std=0,
    # scale_min=1, scale_max=1
    scale_min=0.9, scale_max=1.1
    )[0]

# verts_1 = test_shapes[111]['verts']
# faces_1 = test_shapes[111]['faces']


# evecs_1 = test_shapes[91]['evecs'][:, start_dim:start_dim+feature_dim]
# evecs_1 = torch.nn.functional.normalize(evecs_1, p=2, dim=0)

# L_1 = pp3d.cotan_laplacian(verts_1.numpy(), faces_1.numpy(), denom_eps=1e-10)
# M_1 = pp3d.vertex_areas(verts_1.numpy(), faces_1.numpy())
# M_1 += 1e-8 * np.mean(M_1)
# M_1 = np.diag(M_1).astype(np.float32)

L_1, M_1 = robust_laplacian.mesh_laplacian(verts_1.numpy(), faces_1.numpy())
# L_1, M_1 = robust_laplacian.point_cloud_laplacian(verts_1.numpy())
evals_1, evecs_1 = sla.eigsh(L_1, feature_dim, M_1, sigma=1e-8)
evecs_1 = torch.tensor(evecs_1)


cmap_0 = trimesh.visual.color.interpolate(
    torch.nn.functional.normalize(evecs_0[:, evec_n], p=2, dim=0),
    'bwr')

cmap_1 = trimesh.visual.color.interpolate(
    torch.nn.functional.normalize(evecs_1[:, evec_n], p=2, dim=0)
    , 'bwr')


# chng_by_evec = (evecs_0.abs() - evecs_1.abs()).abs().sum(dim=0)


fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].plot(evecs_0[:, evec_n].abs().cpu().numpy(), label='11')
axs[0].plot(evecs_1[:, evec_n].abs().cpu().numpy(), label='101')
axs[0].legend()

C_orig_rot = torch.linalg.lstsq(evecs_0, evecs_1).solution
plotting_utils.plot_Cxy(fig, axs[1], C_orig_rot,
                        'C_orig_rot', 0, 64, show_grid=False, show_colorbar=False)


plt.show()


mesh_0 = trimesh.Trimesh(vertices=verts_0, faces=faces_0, vertex_colors=cmap_0[:len(verts_0)])
mesh_1 = trimesh.Trimesh(vertices=verts_1 + np.array([1, 0, 0]), faces=faces_1,
                           vertex_colors=cmap_1[:len(verts_1)])

scene.add_geometry(mesh_0)
scene.add_geometry(mesh_1)

axis = trimesh.creation.axis(axis_length=1)
scene.add_geometry(axis)

scene.show()