In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

import trimesh

scene = trimesh.Scene()

In [None]:
import my_code.datasets.shape_dataset as shape_dataset

train_dataset = shape_dataset.SingleScapeDataset(
    phase='train',
    data_root = 'data/SCAPE_r',
    centering = 'bbox',
    num_evecs=200,
    lb_cache_dir=f'data/SCAPE_r/diffusion'
)
test_dataset = shape_dataset.SingleScapeDataset(
    phase='test',
    data_root = 'data/SCAPE_r',
    centering = 'bbox',
    num_evecs=200,
    lb_cache_dir=f'data/SCAPE_r/diffusion'
)

In [None]:
train_shapes = [train_dataset[i] for i in range(len(train_dataset))]
test_shapes = [test_dataset[i] for i in range(len(test_dataset))]

train_diff_folder = 'data/SCAPE_r/diffusion'
test_diff_folder = 'data/SCAPE_r/diffusion'

In [None]:
# add 4 random training shapes to trimesh scene

# np.random.shuffle(train_shapes)
scene.geometry.clear()

rand_idx_train = np.random.randint(0, len(train_shapes), 5)
rand_idx_test = np.random.randint(0, len(test_shapes), 5)

for i, idx in enumerate(rand_idx_train):
    scene.add_geometry(trimesh.Trimesh(
        vertices=train_shapes[idx]['verts'] + torch.tensor([i, 0, 0]),
        faces=train_shapes[idx]['faces']))
    
for i, idx in enumerate(rand_idx_test):
    scene.add_geometry(trimesh.Trimesh(
        vertices=test_shapes[idx]['verts'] + torch.tensor([i, -1, 0]),
        faces=test_shapes[idx]['faces']))
    
axis = trimesh.creation.axis(axis_length=1)
scene.add_geometry(axis)
scene.show()

In [None]:
import networks.diffusion_network as diffusion_network

condition_dim = 0
start_dim = 0

feature_dim = 64
evecs_per_support = 4


device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = diffusion_network.DiffusionNet(
    in_channels=feature_dim,
    out_channels=feature_dim // evecs_per_support,
    cache_dir=None,
    input_type='wks',
    k_eig=128,
    n_block=6
    ).to(device)

opt = torch.optim.Adam(net.parameters(), lr=1e-3)
# add scheduler, decay by 0.1 every 30k iterations

scheduler = torch.optim.lr_scheduler.LinearLR(opt, start_factor=1, end_factor=0.1, total_iters=50000)

In [None]:
input_type = 'wks'

# net.load_state_dict(torch.load('/home/s94zalek_hpc/shape_matching/notebooks/03.07.2024/sign_double_start_0_feat_64_6ch_180xyz_09_11_factor4.pth'))
net.load_state_dict(torch.load('/home/s94zalek_hpc/shape_matching/my_code/experiments/sign_double_start_0_feat_64_6block_factor4_180xyz_09_11_wks/39360.pth'))
# net.load_state_dict(torch.load('/home/s94zalek_hpc/shape_matching/my_code/experiments/sign_double_start_0_feat_64_6block_factor4_180xyz_09_11_noise0.01_meshLapl_wks/39360.pth'))

In [None]:
from tqdm import tqdm
import utils.geometry_util as geometry_util
import robust_laplacian
import scipy.sparse.linalg as sla
import utils.geometry_util as geometry_util
import my_code.sign_canonicalization.training as sign_training

tqdm._instances.clear()

shapes_to_test = test_shapes
net.cache_dir = test_diff_folder

# shapes_to_test = train_shapes
# net.cache_dir = train_diff_folder
         
             
              
iterator = tqdm(range(1000))
incorrect_signs_list = torch.tensor([])
curr_iter = 0

for epoch in range(len(iterator) // len(shapes_to_test)):
    
    # train_shapes_shuffled = train_shapes.copy()
    # np.random.shuffle(test_shapes_list)
    
    
    for curr_idx in range(len(shapes_to_test)):     


        ##############################################
        # Select a shape
        ##############################################
        
        test_shape = shapes_to_test[curr_idx]    
        
        verts = test_shape['verts'].unsqueeze(0).to(device)
        faces = test_shape['faces'].unsqueeze(0).to(device)
        evecs_orig = test_shape['evecs'][:, start_dim:start_dim+feature_dim].unsqueeze(0).to(device)

        ##############################################
        # Set the signs on shape 0
        ##############################################

        # create a random combilation of +1 and -1, length = feature_dim
        sign_gt_0 = torch.randint(0, 2, (feature_dim,)).float().to(device)
        
        sign_gt_0[sign_gt_0 == 0] = -1
        sign_gt_0 = sign_gt_0.float().unsqueeze(0)

        # multiply evecs [6890 x 16] by sign_flip [16]
        evecs_flip_0 = evecs_orig * sign_gt_0
        
        # predict the sign change
        with torch.no_grad():
            sign_pred_0, supp_vec_0, _ = sign_training.predict_sign_change(net, verts, faces, evecs_flip_0, 
                                                evecs_cond=None, input_type=input_type)
        
        ##############################################
        # Set the signs on shape 1
        ##############################################
        
        # create a random combilation of +1 and -1, length = feature_dim
        sign_gt_1 = torch.randint(0, 2, (feature_dim,)).float().to(device)
        
        sign_gt_1[sign_gt_1 == 0] = -1
        sign_gt_1 = sign_gt_1.float().unsqueeze(0)
        
        # multiply evecs [6890 x 16] by sign_flip [16]
        evecs_flip_1 = evecs_orig * sign_gt_1
        
        # predict the sign change
        with torch.no_grad():
            sign_pred_1, supp_vec_1, _ = sign_training.predict_sign_change(net, verts, faces, evecs_flip_1, 
                                                evecs_cond=None, input_type=input_type)
        
        ##############################################
        # Calculate the loss
        ##############################################
        
        # calculate the ground truth sign difference
        sign_diff_gt = sign_gt_1 * sign_gt_0
        
        # calculate the sign difference between predicted evecs
        sign_diff_pred = sign_pred_1 * sign_pred_0
        
        sign_correct = sign_diff_pred.sign() * sign_diff_gt.sign() 
        
        
        # count the number of incorrect signs
        count_incorrect_signs = (sign_correct < 0).int().sum()
            
        # incorrect_signs_list.append(count_incorrect_signs)
        incorrect_signs_list = torch.cat([incorrect_signs_list, torch.tensor([count_incorrect_signs])])
        
        
        iterator.set_description(f'Mean incorrect signs {incorrect_signs_list.float().mean():.2f} / {feature_dim}')
        iterator.update(1)
        # if count_incorrect_signs > 7:
        #     raise ValueError('Too many incorrect signs')
    
    
print(f'Results for {len(incorrect_signs_list)} test shapes')
print(f'Incorrect signs per shape: {incorrect_signs_list.float().mean():.2f} / {feature_dim}')

print('Max incorrect signs', incorrect_signs_list.max())

print()
# print('Shape idx', curr_idx)
print('GT', sign_diff_gt)
print('PRED', sign_diff_pred)
print('Correct', sign_correct)
print(f'Incorrect signs {torch.sum(sign_correct != 1)} / {feature_dim}')
print(incorrect_signs_list)


# plt.plot(support_vector_norm.squeeze().detach().cpu().numpy(), '.', alpha=0.1)
# plt.ylim(-0.1, 0.1)
# # plt.yscale('log')
# plt.show()


In [None]:
# print number of trainable parameters
print(f'Number of trainable parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad)}')

In [None]:
# fig, axs = plt.subplots(3, 1, figsize=(5, 10))

# for i, idx in enumerate(range(1, 4)):
#     axs[i].plot(supp_vec_0[0, :, -idx].cpu(), '-')
plt.plot(supp_vec_0[0, :, -2].cpu(), '-')
plt.plot(supp_vec_1[0, :, -2].cpu(), '-')
plt.show()

In [None]:
scene.geometry.clear()

verts = test_shape['verts'].cpu().numpy()
faces = test_shape['faces'].cpu().numpy()

cmap = np.ones((verts.shape[0], 4))

# set cmap to 1 where supp_vec_0[0, :, -4] > 0.02
cmap[supp_vec_0[0, :, -2].cpu().abs() > 0.015, :2] = 0
# cmap *= 255

mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_colors=cmap)
scene.add_geometry(mesh)

scene.show()

In [None]:
# test the model with 1 summary per evec