In [None]:
import os
import networks.diffusion_network as diffusion_network
from tqdm import tqdm
import my_code.sign_canonicalization.training as sign_training
import my_code.sign_canonicalization.remesh as remesh
import torch
import my_code.diffusion_training_sign_corr.data_loading as data_loading
import yaml
import my_code.datasets.preprocessing as preprocessing
import trimesh
import argparse
import utils.fmap_util as fmap_util
import numpy as np
    

In [None]:
exp_name = 'signNet_remeshed_mass_6b_1ev_10_0.2_0.8'

exp_dir = f'/home/s94zalek_hpc/shape_matching/my_code/experiments/sign_net/{exp_name}'

with open(f'{exp_dir}/config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    
start_dim = config['start_dim']

feature_dim = config['feature_dim']
evecs_per_support = config['evecs_per_support']


device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = diffusion_network.DiffusionNet(
    **config['net_params']
    ).to(device)

input_type = config['net_params']['input_type']


log_file = f'{exp_dir}/log_landmarks.txt'

os.makedirs(os.path.dirname(log_file), exist_ok=True)


dataset_list = [
    # (config["train_folder"], 'train'),
    
    ('FAUST_a', 'test'),
    ('SHREC19_r', 'test'), 
    ('FAUST_r', 'test'),
    ('SCAPE_r_pair', 'test'),
    ('SCAPE_a_pair', 'test'),
]
    
# find the latest checkpoint in f'{exp_dir}/....pth'
checkpoint_files = os.listdir(exp_dir)
checkpoint_files = [f for f in checkpoint_files if f.endswith('.pth')]
checkpoint_files = [int(f.split('.')[0]) for f in checkpoint_files]
checkpoint_files = sorted(checkpoint_files)

last_checkpoint = checkpoint_files[-1]


for n_iter in [last_checkpoint]:
# for n_iter in [200, 600, 1000, 1400, 2000]:

    net.load_state_dict(torch.load(f'{exp_dir}/{n_iter}.pth'))


    for dataset_name, split in dataset_list:
        
        test_dataset_curr = data_loading.get_val_dataset(
            dataset_name, split, 128, canonicalize_fmap=None, preload=False, return_evecs=True, centering='mean'
            )[0]
            
        mean_incorrect_signs, max_incorrect_signs = test_on_dataset(
            test_dataset_curr, n_epochs=100, config=config)

            

In [None]:
def predict_sign_change_landmarks(evecs_flip, landmarks_vector):
    
    assert evecs_flip.dim() == 3
    assert landmarks_vector.dim() == 3
    
    # normalize the evecs
    evecs_flip = torch.nn.functional.normalize(evecs_flip, p=2, dim=1)
    
    # normalize the support vector
    landmarks_vector_norm = torch.nn.functional.normalize(landmarks_vector, p=2, dim=1)
        
    # multiply the support vector by the flipped evecs 
    # [1 x 6890 x 4].T @ [1 x 6890 x 6890] @ [1 x 6890 x 4]
    
    product_with_support = landmarks_vector_norm.transpose(1, 2) @ evecs_flip
    
    
    assert product_with_support.shape[1] == product_with_support.shape[2]
    
    # take the sign of diagonal elements
    sign_flip_predicted = torch.diagonal(product_with_support, dim1=1, dim2=2)
 
    return sign_flip_predicted, None, product_with_support


In [None]:

def test_on_dataset(test_dataset, n_epochs, feature_dim, n_landmarks):

    tqdm._instances.clear()
        
    iterator = tqdm(total=len(test_dataset) * n_epochs)
    incorrect_signs_list = torch.tensor([])
    curr_iter = 0
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    corr_shape = test_dataset[0]['corr'].shape
    
    # choose the landmarks
    landmarks_on_template = torch.tensor(np.random.choice(corr_shape[0], n_landmarks, replace=False))  
    

    for _ in range(n_epochs):
        for curr_idx in range(len(test_dataset)):

            ##############################################
            # Select a shape
            ##############################################

            train_shape = test_dataset[curr_idx]  
            
            assert train_shape['corr'].shape == corr_shape    
            
            ##############################################
            # Set the variables
            ##############################################

            evecs_orig = train_shape['evecs'].unsqueeze(0)[:, :, :feature_dim].to(device)
            
            ##############################################
            # Get the landmarks on the shape
            ##############################################
            
            landmarks_on_shape = train_shape['corr'][landmarks_on_template]
            
            # vector where the landmarks are 1 and the rest is 0, shape [1 x 6890 x 32]
            landmarks_vector = torch.zeros_like(evecs_orig)
            landmarks_vector[:, landmarks_on_shape] = 1   
            
            # print(landmarks_on_shape)
            
            # print(landmarks_vector.shape, landmarks_vector)       
            
            # return
            
            ##############################################
            # Set the signs on shape 0
            ##############################################

            # create a random combilation of +1 and -1, length = feature_dim
            sign_gt_0 = torch.randint(0, 2, (feature_dim,)).float().to(device)
            
            sign_gt_0[sign_gt_0 == 0] = -1
            sign_gt_0 = sign_gt_0.float().unsqueeze(0)

            # print('evecs_orig', evecs_orig.shape, 'sign_gt_0', sign_gt_0.shape)

            # multiply evecs [6890 x 16] by sign_flip [16]
            evecs_flip_0 = evecs_orig * sign_gt_0
            
            
            
            # predict the sign change
            sign_pred_0, _, _ = predict_sign_change_landmarks(
                evecs_flip_0, landmarks_vector
                )
            
            ##############################################
            # Set the signs on shape 1
            ##############################################
            
            # create a random combilation of +1 and -1, length = feature_dim
            sign_gt_1 = torch.randint(0, 2, (feature_dim,)).float().to(device)
            
            sign_gt_1[sign_gt_1 == 0] = -1
            sign_gt_1 = sign_gt_1.float().unsqueeze(0)
            
            # multiply evecs [6890 x 16] by sign_flip [16]
            evecs_flip_1 = evecs_orig * sign_gt_1
            
            # predict the sign change
            sign_pred_1, _, _ = predict_sign_change_landmarks(
                evecs_flip_1, landmarks_vector
                )
            
            ##############################################
            # Calculate the loss
            ##############################################
            
            # calculate the ground truth sign difference
            sign_diff_gt = sign_gt_1 * sign_gt_0
            
            # calculate the sign difference between predicted evecs
            sign_diff_pred = sign_pred_1 * sign_pred_0
            
            sign_correct = sign_diff_pred.sign() * sign_diff_gt.sign() 
            
            
            # count the number of incorrect signs
            count_incorrect_signs = (sign_correct < 0).int().sum()
                
            # incorrect_signs_list.append(count_incorrect_signs)
            incorrect_signs_list = torch.cat([incorrect_signs_list, torch.tensor([count_incorrect_signs])])
            
            
            # print(f'count_incorrect_signs {count_incorrect_signs}')
            # return
            
            iterator.set_description(f'Mean incorrect signs {incorrect_signs_list.float().mean():.2f} / {feature_dim}, max {incorrect_signs_list.max()}')
            iterator.update(1)
            # if count_incorrect_signs > 7:
            #     raise ValueError('Too many incorrect signs')
        
    return incorrect_signs_list.float().mean(), incorrect_signs_list.max()


In [None]:
test_dataset_curr = data_loading.get_val_dataset(
    'FAUST_r', 'test', 128, canonicalize_fmap=None, preload=False, return_evecs=True, centering='mean'
    )[0]
    
test_dataset_curr = [test_dataset_curr[i] for i in range(len(test_dataset_curr))]


In [None]:
mean_incorrect_signs, max_incorrect_signs = test_on_dataset(
    test_dataset_curr, n_epochs=100,
    feature_dim=96,
    n_landmarks=1
    )