In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networks.diffusion_network as diffusion_network
import os
import utils.geometry_util as geometry_util
import utils.shape_util as shape_util
from tqdm import tqdm

import trimesh

scene = trimesh.Scene()

In [None]:
import os

start_dim = 0

feature_dim = 64
evecs_per_support = 4
n_block = 6

input_type = 'learned'
lapl_type = 'mesh'

train_folder = 'FAUST_rot_xyz_180_scaling_0.9_1.1'
test_folder = 'FAUST_rot_xyz_180_scaling_0.9_1.1'

chkpt_name = f'sign_double_start_{start_dim}_feat_{feature_dim}_{n_block}block_factor{evecs_per_support}_dataset_{train_folder}_{input_type}'



experiment_dir = f'/home/s94zalek_hpc/shape_matching/my_code/experiments/{chkpt_name}'
os.makedirs(experiment_dir)



device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = diffusion_network.DiffusionNet(
    in_channels=256,
    out_channels=feature_dim // evecs_per_support,
    cache_dir=None,
    input_type='wks',
    k_eig=128,
    n_block=n_block,
    ).to(device)

opt = torch.optim.Adam(net.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.LinearLR(
    opt, start_factor=1, end_factor=0.1, 
    total_iters=50000)


In [None]:
import my_code.sign_canonicalization.training as sign_training

train_shapes, train_diff_folder = sign_training.load_cached_shapes(
    f'/home/s94zalek_hpc/shape_matching/data_sign_training/train/{train_folder}',
    lapl_type='mesh'
)

test_shapes, test_diff_folder = sign_training.load_cached_shapes(
    f'/home/s94zalek_hpc/shape_matching/data_sign_training/test/{test_folder}',
    lapl_type='mesh'
)

In [None]:
# add 4 random training shapes to trimesh scene

# np.random.shuffle(train_shapes)
scene.geometry.clear()

rand_idx_train = np.random.randint(0, len(train_shapes), 5)
rand_idx_test = np.random.randint(0, len(test_shapes), 5)

for i, idx in enumerate(rand_idx_train):
    scene.add_geometry(trimesh.Trimesh(
        vertices=train_shapes[idx]['verts'] + torch.tensor([i, 0, 0]),
        faces=train_shapes[idx]['faces']))
    
for i, idx in enumerate(rand_idx_test):
    scene.add_geometry(trimesh.Trimesh(
        vertices=test_shapes[idx]['verts'] + torch.tensor([i, -1, 0]),
        faces=test_shapes[idx]['faces']))
    
axis = trimesh.creation.axis(axis_length=1)
scene.add_geometry(axis)
scene.show()

In [None]:

def predict_sign_change(net, verts, faces, evecs_flip, evecs_cond, input_type, cond_net=None):
    
    # normalize the evecs
    evecs_flip = torch.nn.functional.normalize(evecs_flip, p=2, dim=1)
    
    if evecs_cond is not None:
        evecs_cond = torch.nn.functional.normalize(evecs_cond, p=2, dim=1)
        evecs_input = torch.cat([evecs_flip, evecs_cond], dim=-1)
        
    elif input_type == 'wks':
        evecs_input = None
    
    elif input_type == 'learned':
        evecs_input = cond_net(verts=verts, faces=faces)
        evecs_input = torch.nn.functional.normalize(evecs_input, dim=-1, p=2)
    
    else:
        evecs_input = evecs_flip
        
    # process the flipped evecs
    support_vector_flip = net(
        verts=verts,
        faces=faces,
        feats=evecs_input,
    ) # [1 x 6890 x 1]

    # normalize the support vector
    support_vector_norm = torch.nn.functional.normalize(support_vector_flip, p=2, dim=1)
    
    if support_vector_norm.shape[-1] != evecs_flip.shape[-1]:
        # copy each element to match the number of evecs
        assert evecs_flip.shape[-1] % support_vector_norm.shape[-1] == 0
        
        repeat_factor = evecs_flip.shape[-1] // support_vector_norm.shape[-1]
        
        support_vector_norm_repeated = torch.repeat_interleave(
            support_vector_norm, repeat_factor, dim=-1)
    else:
        support_vector_norm_repeated = support_vector_norm
        
    
    # multiply the support vector by the flipped evecs [1 x 6890 x 4].T @ [1 x 6890 x 4]
    product_with_support = support_vector_norm_repeated.transpose(1, 2) @ evecs_flip

    if product_with_support.shape[1] == product_with_support.shape[2]:
        # take only diagonal elements
        sign_flip_predicted = torch.diagonal(product_with_support, dim1=1, dim2=2)
        
    # get the sign of the support vector
    # sign_flip_predicted = product_with_support
 
    return sign_flip_predicted, support_vector_norm, product_with_support


In [None]:
cond_net = diffusion_network.DiffusionNet(
    in_channels=128,
    out_channels=256,
    cache_dir=None,
    input_type='wks'
    ).to(device)

cond_ckpt = torch.load('/home/s94zalek_hpc/shape_matching/checkpoints/faust.pth')
cond_net.load_state_dict(cond_ckpt['networks']['feature_extractor'])

# set the cond_net to be untrainable
for param in cond_net.parameters():
    param.requires_grad = False

In [None]:
tqdm._instances.clear()

loss_fn = torch.nn.MSELoss()
losses = torch.tensor([])
train_iterator = tqdm(range(40001))
    
net.cache_dir = train_diff_folder    
cond_net.cache_dir = train_diff_folder  
        
curr_iter = 0
for epoch in range(len(train_iterator) // len(train_shapes)):
    
    # train_shapes_shuffled = train_shapes.copy()
    np.random.shuffle(train_shapes)
    
    
    for curr_idx in range(len(train_shapes)):     

        ##############################################
        # Select a shape
        ##############################################
        # curr_idx = np.random.randint(0, len(train_shapes))
    
        train_shape = train_shapes[curr_idx]

        verts = train_shape['verts'].unsqueeze(0).to(device)
        
        if lapl_type == 'pcl':
            faces = None
        else:
            faces = train_shape['faces'].unsqueeze(0).to(device)    

        evecs_orig = train_shape['evecs'][:, start_dim:start_dim+feature_dim].unsqueeze(0).to(device)


        ##############################################
        # Set the signs on shape 0
        ##############################################

        # create a random combilation of +1 and -1, length = feature_dim
        sign_gt_0 = torch.randint(0, 2, (feature_dim,)).float().to(device)
        
        sign_gt_0[sign_gt_0 == 0] = -1
        sign_gt_0 = sign_gt_0.float().unsqueeze(0)

        # multiply evecs [6890 x 16] by sign_flip [16]
        evecs_flip_0 = evecs_orig * sign_gt_0
        
        # predict the sign change
        sign_pred_0 = predict_sign_change(net, verts, faces, evecs_flip_0, 
                                                evecs_cond=None, input_type=input_type,
                                                cond_net=cond_net)[0]
        
        ##############################################
        # Set the signs on shape 1
        ##############################################
        
        # create a random combilation of +1 and -1, length = feature_dim
        sign_gt_1 = torch.randint(0, 2, (feature_dim,)).float().to(device)
        
        sign_gt_1[sign_gt_1 == 0] = -1
        sign_gt_1 = sign_gt_1.float().unsqueeze(0)
        
        # multiply evecs [6890 x 16] by sign_flip [16]
        evecs_flip_1 = evecs_orig * sign_gt_1
        
        # predict the sign change
        sign_pred_1 = predict_sign_change(net, verts, faces, evecs_flip_1, 
                                                evecs_cond=None, input_type=input_type,
                                                cond_net=cond_net)[0]
        
        ##############################################
        # Calculate the loss
        ##############################################
        
        # calculate the ground truth sign difference
        sign_diff_gt = sign_gt_1 * sign_gt_0
        
        # calculate the sign difference between predicted evecs
        sign_diff_pred = sign_pred_1 * sign_pred_0
        
        # calculate the loss
        loss = loss_fn(
            sign_diff_pred.reshape(sign_diff_pred.shape[0], -1),
            sign_diff_gt.reshape(sign_diff_gt.shape[0], -1)
            )

        opt.zero_grad()
        loss.backward()
        opt.step()
        scheduler.step()
        
        losses = torch.cat([losses, torch.tensor([loss.item()])])
        
        # print mean of last 10 losses
        train_iterator.set_description(f'loss={torch.mean(losses[-10:]):.3f}')
        
        # plot the losses every 1000 iterations
        if curr_iter > 0 and curr_iter % (len(train_iterator) // 10) == 0:
            pd.Series(losses.numpy()).rolling(10).mean().plot()
            plt.yscale('log')
            # plt.show()
            
            plt.savefig(f'{experiment_dir}/losses_{curr_iter}.png')
            plt.close()
            
        curr_iter += 1
        train_iterator.update(1)
        
        
# save model checkpoint
torch.save(
    net.state_dict(),
    f'{experiment_dir}/{curr_iter}.pth')

In [None]:
from tqdm import tqdm
import utils.geometry_util as geometry_util
import robust_laplacian
import scipy.sparse.linalg as sla
import utils.geometry_util as geometry_util

tqdm._instances.clear()

shapes_to_test = test_shapes
net.cache_dir = test_diff_folder
cond_net.cache_dir = test_diff_folder

# shapes_to_test = train_shapes
# net.cache_dir = train_diff_folder
         
             
              
iterator = tqdm(range(1000))
incorrect_signs_list = torch.tensor([])
curr_iter = 0

for epoch in range(len(iterator) // len(shapes_to_test)):
    
    # train_shapes_shuffled = train_shapes.copy()
    # np.random.shuffle(test_shapes_list)
    
    
    for curr_idx in range(len(shapes_to_test)):     


        ##############################################
        # Select a shape
        ##############################################
        
        test_shape = shapes_to_test[curr_idx]    
        
        verts = test_shape['verts'].unsqueeze(0).to(device)
        faces = test_shape['faces'].unsqueeze(0).to(device)
        evecs_orig = test_shape['evecs'][:, start_dim:start_dim+feature_dim].unsqueeze(0).to(device)

        ##############################################
        # Set the signs on shape 0
        ##############################################

        # create a random combilation of +1 and -1, length = feature_dim
        sign_gt_0 = torch.randint(0, 2, (feature_dim,)).float().to(device)
        
        sign_gt_0[sign_gt_0 == 0] = -1
        sign_gt_0 = sign_gt_0.float().unsqueeze(0)

        # multiply evecs [6890 x 16] by sign_flip [16]
        evecs_flip_0 = evecs_orig * sign_gt_0
        
        # predict the sign change
        with torch.no_grad():
            sign_pred_0, supp_vec_0, _ = predict_sign_change(
                net, verts, faces, evecs_flip_0, evecs_cond=None, input_type=input_type,
                cond_net=cond_net)
        
        ##############################################
        # Set the signs on shape 1
        ##############################################
        
        # create a random combilation of +1 and -1, length = feature_dim
        sign_gt_1 = torch.randint(0, 2, (feature_dim,)).float().to(device)
        
        sign_gt_1[sign_gt_1 == 0] = -1
        sign_gt_1 = sign_gt_1.float().unsqueeze(0)
        
        # multiply evecs [6890 x 16] by sign_flip [16]
        evecs_flip_1 = evecs_orig * sign_gt_1
        
        # predict the sign change
        with torch.no_grad():
            sign_pred_1, supp_vec_1, _ = predict_sign_change(
                net, verts, faces, evecs_flip_1, evecs_cond=None, input_type=input_type, cond_net=cond_net)
        
        ##############################################
        # Calculate the loss
        ##############################################
        
        # calculate the ground truth sign difference
        sign_diff_gt = sign_gt_1 * sign_gt_0
        
        # calculate the sign difference between predicted evecs
        sign_diff_pred = sign_pred_1 * sign_pred_0
        
        sign_correct = sign_diff_pred.sign() * sign_diff_gt.sign() 
        
        
        # count the number of incorrect signs
        count_incorrect_signs = (sign_correct < 0).int().sum()
            
        # incorrect_signs_list.append(count_incorrect_signs)
        incorrect_signs_list = torch.cat([incorrect_signs_list, torch.tensor([count_incorrect_signs])])
        
        
        iterator.set_description(f'Mean incorrect signs {incorrect_signs_list.float().mean():.2f} / {feature_dim}')
        iterator.update(1)
        # if count_incorrect_signs > 7:
        #     raise ValueError('Too many incorrect signs')
    
    
print(f'Results for {len(incorrect_signs_list)} test shapes')
print(f'Incorrect signs per shape: {incorrect_signs_list.float().mean():.2f} / {feature_dim}')

print('Max incorrect signs', incorrect_signs_list.max())

print()
# print('Shape idx', curr_idx)
print('GT', sign_diff_gt)
print('PRED', sign_diff_pred)
print('Correct', sign_correct)
print(f'Incorrect signs {torch.sum(sign_correct != 1)} / {feature_dim}')
print(incorrect_signs_list)


# plt.plot(support_vector_norm.squeeze().detach().cpu().numpy(), '.', alpha=0.1)
# plt.ylim(-0.1, 0.1)
# # plt.yscale('log')
# plt.show()